use matrix_sdk::{ config::SyncSettings, ruma::api::client::{filter::FilterDefinition, session::get_login_types::v3::{IdentityProvider, LoginType}}, Client, Session, }; use rand::{distributions::Alphanumeric, thread_rng, Rng}; use reqwest::Client as http; use rpassword::prompt_password; use serde::{Deserialize, Serialize}; use serde_json::{from_str, Value}; use std::path::{Path, PathBuf}; use tokio::fs; #[derive(Debug, Serialize, Deserialize)] pub struct ClientSession { homeserver: String, db_path: PathBuf, passphrase: String, } #[derive(Debug, Serialize, Deserialize)] pub struct FullSession { client_session: ClientSession, user_session: Session, #[serde(skip_serializing_if = "Option::is_none")] sync_token: Option, } #[derive(Debug, Clone)] enum LoginChoice { Password, Sso, SsoIdp(IdentityProvider), } impl LoginChoice { pub async fn login(&self, client: &Client, user: String, hs: String) -> anyhow::Result<()> { match self { LoginChoice::Password => Self::login_password(client, user, hs).await, LoginChoice::Sso => Self::login_sso(client, None).await, LoginChoice::SsoIdp(idp) => Self::login_sso(client, Some(idp.to_owned())).await, } } async fn login_password(client: &Client, user: String, _hs: String) -> anyhow::Result<()> { loop { let password = prompt_password("Password\n> ")?; match client .login_username(&user, &password) .initial_device_display_name("scam-police") .send() .await { Ok(_) => { println!("[*] Logged in as {user}"); break; } Err(e) => { println!("[!] Error logging in: {e}"); println!("[!] Please try again\n"); } } } Ok(()) } async fn login_sso(client: &Client, idp: Option) -> anyhow::Result<()> { let mut login_builder = client.login_sso(|url| async move { open::that_in_background(url); println!("[*] Waiting for SSO token..."); Ok(()) }).initial_device_display_name("scam-police"); if let Some(idp) = idp.to_owned() { login_builder = login_builder.identity_provider_id(&idp.id); login_builder.send().await?; } else { login_builder.send().await?; } println!("[+] Got SSO token!"); Ok(()) } } // // Matrix Login & Init // pub async fn login(data_dir: &Path, session_file: &Path, mxid: String) -> anyhow::Result { println!("[*] Logging in as {mxid}..."); let (user, hs) = resolve_mxid(mxid).await?; let (client, client_session) = build_client(data_dir, hs.to_owned()).await?; let mut login_choices = Vec::new(); for login_type in client.get_login_types().await?.flows { match login_type { LoginType::Password(_) => { login_choices.push(LoginChoice::Password); }, LoginType::Sso(sso) => { if sso.identity_providers.is_empty() { login_choices.push(LoginChoice::Sso); } else { login_choices.extend(sso.identity_providers.into_iter().map(LoginChoice::SsoIdp)); } }, // Ignore all other types _ => {}, } } match login_choices.to_owned().len() { 0 => anyhow::bail!("No supported login types"), 1 => login_choices.to_owned().get(0).unwrap().login(&client, user, hs.to_owned()).await, _ => { use terminal_menu::*; let mut menu_items = vec![label("----- Scam Police Login -----")]; let choices: Vec<(LoginChoice, String)> = login_choices.into_iter().map(|a| (a.to_owned(), match a { LoginChoice::Password => format!("Password"), LoginChoice::Sso => format!("SSO"), LoginChoice::SsoIdp(idp) => format!("SSO via {}", idp.name), })).collect(); for choice in choices.to_owned() { menu_items.push(button(choice.1)); } menu_items.push(button("Abort login")); menu_items.push(label("-----------------------------")); let menu = menu(menu_items); run(&menu); let menu = mut_menu(&menu); let mut selected: Option = None; let name = menu.selected_item_name().to_string(); if name == "Abort login" { selected = None; } else { for c in choices { if c.1 == name { selected = Some(c.0.to_owned()); } } } match selected { Some(s) => s.login(&client, user, hs.to_owned()).await, None => anyhow::bail!("Aborting login") } } }?; let user_session = client .session() .expect("A logged-in client should have a session"); let serialized_session = serde_json::to_string(&FullSession { client_session, user_session, sync_token: None, })?; fs::write(session_file, serialized_session).await?; Ok(client) } pub async fn build_client(data_dir: &Path, hs: String) -> anyhow::Result<(Client, ClientSession)> { let mut rng = thread_rng(); let db_subfolder: String = (&mut rng) .sample_iter(Alphanumeric) .take(7) .map(char::from) .collect(); let db_path = data_dir.join(db_subfolder); let passphrase: String = (&mut rng) .sample_iter(Alphanumeric) .take(32) .map(char::from) .collect(); match Client::builder() .homeserver_url(&hs) .sled_store(&db_path, Some(&passphrase))? .build() .await { Ok(client) => { println!("[+] Homeserver OK"); return Ok(( client, ClientSession { homeserver: hs, db_path, passphrase, }, )); } Err(error) => match &error { matrix_sdk::ClientBuildError::AutoDiscovery(_) | matrix_sdk::ClientBuildError::Url(_) | matrix_sdk::ClientBuildError::Http(_) => { anyhow::bail!("[!] {error:?}"); } _ => { return Err(error.into()); } }, } } // // Helper Functions // // Resolve mxid into user and hs pub async fn resolve_mxid(mxid: String) -> anyhow::Result<(String, String)> { if mxid.get(0..1).unwrap() != "@" || !mxid.contains(":") { anyhow::bail!("Invalid mxid"); } let sep = mxid.find(":").unwrap(); let user = mxid.get(1..sep).unwrap().to_string(); let hs = resolve_homeserver(mxid.get((sep + 1)..).unwrap().to_string()).await?; Ok((user, hs)) } // Resolve homeserver pub async fn resolve_homeserver(homeserver: String) -> anyhow::Result { let mut hs = homeserver; if !hs.contains("://") { hs = format!("https://{hs}"); } if hs.chars().last().unwrap().to_string() == "/" { hs.pop(); } let ident = http::new() .get(format!("{hs}/.well-known/matrix/client")) .send() .await; match ident { Ok(r) => { let body = r.text().await?; let json: Value = from_str(&body)?; let discovered = json["m.homeserver"]["base_url"].as_str().unwrap(); Ok(discovered.to_string()) } Err(e) => Err(e.into()), } } // // Persistence // pub async fn sync<'a>( client: Client, initial_sync_token: Option, ) -> anyhow::Result<(Client, SyncSettings<'a>)> { println!("[*] Running initial sync..."); let filter = FilterDefinition::empty(); let mut sync_settings = SyncSettings::default().filter(filter.into()); if let Some(sync_token) = initial_sync_token { sync_settings = sync_settings.token(sync_token); } loop { match client.sync_once(sync_settings.clone()).await { Ok(response) => { sync_settings = sync_settings.token(response.next_batch.clone()); persist_sync_token(response.next_batch).await?; break; } Err(error) => { println!("[!] An error occurred during initial sync: {error}"); if error.to_string().contains("[401 / M_UNKNOWN_TOKEN]") { anyhow::bail!("Unknown token. You need to login again"); } println!("[!] Trying again…"); } } } println!("[*] Scam police is now running!"); Ok((client, sync_settings)) } pub async fn persist_sync_token(sync_token: String) -> anyhow::Result<()> { let serialized_session = fs::read_to_string(crate::SESSION_FILE.to_owned()).await?; let mut full_session: FullSession = from_str(&serialized_session)?; full_session.sync_token = Some(sync_token); let serialized_session = serde_json::to_string(&full_session)?; fs::write(crate::SESSION_FILE.to_owned(), serialized_session).await?; Ok(()) } pub async fn restore_session(session_file: &Path) -> anyhow::Result<(Client, Option)> { let serialized_session = fs::read_to_string(session_file).await?; let FullSession { client_session, user_session, sync_token, } = from_str(&serialized_session)?; let client = Client::builder() .homeserver_url(client_session.homeserver) .sled_store(client_session.db_path, Some(&client_session.passphrase))? .build() .await?; println!("[*] Restoring session for {}…", user_session.user_id); client.restore_login(user_session).await?; Ok((client, sync_token)) }