use std::collections::HashSet;
use anyhow::Context;
use async_trait::async_trait;
use mas_data_model::Device;
use mas_matrix::ProvisionRequest;
use mas_storage::{
compat::CompatSessionFilter,
oauth2::OAuth2SessionFilter,
queue::{
DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, QueueJobRepositoryExt as _,
SyncDevicesJob,
},
user::{UserEmailRepository, UserRepository},
Pagination, RepositoryAccess,
};
use tracing::info;
use crate::{
new_queue::{JobContext, JobError, RunnableJob},
State,
};
#[async_trait]
impl RunnableJob for ProvisionUserJob {
#[tracing::instrument(
name = "job.provision_user"
fields(user.id = %self.user_id()),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let matrix = state.matrix_connection();
let mut repo = state.repository().await.map_err(JobError::retry)?;
let mut rng = state.rng();
let clock = state.clock();
let user = repo
.user()
.lookup(self.user_id())
.await
.map_err(JobError::retry)?
.context("User not found")
.map_err(JobError::fail)?;
let mxid = matrix.mxid(&user.username);
let emails = repo
.user_email()
.all(&user)
.await
.map_err(JobError::retry)?
.into_iter()
.filter(|email| email.confirmed_at.is_some())
.map(|email| email.email)
.collect();
let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails);
if let Some(display_name) = self.display_name_to_set() {
request = request.set_displayname(display_name.to_owned());
}
let created = matrix
.provision_user(&request)
.await
.map_err(JobError::retry)?;
if created {
info!(%user.id, %mxid, "User created");
} else {
info!(%user.id, %mxid, "User updated");
}
let sync_device_job = SyncDevicesJob::new(&user);
repo.queue_job()
.schedule_job(&mut rng, &clock, sync_device_job)
.await
.map_err(JobError::retry)?;
repo.save().await.map_err(JobError::retry)?;
Ok(())
}
}
#[async_trait]
impl RunnableJob for ProvisionDeviceJob {
#[tracing::instrument(
name = "job.provision_device"
fields(
user.id = %self.user_id(),
device.id = %self.device_id(),
),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let mut repo = state.repository().await.map_err(JobError::retry)?;
let mut rng = state.rng();
let clock = state.clock();
let user = repo
.user()
.lookup(self.user_id())
.await
.map_err(JobError::retry)?
.context("User not found")
.map_err(JobError::fail)?;
repo.queue_job()
.schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
.await
.map_err(JobError::retry)?;
Ok(())
}
}
#[async_trait]
impl RunnableJob for DeleteDeviceJob {
#[tracing::instrument(
name = "job.delete_device"
fields(
user.id = %self.user_id(),
device.id = %self.device_id(),
),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let mut rng = state.rng();
let clock = state.clock();
let mut repo = state.repository().await.map_err(JobError::retry)?;
let user = repo
.user()
.lookup(self.user_id())
.await
.map_err(JobError::retry)?
.context("User not found")
.map_err(JobError::fail)?;
repo.queue_job()
.schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
.await
.map_err(JobError::retry)?;
Ok(())
}
}
#[async_trait]
impl RunnableJob for SyncDevicesJob {
#[tracing::instrument(
name = "job.sync_devices",
fields(user.id = %self.user_id()),
skip_all,
err,
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), JobError> {
let matrix = state.matrix_connection();
let mut repo = state.repository().await.map_err(JobError::retry)?;
let user = repo
.user()
.lookup(self.user_id())
.await
.map_err(JobError::retry)?
.context("User not found")
.map_err(JobError::fail)?;
repo.user()
.acquire_lock_for_sync(&user)
.await
.map_err(JobError::retry)?;
let mut devices = HashSet::new();
let mut cursor = Pagination::first(100);
loop {
let page = repo
.compat_session()
.list(
CompatSessionFilter::new().for_user(&user).active_only(),
cursor,
)
.await
.map_err(JobError::retry)?;
for (compat_session, _) in page.edges {
devices.insert(compat_session.device.as_str().to_owned());
cursor = cursor.after(compat_session.id);
}
if !page.has_next_page {
break;
}
}
let mut cursor = Pagination::first(100);
loop {
let page = repo
.oauth2_session()
.list(
OAuth2SessionFilter::new().for_user(&user).active_only(),
cursor,
)
.await
.map_err(JobError::retry)?;
for oauth2_session in page.edges {
for scope in &*oauth2_session.scope {
if let Some(device) = Device::from_scope_token(scope) {
devices.insert(device.as_str().to_owned());
}
}
cursor = cursor.after(oauth2_session.id);
}
if !page.has_next_page {
break;
}
}
let mxid = matrix.mxid(&user.username);
matrix
.sync_devices(&mxid, devices)
.await
.map_err(JobError::retry)?;
repo.save().await.map_err(JobError::retry)?;
Ok(())
}
}