diff --git a/Cargo.lock b/Cargo.lock index 823b314..ceb6816 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1191,6 +1191,7 @@ name = "nostr-rs-relay" version = "0.7.17" dependencies = [ "anyhow", + "async-trait", "bitcoin_hashes", "clap", "config", diff --git a/Cargo.toml b/Cargo.toml index ed75a56..3fc859c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,6 +42,7 @@ parse_duration = "2" rand = "0.8" const_format = "0.2.28" regex = "1" +async-trait = "0.1.60" [dev-dependencies] anyhow = "1" diff --git a/src/bin/bulkloader.rs b/src/bin/bulkloader.rs index cb9d3cd..59a3bd8 100644 --- a/src/bin/bulkloader.rs +++ b/src/bin/bulkloader.rs @@ -1,14 +1,13 @@ use std::io; use std::path::Path; use nostr_rs_relay::utils::is_lower_hex; -use tracing::*; +use tracing::info; use nostr_rs_relay::config; use nostr_rs_relay::event::{Event,single_char_tagname}; use nostr_rs_relay::error::{Error, Result}; -use nostr_rs_relay::db::build_pool; -use nostr_rs_relay::schema::{curr_db_version, DB_VERSION}; +use nostr_rs_relay::repo::sqlite::{PooledConnection, build_pool}; +use nostr_rs_relay::repo::sqlite_migration::{curr_db_version, DB_VERSION}; use rusqlite::{OpenFlags, Transaction}; -use nostr_rs_relay::db::PooledConnection; use std::sync::mpsc; use std::thread; use rusqlite::params; @@ -67,7 +66,7 @@ pub fn main() -> Result<()> { info!("finished parsing events"); event_tx.send(None).ok(); let ok: Result<()> = Ok(()); - return ok; + ok }); let mut conn: PooledConnection = pool.get()?; let mut events_read = 0; diff --git a/src/config.rs b/src/config.rs index 3d22984..47be07c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -18,6 +18,7 @@ pub struct Info { #[allow(unused)] pub struct Database { pub data_directory: String, + pub engine: String, pub in_memory: bool, pub min_conn: u32, pub max_conn: u32, @@ -206,6 +207,7 @@ impl Default for Settings { diagnostics: Diagnostics { tracing: false }, database: Database { data_directory: ".".to_owned(), + engine: "sqlite".to_owned(), in_memory: false, min_conn: 4, max_conn: 8, diff --git a/src/conn.rs b/src/conn.rs index c957495..dcd6ffb 100644 --- a/src/conn.rs +++ b/src/conn.rs @@ -14,7 +14,7 @@ const MAX_SUBSCRIPTION_ID_LEN: usize = 256; /// State for a client connection pub struct ClientConn { /// Client IP (either from socket, or configured proxy header - client_ip: String, + client_ip_addr: String, /// Unique client identifier generated at connection time client_id: Uuid, /// The current set of active client subscriptions @@ -32,22 +32,22 @@ impl Default for ClientConn { impl ClientConn { /// Create a new, empty connection state. #[must_use] - pub fn new(client_ip: String) -> Self { + pub fn new(client_ip_addr: String) -> Self { let client_id = Uuid::new_v4(); ClientConn { - client_ip, + client_ip_addr, client_id, subscriptions: HashMap::new(), max_subs: 32, } } - pub fn subscriptions(&self) -> &HashMap { + #[must_use] pub fn subscriptions(&self) -> &HashMap { &self.subscriptions } /// Check if the given subscription already exists - pub fn has_subscription(&self, sub: &Subscription) -> bool { + #[must_use] pub fn has_subscription(&self, sub: &Subscription) -> bool { self.subscriptions.values().any(|x| x == sub) } @@ -60,7 +60,7 @@ impl ClientConn { #[must_use] pub fn ip(&self) -> &str { - &self.client_ip + &self.client_ip_addr } /// Add a new subscription for this connection. diff --git a/src/db.rs b/src/db.rs index 56a2c1c..f09a5d5 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,32 +1,16 @@ //! Event persistence and querying -//use crate::config::SETTINGS; use crate::config::Settings; use crate::error::{Error, Result}; -use crate::event::{single_char_tagname, Event}; -use crate::hexrange::hex_range; -use crate::hexrange::HexSearch; -use crate::nip05; +use crate::event::Event; use crate::notice::Notice; -use crate::schema::{upgrade_db, STARTUP_SQL}; -use crate::subscription::ReqFilter; -use crate::subscription::Subscription; -use crate::utils::{is_hex, is_lower_hex}; use governor::clock::Clock; use governor::{Quota, RateLimiter}; -use hex; use r2d2; -use r2d2_sqlite::SqliteConnectionManager; -use rusqlite::params; -use rusqlite::types::ToSql; -use rusqlite::OpenFlags; -use tokio::sync::{Mutex, MutexGuard}; -use std::fmt::Write as _; -use std::path::Path; use std::sync::Arc; use std::thread; -use std::time::Duration; +use crate::repo::sqlite::SqliteRepo; +use crate::repo::NostrRepo; use std::time::Instant; -use tokio::task; use tracing::{debug, info, trace, warn}; pub type SqlitePool = r2d2::Pool; @@ -41,220 +25,122 @@ pub struct SubmittedEvent { /// Database file pub const DB_FILE: &str = "nostr.db"; -/// How frequently to attempt checkpointing -pub const CHECKPOINT_FREQ_SEC: u64 = 60; - -/// Build a database connection pool. +/// Build repo /// # Panics /// /// Will panic if the pool could not be created. -#[must_use] -pub fn build_pool( - name: &str, - settings: &Settings, - flags: OpenFlags, - min_size: u32, - max_size: u32, - wait_for_db: bool, -) -> SqlitePool { - let db_dir = &settings.database.data_directory; - let full_path = Path::new(db_dir).join(DB_FILE); - // small hack; if the database doesn't exist yet, that means the - // writer thread hasn't finished. Give it a chance to work. This - // is only an issue with the first time we run. - if !settings.database.in_memory { - while !full_path.exists() && wait_for_db { - debug!("Database reader pool is waiting on the database to be created..."); - thread::sleep(Duration::from_millis(500)); - } - } - let manager = if settings.database.in_memory { - SqliteConnectionManager::memory() - .with_flags(flags) - .with_init(|c| c.execute_batch(STARTUP_SQL)) - } else { - SqliteConnectionManager::file(&full_path) - .with_flags(flags) - .with_init(|c| c.execute_batch(STARTUP_SQL)) - }; - let pool: SqlitePool = r2d2::Pool::builder() - .test_on_check_out(true) // no noticeable performance hit - .min_idle(Some(min_size)) - .max_size(max_size) - .max_lifetime(Some(Duration::from_secs(30))) - .build(manager) - .unwrap(); - info!( - "Built a connection pool {:?} (min={}, max={})", - name, min_size, max_size - ); - pool -} - -/// Display database pool stats every 1 minute -pub async fn monitor_pool(name: &str, pool: SqlitePool) { - let sleep_dur = Duration::from_secs(60); - loop { - log_pool_stats(name, &pool); - tokio::time::sleep(sleep_dur).await; +pub async fn build_repo(settings: &Settings) -> Arc { + match settings.database.engine.as_str() { + "sqlite" => {Arc::new(build_sqlite_pool(settings).await)}, + _ => panic!("Unknown database engine"), } } - -/// Perform normal maintenance -pub fn optimize_db(conn: &mut PooledConnection) -> Result<()> { - let start = Instant::now(); - conn.execute_batch("PRAGMA optimize;")?; - info!("optimize ran in {:?}", start.elapsed()); - Ok(()) -} -#[derive(Debug)] -enum SqliteStatus { - Ok, - Busy, - Error, - Other(u64), +async fn build_sqlite_pool(settings: &Settings) -> SqliteRepo { + let repo = SqliteRepo::new(settings); + repo.start().await.ok(); + repo.migrate_up().await.ok(); + repo } -/// Checkpoint/Truncate WAL. Returns the number of WAL pages remaining. -pub fn checkpoint_db(conn: &mut PooledConnection) -> Result { - let query = "PRAGMA wal_checkpoint(TRUNCATE);"; - let start = Instant::now(); - let (cp_result, wal_size, _frames_checkpointed) = conn.query_row(query, [], |row| { - let checkpoint_result: u64 = row.get(0)?; - let wal_size: u64 = row.get(1)?; - let frames_checkpointed: u64 = row.get(2)?; - Ok((checkpoint_result, wal_size, frames_checkpointed)) - })?; - let result = match cp_result { - 0 => SqliteStatus::Ok, - 1 => SqliteStatus::Busy, - 2 => SqliteStatus::Error, - x => SqliteStatus::Other(x), - }; - info!( - "checkpoint ran in {:?} (result: {:?}, WAL size: {})", - start.elapsed(), - result, - wal_size - ); - Ok(wal_size as usize) -} - -/// Spawn a database writer that persists events to the SQLite store. +/// Spawn a database writer that persists events to the `SQLite` store. pub async fn db_writer( + repo: Arc, settings: Settings, mut event_rx: tokio::sync::mpsc::Receiver, bcast_tx: tokio::sync::broadcast::Sender, metadata_tx: tokio::sync::broadcast::Sender, mut shutdown: tokio::sync::broadcast::Receiver<()>, -) -> tokio::task::JoinHandle> { +) -> Result<()> { // are we performing NIP-05 checking? let nip05_active = settings.verified_users.is_active(); // are we requriing NIP-05 user verification? let nip05_enabled = settings.verified_users.is_enabled(); - task::spawn_blocking(move || { - let db_dir = &settings.database.data_directory; - let full_path = Path::new(db_dir).join(DB_FILE); - // create a connection pool - let pool = build_pool( - "event writer", - &settings, - OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE, - 1, - 2, - false, - ); - if settings.database.in_memory { - info!("using in-memory database, this will not persist a restart!"); - } else { - info!("opened database {:?} for writing", full_path); + //upgrade_db(&mut pool.get()?)?; + + // Make a copy of the whitelist + let whitelist = &settings.authorization.pubkey_whitelist.clone(); + + // get rate limit settings + let rps_setting = settings.limits.messages_per_sec; + let mut most_recent_rate_limit = Instant::now(); + let mut lim_opt = None; + let clock = governor::clock::QuantaClock::default(); + if let Some(rps) = rps_setting { + if rps > 0 { + info!("Enabling rate limits for event creation ({}/sec)", rps); + let quota = core::num::NonZeroU32::new(rps * 60).unwrap(); + lim_opt = Some(RateLimiter::direct(Quota::per_minute(quota))); } - upgrade_db(&mut pool.get()?)?; - - // Make a copy of the whitelist - let whitelist = &settings.authorization.pubkey_whitelist.clone(); - - // get rate limit settings - let rps_setting = settings.limits.messages_per_sec; - let mut most_recent_rate_limit = Instant::now(); - let mut lim_opt = None; - let clock = governor::clock::QuantaClock::default(); - if let Some(rps) = rps_setting { - if rps > 0 { - info!("Enabling rate limits for event creation ({}/sec)", rps); - let quota = core::num::NonZeroU32::new(rps * 60).unwrap(); - lim_opt = Some(RateLimiter::direct(Quota::per_minute(quota))); + } + loop { + if shutdown.try_recv().is_ok() { + info!("shutting down database writer"); + break; + } + // call blocking read on channel + let next_event = event_rx.recv().await; + // if the channel has closed, we will never get work + if next_event.is_none() { + break; + } + // track if an event write occurred; this is used to + // update the rate limiter + let mut event_write = false; + let subm_event = next_event.unwrap(); + let event = subm_event.event; + let notice_tx = subm_event.notice_tx; + // check if this event is authorized. + if let Some(allowed_addrs) = whitelist { + // TODO: incorporate delegated pubkeys + // if the event address is not in allowed_addrs. + if !allowed_addrs.contains(&event.pubkey) { + debug!( + "rejecting event: {}, unauthorized author", + event.get_event_id_prefix() + ); + notice_tx + .try_send(Notice::blocked( + event.id, + "pubkey is not allowed to publish to this relay", + )) + .ok(); + continue; } } - loop { - if shutdown.try_recv().is_ok() { - info!("shutting down database writer"); - break; - } - // call blocking read on channel - let next_event = event_rx.blocking_recv(); - // if the channel has closed, we will never get work - if next_event.is_none() { - break; - } - // track if an event write occurred; this is used to - // update the rate limiter - let mut event_write = false; - let subm_event = next_event.unwrap(); - let event = subm_event.event; - let notice_tx = subm_event.notice_tx; - // check if this event is authorized. - if let Some(allowed_addrs) = whitelist { - // TODO: incorporate delegated pubkeys - // if the event address is not in allowed_addrs. - if !allowed_addrs.contains(&event.pubkey) { - debug!( - "rejecting event: {}, unauthorized author", - event.get_event_id_prefix() - ); - notice_tx - .try_send(Notice::blocked( - event.id, - "pubkey is not allowed to publish to this relay", - )) - .ok(); - continue; - } - } - // Check that event kind isn't blacklisted - let kinds_blacklist = &settings.limits.event_kind_blacklist.clone(); - if let Some(event_kind_blacklist) = kinds_blacklist { - if event_kind_blacklist.contains(&event.kind) { - debug!( - "rejecting event: {}, blacklisted kind: {}", - &event.get_event_id_prefix(), - &event.kind - ); - notice_tx - .try_send(Notice::blocked( - event.id, - "event kind is blocked by relay" - )) - .ok(); - continue; - } + // Check that event kind isn't blacklisted + let kinds_blacklist = &settings.limits.event_kind_blacklist.clone(); + if let Some(event_kind_blacklist) = kinds_blacklist { + if event_kind_blacklist.contains(&event.kind) { + debug!( + "rejecting event: {}, blacklisted kind: {}", + &event.get_event_id_prefix(), + &event.kind + ); + notice_tx + .try_send(Notice::blocked( + event.id, + "event kind is blocked by relay" + )) + .ok(); + continue; } + } - // send any metadata events to the NIP-05 verifier - if nip05_active && event.is_kind_metadata() { - // we are sending this prior to even deciding if we - // persist it. this allows the nip05 module to - // inspect it, update if necessary, or persist a new - // event and broadcast it itself. - metadata_tx.send(event.clone()).ok(); - } + // send any metadata events to the NIP-05 verifier + if nip05_active && event.is_kind_metadata() { + // we are sending this prior to even deciding if we + // persist it. this allows the nip05 module to + // inspect it, update if necessary, or persist a new + // event and broadcast it itself. + metadata_tx.send(event.clone()).ok(); + } // check for NIP-05 verification - if nip05_enabled { - match nip05::query_latest_user_verification(pool.get()?, event.pubkey.to_owned()) { + if nip05_enabled { + match repo.get_latest_user_verification(&event.pubkey).await { Ok(uv) => { if uv.is_valid(&settings.verified_users) { info!( @@ -297,198 +183,69 @@ pub async fn db_writer( } } // TODO: cache recent list of authors to remove a DB call. - let start = Instant::now(); - if event.kind >= 20000 && event.kind < 30000 { - bcast_tx.send(event.clone()).ok(); - info!( - "published ephemeral event: {:?} from: {:?} in: {:?}", - event.get_event_id_prefix(), - event.get_author_prefix(), - start.elapsed() - ); - event_write = true - } else { - match write_event(&mut pool.get()?, &event) { - Ok(updated) => { - if updated == 0 { - trace!("ignoring duplicate or deleted event"); - notice_tx.try_send(Notice::duplicate(event.id)).ok(); - } else { - info!( - "persisted event: {:?} (kind: {}) from: {:?} in: {:?}", - event.get_event_id_prefix(), - event.kind, - event.get_author_prefix(), - start.elapsed() - ); - event_write = true; - // send this out to all clients - bcast_tx.send(event.clone()).ok(); - notice_tx.try_send(Notice::saved(event.id)).ok(); - } - } - Err(err) => { - warn!("event insert failed: {:?}", err); - let msg = "relay experienced an error trying to publish the latest event"; - notice_tx.try_send(Notice::error(event.id, msg)).ok(); - } - } - } - - // use rate limit, if defined, and if an event was actually written. - if event_write { - if let Some(ref lim) = lim_opt { - if let Err(n) = lim.check() { - let wait_for = n.wait_time_from(clock.now()); - // check if we have recently logged rate - // limits, but print out a message only once - // per second. - if most_recent_rate_limit.elapsed().as_secs() > 10 { - warn!( - "rate limit reached for event creation (sleep for {:?}) (suppressing future messages for 10 seconds)", - wait_for - ); - // reset last rate limit message - most_recent_rate_limit = Instant::now(); - } - // block event writes, allowing them to queue up - thread::sleep(wait_for); - continue; - } - } - } - } - info!("database connection closed"); - Ok(()) - }) -} - -/// Persist an event to the database, returning rows added. -pub fn write_event(conn: &mut PooledConnection, e: &Event) -> Result { - // enable auto vacuum - conn.execute_batch("pragma auto_vacuum = FULL")?; - - // start transaction - let tx = conn.transaction()?; - // get relevant fields from event and convert to blobs. - let id_blob = hex::decode(&e.id).ok(); - let pubkey_blob: Option> = hex::decode(&e.pubkey).ok(); - let delegator_blob: Option> = e.delegated_by.as_ref().and_then(|d| hex::decode(d).ok()); - let event_str = serde_json::to_string(&e).ok(); - // check for replaceable events that would hide this one; we won't even attempt to insert these. - if e.is_replaceable() { - let repl_count = tx.query_row( - "SELECT e.id FROM event e INDEXED BY author_index WHERE e.author=? AND e.kind=? AND e.created_at > ? LIMIT 1;", - params![pubkey_blob, e.kind, e.created_at], |row| row.get::(0)); - if repl_count.ok().is_some() { - return Ok(0); - } - } - // ignore if the event hash is a duplicate. - let mut ins_count = tx.execute( - "INSERT OR IGNORE INTO event (event_hash, created_at, kind, author, delegated_by, content, first_seen, hidden) VALUES (?1, ?2, ?3, ?4, ?5, ?6, strftime('%s','now'), FALSE);", - params![id_blob, e.created_at, e.kind, pubkey_blob, delegator_blob, event_str] - )?; - if ins_count == 0 { - // if the event was a duplicate, no need to insert event or - // pubkey references. - tx.rollback().ok(); - return Ok(ins_count); - } - // remember primary key of the event most recently inserted. - let ev_id = tx.last_insert_rowid(); - // add all tags to the tag table - for tag in e.tags.iter() { - // ensure we have 2 values. - if tag.len() >= 2 { - let tagname = &tag[0]; - let tagval = &tag[1]; - // only single-char tags are searchable - let tagchar_opt = single_char_tagname(tagname); - match &tagchar_opt { - Some(_) => { - // if tagvalue is lowercase hex; - if is_lower_hex(tagval) && (tagval.len() % 2 == 0) { - tx.execute( - "INSERT OR IGNORE INTO tag (event_id, name, value_hex) VALUES (?1, ?2, ?3)", - params![ev_id, &tagname, hex::decode(tagval).ok()], - )?; + let start = Instant::now(); + if event.kind >= 20000 && event.kind < 30000 { + bcast_tx.send(event.clone()).ok(); + info!( + "published ephemeral event: {:?} from: {:?} in: {:?}", + event.get_event_id_prefix(), + event.get_author_prefix(), + start.elapsed() + ); + event_write = true; + } else { + match repo.write_event(&event).await { + Ok(updated) => { + if updated == 0 { + trace!("ignoring duplicate or deleted event"); + notice_tx.try_send(Notice::duplicate(event.id)).ok(); } else { - tx.execute( - "INSERT OR IGNORE INTO tag (event_id, name, value) VALUES (?1, ?2, ?3)", - params![ev_id, &tagname, &tagval], - )?; + info!( + "persisted event: {:?} (kind: {}) from: {:?} in: {:?}", + event.get_event_id_prefix(), + event.kind, + event.get_author_prefix(), + start.elapsed() + ); + event_write = true; + // send this out to all clients + bcast_tx.send(event.clone()).ok(); + notice_tx.try_send(Notice::saved(event.id)).ok(); } } - None => {} + Err(err) => { + warn!("event insert failed: {:?}", err); + let msg = "relay experienced an error trying to publish the latest event"; + notice_tx.try_send(Notice::error(event.id, msg)).ok(); + } + } + } + + // use rate limit, if defined, and if an event was actually written. + if event_write { + if let Some(ref lim) = lim_opt { + if let Err(n) = lim.check() { + let wait_for = n.wait_time_from(clock.now()); + // check if we have recently logged rate + // limits, but print out a message only once + // per second. + if most_recent_rate_limit.elapsed().as_secs() > 10 { + warn!( + "rate limit reached for event creation (sleep for {:?}) (suppressing future messages for 10 seconds)", + wait_for + ); + // reset last rate limit message + most_recent_rate_limit = Instant::now(); + } + // block event writes, allowing them to queue up + thread::sleep(wait_for); + continue; + } } } } - // if this event is replaceable update, remove other replaceable - // event with the same kind from the same author that was issued - // earlier than this. - if e.is_replaceable() { - let author = hex::decode(&e.pubkey).ok(); - // this is a backwards check - hide any events that were older. - let update_count = tx.execute( - "DELETE FROM event WHERE kind=? and author=? and id NOT IN (SELECT id FROM event INDEXED BY author_kind_index WHERE kind=? AND author=? ORDER BY created_at DESC LIMIT 1)", - params![e.kind, author, e.kind, author], - )?; - if update_count > 0 { - info!( - "removed {} older replaceable kind {} events for author: {:?}", - update_count, - e.kind, - e.get_author_prefix() - ); - } - } - // if this event is a deletion, hide the referenced events from the same author. - if e.kind == 5 { - let event_candidates = e.tag_values_by_name("e"); - // first parameter will be author - let mut params: Vec> = vec![Box::new(hex::decode(&e.pubkey)?)]; - event_candidates - .iter() - .filter(|x| is_hex(x) && x.len() == 64) - .filter_map(|x| hex::decode(x).ok()) - .for_each(|x| params.push(Box::new(x))); - let query = format!( - "UPDATE event SET hidden=TRUE WHERE kind!=5 AND author=? AND event_hash IN ({})", - repeat_vars(params.len() - 1) - ); - let mut stmt = tx.prepare(&query)?; - let update_count = stmt.execute(rusqlite::params_from_iter(params))?; - info!( - "hid {} deleted events for author {:?}", - update_count, - e.get_author_prefix() - ); - } else { - // check if a deletion has already been recorded for this event. - // Only relevant for non-deletion events - let del_count = tx.query_row( - "SELECT e.id FROM event e LEFT JOIN tag t ON e.id=t.event_id WHERE e.author=? AND t.name='e' AND e.kind=5 AND t.value_hex=? LIMIT 1;", - params![pubkey_blob, id_blob], |row| row.get::(0)); - // check if a the query returned a result, meaning we should - // hid the current event - if del_count.ok().is_some() { - // a deletion already existed, mark original event as hidden. - info!( - "hid event: {:?} due to existing deletion by author: {:?}", - e.get_event_id_prefix(), - e.get_author_prefix() - ); - let _update_count = - tx.execute("UPDATE event SET hidden=TRUE WHERE id=?", params![ev_id])?; - // event was deleted, so let caller know nothing new - // arrived, preventing this from being sent to active - // subscriptions - ins_count = 0; - } - } - tx.commit()?; - Ok(ins_count) + info!("database connection closed"); + Ok(()) } /// Serialized event associated with a specific subscription request. @@ -499,459 +256,3 @@ pub struct QueryResult { /// Serialized event pub event: String, } - -/// Produce a arbitrary list of '?' parameters. -fn repeat_vars(count: usize) -> String { - if count == 0 { - return "".to_owned(); - } - let mut s = "?,".repeat(count); - // Remove trailing comma - s.pop(); - s -} - -/// Decide if there is an index that should be used explicitly -fn override_index(f: &ReqFilter) -> Option { - // queries for multiple kinds default to kind_index, which is - // significantly slower than kind_created_at_index. - if let Some(ks) = &f.kinds { - if f.ids.is_none() && - ks.len() > 1 && - f.since.is_none() && - f.until.is_none() && - f.tags.is_none() && - f.authors.is_none() { - return Some("kind_created_at_index".into()); - } - } - // if there is an author, it is much better to force the authors index. - if let Some(_) = &f.authors { - if f.since.is_none() && f.until.is_none() { - if f.kinds.is_none() { - // with no use of kinds/created_at, just author - return Some("author_index".into()); - } else { - // prefer author_kind if there are kinds - return Some("author_kind_index".into()); - } - } else { - // finally, prefer author_created_at if time is provided - return Some("author_created_at_index".into()); - } - } - None -} - -/// Create a dynamic SQL subquery and params from a subscription filter (and optional explicit index used) -fn query_from_filter(f: &ReqFilter) -> (String, Vec>, Option) { - // build a dynamic SQL query. all user-input is either an integer - // (sqli-safe), or a string that is filtered to only contain - // hexadecimal characters. Strings that require escaping (tag - // names/values) use parameters. - - // if the filter is malformed, don't return anything. - if f.force_no_match { - let empty_query = "SELECT e.content, e.created_at FROM event e WHERE 1=0".to_owned(); - // query parameters for SQLite - let empty_params: Vec> = vec![]; - return (empty_query, empty_params, None); - } - - // check if the index needs to be overriden - let idx_name = override_index(f); - let idx_stmt = idx_name.as_ref().map_or_else(|| "".to_owned(), |i| format!("INDEXED BY {}",i)); - let mut query = format!("SELECT e.content, e.created_at FROM event e {}", idx_stmt); - // query parameters for SQLite - let mut params: Vec> = vec![]; - - // individual filter components (single conditions such as an author or event ID) - let mut filter_components: Vec = Vec::new(); - // Query for "authors", allowing prefix matches - if let Some(authvec) = &f.authors { - // take each author and convert to a hexsearch - let mut auth_searches: Vec = vec![]; - for auth in authvec { - match hex_range(auth) { - Some(HexSearch::Exact(ex)) => { - auth_searches.push("author=?".to_owned()); - params.push(Box::new(ex)); - } - Some(HexSearch::Range(lower, upper)) => { - auth_searches.push( - "(author>? AND author { - auth_searches.push("author>?".to_owned()); - params.push(Box::new(lower)); - } - None => { - info!("Could not parse hex range from author {:?}", auth); - } - } - } - if !authvec.is_empty() { - let auth_clause = format!("({})", auth_searches.join(" OR ")); - filter_components.push(auth_clause); - } else { - filter_components.push("false".to_owned()); - } - } - // Query for Kind - if let Some(ks) = &f.kinds { - // kind is number, no escaping needed - let str_kinds: Vec = ks.iter().map(|x| x.to_string()).collect(); - let kind_clause = format!("kind IN ({})", str_kinds.join(", ")); - filter_components.push(kind_clause); - } - // Query for event, allowing prefix matches - if let Some(idvec) = &f.ids { - // take each author and convert to a hexsearch - let mut id_searches: Vec = vec![]; - for id in idvec { - match hex_range(id) { - Some(HexSearch::Exact(ex)) => { - id_searches.push("event_hash=?".to_owned()); - params.push(Box::new(ex)); - } - Some(HexSearch::Range(lower, upper)) => { - id_searches.push("(event_hash>? AND event_hash { - id_searches.push("event_hash>?".to_owned()); - params.push(Box::new(lower)); - } - None => { - info!("Could not parse hex range from id {:?}", id); - } - } - } - if !idvec.is_empty() { - let id_clause = format!("({})", id_searches.join(" OR ")); - filter_components.push(id_clause); - } else { - // if the ids list was empty, we should never return - // any results. - filter_components.push("false".to_owned()); - } - } - // Query for tags - if let Some(map) = &f.tags { - for (key, val) in map.iter() { - let mut str_vals: Vec> = vec![]; - let mut blob_vals: Vec> = vec![]; - for v in val { - if (v.len() % 2 == 0) && is_lower_hex(v) { - if let Ok(h) = hex::decode(v) { - blob_vals.push(Box::new(h)); - } - } else { - str_vals.push(Box::new(v.to_owned())); - } - } - // create clauses with "?" params for each tag value being searched - let str_clause = format!("value IN ({})", repeat_vars(str_vals.len())); - let blob_clause = format!("value_hex IN ({})", repeat_vars(blob_vals.len())); - // find evidence of the target tag name/value existing for this event. - let tag_clause = format!( - "e.id IN (SELECT e.id FROM event e LEFT JOIN tag t on e.id=t.event_id WHERE hidden!=TRUE and (name=? AND ({} OR {})))", - str_clause, blob_clause - ); - // add the tag name as the first parameter - params.push(Box::new(key.to_string())); - // add all tag values that are plain strings as params - params.append(&mut str_vals); - // add all tag values that are blobs as params - params.append(&mut blob_vals); - filter_components.push(tag_clause); - } - } - // Query for timestamp - if f.since.is_some() { - let created_clause = format!("created_at > {}", f.since.unwrap()); - filter_components.push(created_clause); - } - // Query for timestamp - if f.until.is_some() { - let until_clause = format!("created_at < {}", f.until.unwrap()); - filter_components.push(until_clause); - } - // never display hidden events - query.push_str(" WHERE hidden!=TRUE"); - // build filter component conditions - if !filter_components.is_empty() { - query.push_str(" AND "); - query.push_str(&filter_components.join(" AND ")); - } - // Apply per-filter limit to this subquery. - // The use of a LIMIT implies a DESC order, to capture only the most recent events. - if let Some(lim) = f.limit { - let _ = write!(query, " ORDER BY e.created_at DESC LIMIT {}", lim); - } else { - query.push_str(" ORDER BY e.created_at ASC") - } - (query, params, idx_name) -} - -/// Create a dynamic SQL query string and params from a subscription. -fn query_from_sub(sub: &Subscription) -> (String, Vec>, Vec) { - // build a dynamic SQL query for an entire subscription, based on - // SQL subqueries for filters. - let mut subqueries: Vec = Vec::new(); - let mut indexes = vec![]; - // subquery params - let mut params: Vec> = vec![]; - // for every filter in the subscription, generate a subquery - for f in sub.filters.iter() { - let (f_subquery, mut f_params, index) = query_from_filter(f); - if let Some(i) = index { - indexes.push(i); - } - subqueries.push(f_subquery); - params.append(&mut f_params); - } - // encapsulate subqueries into select statements - let subqueries_selects: Vec = subqueries - .iter() - .map(|s| format!("SELECT distinct content, created_at FROM ({})", s)) - .collect(); - let query: String = subqueries_selects.join(" UNION "); - (query, params,indexes) -} - -/// Check if the pool is fully utilized -fn _pool_at_capacity(pool: &SqlitePool) -> bool { - let state: r2d2::State = pool.state(); - state.idle_connections == 0 -} - -/// Log pool stats -fn log_pool_stats(name: &str, pool: &SqlitePool) { - let state: r2d2::State = pool.state(); - let in_use_cxns = state.connections - state.idle_connections; - debug!( - "DB pool {:?} usage (in_use: {}, available: {}, max: {})", - name, - in_use_cxns, - state.connections, - pool.max_size() - ); -} - - -/// Perform database maintenance on a regular basis -pub async fn db_optimize_task(pool: SqlitePool) { - tokio::task::spawn(async move { - loop { - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(60*60)) => { - if let Ok(mut conn) = pool.get() { - // the busy timer will block writers, so don't set - // this any higher than you want max latency for event - // writes. - info!("running database optimizer"); - optimize_db(&mut conn).ok(); - } - } - }; - } - }); -} - -/// Perform database WAL checkpoint on a regular basis -pub async fn db_checkpoint_task(pool: SqlitePool, safe_to_read: Arc>) { - tokio::task::spawn(async move { - // WAL size in pages. - let mut current_wal_size = 0; - // WAL threshold for more aggressive checkpointing (10,000 pages, or about 40MB) - let wal_threshold = 1000*10; - // default threshold for the busy timer - let busy_wait_default = Duration::from_secs(1); - // if the WAL file is getting too big, switch to this - let busy_wait_default_long = Duration::from_secs(10); - loop { - tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(CHECKPOINT_FREQ_SEC)) => { - if let Ok(mut conn) = pool.get() { - let mut _guard:Option> = None; - // the busy timer will block writers, so don't set - // this any higher than you want max latency for event - // writes. - if current_wal_size <= wal_threshold { - conn.busy_timeout(busy_wait_default).ok(); - } else { - // if the wal size has exceeded a threshold, increase the busy timeout. - conn.busy_timeout(busy_wait_default_long).ok(); - // take a lock that will prevent new readers. - info!("blocking new readers to perform wal_checkpoint"); - _guard = Some(safe_to_read.lock().await); - } - debug!("running wal_checkpoint(TRUNCATE)"); - if let Ok(new_size) = checkpoint_db(&mut conn) { - current_wal_size = new_size; - } - } - } - }; - } - }); -} - -/// Perform a database query using a subscription. -/// -/// The [`Subscription`] is converted into a SQL query. Each result -/// is published on the `query_tx` channel as it is returned. If a -/// message becomes available on the `abandon_query_rx` channel, the -/// query is immediately aborted. -pub async fn db_query( - sub: Subscription, - client_id: String, - pool: SqlitePool, - query_tx: tokio::sync::mpsc::Sender, - mut abandon_query_rx: tokio::sync::oneshot::Receiver<()>, - safe_to_read: Arc>, -) { - let pre_spawn_start = Instant::now(); - task::spawn_blocking(move || { - { - // if we are waiting on a checkpoint, stop until it is complete - let _ = safe_to_read.blocking_lock(); - } - let db_queue_time = pre_spawn_start.elapsed(); - // if the queue time was very long (>5 seconds), spare the DB and abort. - if db_queue_time > Duration::from_secs(5) { - info!( - "shedding DB query load queued for {:?} (cid: {}, sub: {:?})", - db_queue_time, client_id, sub.id - ); - return Ok(()); - } - // otherwise, report queuing time if it is slow - else if db_queue_time > Duration::from_secs(1) { - debug!( - "(slow) DB query queued for {:?} (cid: {}, sub: {:?})", - db_queue_time, client_id, sub.id - ); - } - let start = Instant::now(); - let mut row_count: usize = 0; - // generate SQL query - let (q, p, idxs) = query_from_sub(&sub); - let sql_gen_elapsed = start.elapsed(); - - if sql_gen_elapsed > Duration::from_millis(10) { - debug!("SQL (slow) generated in {:?}", start.elapsed()); - } - // cutoff for displaying slow queries - let slow_cutoff = Duration::from_millis(2000); - // any client that doesn't cause us to generate new rows in 5 - // seconds gets dropped. - let abort_cutoff = Duration::from_secs(5); - let start = Instant::now(); - let mut slow_first_event; - let mut last_successful_send = Instant::now(); - if let Ok(mut conn) = pool.get() { - // execute the query. - // make the actual SQL query (with parameters inserted) available - conn.trace(Some(|x| {trace!("SQL trace: {:?}", x)})); - let mut stmt = conn.prepare_cached(&q)?; - let mut event_rows = stmt.query(rusqlite::params_from_iter(p))?; - - let mut first_result = true; - while let Some(row) = event_rows.next()? { - let first_event_elapsed = start.elapsed(); - slow_first_event = first_event_elapsed >= slow_cutoff; - if first_result { - debug!( - "first result in {:?} (cid: {}, sub: {:?}) [used indexes: {:?}]", - first_event_elapsed, client_id, sub.id, idxs - ); - first_result = false; - } - // logging for slow queries; show sub and SQL. - // to reduce logging; only show 1/16th of clients (leading 0) - if row_count == 0 && slow_first_event && client_id.starts_with('0') { - debug!( - "query req (slow): {:?} (cid: {}, sub: {:?})", - sub, client_id, sub.id - ); - } - // check if a checkpoint is trying to run, and abort - if row_count % 100 == 0 { - { - if safe_to_read.try_lock().is_err() { - // lock was held, abort this query - debug!("query aborted due to checkpoint (cid: {}, sub: {:?})", client_id, sub.id); - return Ok(()); - } - } - } - - // check if this is still active; every 100 rows - if row_count % 100 == 0 && abandon_query_rx.try_recv().is_ok() { - debug!("query aborted (cid: {}, sub: {:?})", client_id, sub.id); - return Ok(()); - } - row_count += 1; - let event_json = row.get(0)?; - loop { - if query_tx.capacity() != 0 { - // we have capacity to add another item - break; - } else { - // the queue is full - trace!("db reader thread is stalled"); - if last_successful_send + abort_cutoff < Instant::now() { - // the queue has been full for too long, abort - info!("aborting database query due to slow client (cid: {}, sub: {:?})", - client_id, sub.id); - let ok: Result<()> = Ok(()); - return ok; - } - // check if a checkpoint is trying to run, and abort - if safe_to_read.try_lock().is_err() { - // lock was held, abort this query - debug!("query aborted due to checkpoint (cid: {}, sub: {:?})", client_id, sub.id); - return Ok(()); - } - // give the queue a chance to clear before trying again - thread::sleep(Duration::from_millis(100)); - } - } - // TODO: we could use try_send, but we'd have to juggle - // getting the query result back as part of the error - // result. - query_tx - .blocking_send(QueryResult { - sub_id: sub.get_id(), - event: event_json, - }) - .ok(); - last_successful_send = Instant::now(); - } - query_tx - .blocking_send(QueryResult { - sub_id: sub.get_id(), - event: "EOSE".to_string(), - }) - .ok(); - debug!( - "query completed in {:?} (cid: {}, sub: {:?}, db_time: {:?}, rows: {})", - pre_spawn_start.elapsed(), - client_id, - sub.id, - start.elapsed(), - row_count - ); - } else { - warn!("Could not get a database connection for querying"); - } - let ok: Result<()> = Ok(()); - ok - }); -} diff --git a/src/delegation.rs b/src/delegation.rs index e682b40..640cb00 100644 --- a/src/delegation.rs +++ b/src/delegation.rs @@ -84,7 +84,7 @@ pub struct ConditionQuery { } impl ConditionQuery { - pub fn allows_event(&self, event: &Event) -> bool { + #[must_use] pub fn allows_event(&self, event: &Event) -> bool { // check each condition, to ensure that the event complies // with the restriction. for c in &self.conditions { @@ -101,7 +101,7 @@ impl ConditionQuery { } // Verify that the delegator approved the delegation; return a ConditionQuery if so. -pub fn validate_delegation( +#[must_use] pub fn validate_delegation( delegator: &str, delegatee: &str, cond_query: &str, @@ -133,8 +133,8 @@ pub fn validate_delegation( } /// Parsed delegation condition -/// see https://github.com/nostr-protocol/nips/pull/28#pullrequestreview-1084903800 -/// An example complex condition would be: kind=1,2,3&created_at<1665265999 +/// see +/// An example complex condition would be: `kind=1,2,3&created_at<1665265999` #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] pub struct Condition { pub field: Field, @@ -144,7 +144,7 @@ pub struct Condition { impl Condition { /// Check if this condition allows the given event to be delegated - pub fn allows_event(&self, event: &Event) -> bool { + #[must_use] pub fn allows_event(&self, event: &Event) -> bool { // determine what the right-hand side of the operator is let resolved_field = match &self.field { Field::Kind => event.kind, @@ -323,7 +323,7 @@ mod tests { Condition { field: Field::CreatedAt, operator: Operator::LessThan, - values: vec![1665867123], + values: vec![1_665_867_123], }, ], }; diff --git a/src/event.rs b/src/event.rs index d476bee..88765c6 100644 --- a/src/event.rs +++ b/src/event.rs @@ -1,6 +1,6 @@ //! Event parsing and validation use crate::delegation::validate_delegation; -use crate::error::Error::*; +use crate::error::Error::{CommandUnknownError, EventCouldNotCanonicalize, EventInvalidId, EventInvalidSignature, EventMalformedPubkey}; use crate::error::Result; use crate::nip05; use crate::utils::unix_time; @@ -28,7 +28,7 @@ pub struct EventCmd { } impl EventCmd { - pub fn event_id(&self) -> &str { + #[must_use] pub fn event_id(&self) -> &str { &self.event.id } } @@ -65,7 +65,7 @@ where } /// Attempt to form a single-char tag name. -pub fn single_char_tagname(tagname: &str) -> Option { +#[must_use] pub fn single_char_tagname(tagname: &str) -> Option { // We return the tag character if and only if the tagname consists // of a single char. let mut tagnamechars = tagname.chars(); @@ -87,22 +87,22 @@ pub fn single_char_tagname(tagname: &str) -> Option { impl From for Result { fn from(ec: EventCmd) -> Result { // ensure command is correct - if ec.cmd != "EVENT" { - Err(CommandUnknownError) - } else { - ec.event.validate().map(|_| { + if ec.cmd == "EVENT" { + ec.event.validate().map(|_| { let mut e = ec.event; e.build_index(); e.update_delegation(); e }) + } else { + Err(CommandUnknownError) } } } impl Event { #[cfg(test)] - pub fn simple_event() -> Event { + #[must_use] pub fn simple_event() -> Event { Event { id: "0".to_owned(), pubkey: "0".to_owned(), @@ -116,17 +116,17 @@ impl Event { } } - pub fn is_kind_metadata(&self) -> bool { + #[must_use] pub fn is_kind_metadata(&self) -> bool { self.kind == 0 } /// Should this event be replaced with newer timestamps from same author? - pub fn is_replaceable(&self) -> bool { + #[must_use] pub fn is_replaceable(&self) -> bool { self.kind == 0 || self.kind == 3 || self.kind == 41 || (self.kind >= 10000 && self.kind < 20000) } /// Pull a NIP-05 Name out of the event, if one exists - pub fn get_nip05_addr(&self) -> Option { + #[must_use] pub fn get_nip05_addr(&self) -> Option { if self.is_kind_metadata() { // very quick check if we should attempt to parse this json if self.content.contains("\"nip05\"") { @@ -143,7 +143,7 @@ impl Event { // is this event delegated (properly)? // does the signature match, and are conditions valid? // if so, return an alternate author for the event - pub fn delegated_author(&self) -> Option { + #[must_use] pub fn delegated_author(&self) -> Option { // is there a delegation tag? let delegation_tag: Vec = self .tags @@ -151,8 +151,7 @@ impl Event { .filter(|x| x.len() == 4) .filter(|x| x.get(0).unwrap() == "delegation") .take(1) - .next()? - .to_vec(); // get first tag + .next()?.clone(); // get first tag //let delegation_tag = self.tag_values_by_name("delegation"); // delegation tags should have exactly 3 elements after the name (pubkey, condition, sig) @@ -212,24 +211,24 @@ impl Event { } /// Create a short event identifier, suitable for logging. - pub fn get_event_id_prefix(&self) -> String { + #[must_use] pub fn get_event_id_prefix(&self) -> String { self.id.chars().take(8).collect() } - pub fn get_author_prefix(&self) -> String { + #[must_use] pub fn get_author_prefix(&self) -> String { self.pubkey.chars().take(8).collect() } /// Retrieve tag initial values across all tags matching the name - pub fn tag_values_by_name(&self, tag_name: &str) -> Vec { + #[must_use] pub fn tag_values_by_name(&self, tag_name: &str) -> Vec { self.tags .iter() .filter(|x| x.len() > 1) .filter(|x| x.get(0).unwrap() == tag_name) - .map(|x| x.get(1).unwrap().to_owned()) + .map(|x| x.get(1).unwrap().clone()) .collect() } - pub fn is_valid_timestamp(&self, reject_future_seconds: Option) -> bool { + #[must_use] pub fn is_valid_timestamp(&self, reject_future_seconds: Option) -> bool { if let Some(allowable_future) = reject_future_seconds { let curr_time = unix_time(); // calculate difference, plus how far future we allow @@ -291,7 +290,7 @@ impl Event { let id = Number::from(0_u64); c.push(serde_json::Value::Number(id)); // public key - c.push(Value::String(self.pubkey.to_owned())); + c.push(Value::String(self.pubkey.clone())); // creation time let created_at = Number::from(self.created_at); c.push(serde_json::Value::Number(created_at)); @@ -301,7 +300,7 @@ impl Event { // tags c.push(self.tags_to_canonical()); // content - c.push(Value::String(self.content.to_owned())); + c.push(Value::String(self.content.clone())); serde_json::to_string(&Value::Array(c)).ok() } @@ -309,11 +308,11 @@ impl Event { fn tags_to_canonical(&self) -> Value { let mut tags = Vec::::new(); // iterate over self tags, - for t in self.tags.iter() { + for t in &self.tags { // each tag is a vec of strings let mut a = Vec::::new(); for v in t.iter() { - a.push(serde_json::Value::String(v.to_owned())); + a.push(serde_json::Value::String(v.clone())); } tags.push(serde_json::Value::Array(a)); } @@ -321,7 +320,7 @@ impl Event { } /// Determine if the given tag and value set intersect with tags in this event. - pub fn generic_tag_val_intersect(&self, tagname: char, check: &HashSet) -> bool { + #[must_use] pub fn generic_tag_val_intersect(&self, tagname: char, check: &HashSet) -> bool { match &self.tagidx { // check if this is indexable tagname Some(idx) => match idx.get(&tagname) { @@ -413,7 +412,7 @@ mod tests { id: "999".to_owned(), pubkey: "012345".to_owned(), delegated_by: None, - created_at: 501234, + created_at: 501_234, kind: 1, tags: vec![], content: "this is a test".to_owned(), @@ -431,7 +430,7 @@ mod tests { id: "999".to_owned(), pubkey: "012345".to_owned(), delegated_by: None, - created_at: 501234, + created_at: 501_234, kind: 1, tags: vec![ vec!["j".to_owned(), "abc".to_owned()], @@ -458,7 +457,7 @@ mod tests { id: "999".to_owned(), pubkey: "012345".to_owned(), delegated_by: None, - created_at: 501234, + created_at: 501_234, kind: 1, tags: vec![ vec!["j".to_owned(), "abc".to_owned()], @@ -485,7 +484,7 @@ mod tests { id: "999".to_owned(), pubkey: "012345".to_owned(), delegated_by: None, - created_at: 501234, + created_at: 501_234, kind: 1, tags: vec![ vec!["#e".to_owned(), "aoeu".to_owned()], diff --git a/src/hexrange.rs b/src/hexrange.rs index e571778..fa9742b 100644 --- a/src/hexrange.rs +++ b/src/hexrange.rs @@ -19,7 +19,7 @@ fn is_all_fs(s: &str) -> bool { } /// Find the next hex sequence greater than the argument. -pub fn hex_range(s: &str) -> Option { +#[must_use] pub fn hex_range(s: &str) -> Option { // handle special cases if !is_hex(s) || s.len() > 64 { return None; diff --git a/src/info.rs b/src/info.rs index e0b12f9..29bc7fb 100644 --- a/src/info.rs +++ b/src/info.rs @@ -37,7 +37,7 @@ impl From for RelayInfo { contact: i.contact, supported_nips: Some(vec![1, 2, 9, 11, 12, 15, 16, 20, 22]), software: Some("https://git.sr.ht/~gheartsfield/nostr-rs-relay".to_owned()), - version: CARGO_PKG_VERSION.map(|x| x.to_owned()), + version: CARGO_PKG_VERSION.map(std::borrow::ToOwned::to_owned), } } } diff --git a/src/lib.rs b/src/lib.rs index 45b94e4..7b13866 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,7 +10,7 @@ pub mod hexrange; pub mod info; pub mod nip05; pub mod notice; -pub mod schema; +pub mod repo; pub mod subscription; pub mod utils; // Public API for creating relays programatically diff --git a/src/main.rs b/src/main.rs index 5da2b02..35dc094 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,6 @@ //! Server process use clap::Parser; -use nostr_rs_relay::cli::*; +use nostr_rs_relay::cli::CLIArgs; use nostr_rs_relay::config; use nostr_rs_relay::server::start_server; use std::sync::mpsc as syncmpsc; @@ -37,12 +37,15 @@ fn main() { if let Some(db_dir) = db_dir_arg { settings.database.data_directory = db_dir; } + // we should have a 'control plane' channel to monitor and bump + // the server. this will let us do stuff like clear the database, + // shutdown, etc.; for now all this does is initiate shutdown if + // `()` is sent. This will change in the future, this is just a + // stopgap to shutdown the relay when it is used as a library. let (_, ctrl_rx): (MpscSender<()>, MpscReceiver<()>) = syncmpsc::channel(); // run this in a new thread - let handle = thread::spawn(|| { - // we should have a 'control plane' channel to monitor and bump the server. - // this will let us do stuff like clear the database, shutdown, etc. - let _svr = start_server(settings, ctrl_rx); + let handle = thread::spawn(move || { + let _svr = start_server(&settings, ctrl_rx); }); // block on nostr thread to finish. handle.join().unwrap(); diff --git a/src/nip05.rs b/src/nip05.rs index 8f73751..3247be7 100644 --- a/src/nip05.rs +++ b/src/nip05.rs @@ -5,16 +5,14 @@ //! consumes a stream of metadata events, and keeps a database table //! updated with the current NIP-05 verification status. use crate::config::VerifiedUsers; -use crate::db; use crate::error::{Error, Result}; use crate::event::Event; -use crate::utils::unix_time; +use crate::repo::NostrRepo; +use std::sync::Arc; use hyper::body::HttpBody; use hyper::client::connect::HttpConnector; use hyper::Client; use hyper_tls::HttpsConnector; -use rand::Rng; -use rusqlite::params; use std::time::Duration; use std::time::Instant; use std::time::SystemTime; @@ -23,14 +21,12 @@ use tracing::{debug, info, warn}; /// NIP-05 verifier state pub struct Verifier { + /// Repository for saving/retrieving events and records + repo: Arc, /// Metadata events for us to inspect metadata_rx: tokio::sync::broadcast::Receiver, /// Newly validated events get written and then broadcast on this channel to subscribers event_tx: tokio::sync::broadcast::Sender, - /// SQLite read query pool - read_pool: db::SqlitePool, - /// SQLite write query pool - write_pool: db::SqlitePool, /// Settings settings: crate::config::Settings, /// HTTP client @@ -52,7 +48,7 @@ pub struct Nip05Name { impl Nip05Name { /// Does this name represent the entire domain? - pub fn is_domain_only(&self) -> bool { + #[must_use] pub fn is_domain_only(&self) -> bool { self.local == "_" } @@ -73,16 +69,11 @@ impl std::convert::TryFrom<&str> for Nip05Name { fn try_from(inet: &str) -> Result { // break full name at the @ boundary. let components: Vec<&str> = inet.split('@').collect(); - if components.len() != 2 { - Err(Error::CustomError("too many/few components".to_owned())) - } else { - // check if local name is valid + if components.len() == 2 { + // check if local name is valid let local = components[0]; let domain = components[1]; - if local - .chars() - .all(|x| x.is_alphanumeric() || x == '_' || x == '-' || x == '.') - { + if local.chars().all(|x| x.is_alphanumeric() || x == '_' || x == '-' || x == '.') { if domain .chars() .all(|x| x.is_alphanumeric() || x == '-' || x == '.') @@ -101,6 +92,8 @@ impl std::convert::TryFrom<&str> for Nip05Name { "invalid character in local part".to_owned(), )) } + } else { + Err(Error::CustomError("too many/few components".to_owned())) } } } @@ -111,55 +104,30 @@ impl std::fmt::Display for Nip05Name { } } -// Current time, with a slight foward jitter in seconds -fn now_jitter(sec: u64) -> u64 { - // random time between now, and 10min in future. - let mut rng = rand::thread_rng(); - let jitter_amount = rng.gen_range(0..sec); - let now = unix_time(); - now.saturating_add(jitter_amount) -} - /// Check if the specified username and address are present and match in this response body -fn body_contains_user(username: &str, address: &str, bytes: hyper::body::Bytes) -> Result { +fn body_contains_user(username: &str, address: &str, bytes: &hyper::body::Bytes) -> Result { // convert the body into json let body: serde_json::Value = serde_json::from_slice(&bytes)?; // ensure we have a names object. let names_map = body .as_object() .and_then(|x| x.get("names")) - .and_then(|x| x.as_object()) + .and_then(serde_json::Value::as_object) .ok_or_else(|| Error::CustomError("not a map".to_owned()))?; // get the pubkey for the requested user - let check_name = names_map.get(username).and_then(|x| x.as_str()); + let check_name = names_map.get(username).and_then(serde_json::Value::as_str); // ensure the address is a match - Ok(check_name.map(|x| x == address).unwrap_or(false)) + Ok(check_name.map_or(false, |x| x == address)) } impl Verifier { pub fn new( + repo: Arc, metadata_rx: tokio::sync::broadcast::Receiver, event_tx: tokio::sync::broadcast::Sender, settings: crate::config::Settings, ) -> Result { info!("creating NIP-05 verifier"); - // build a database connection for reading and writing. - let write_pool = db::build_pool( - "nip05 writer", - &settings, - rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE, - 1, // min conns - 4, // max conns - true, // wait for DB - ); - let read_pool = db::build_pool( - "nip05 reader", - &settings, - rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY, - 1, // min conns - 8, // max conns - true, // wait for DB - ); // setup hyper client let https = HttpsConnector::new(); let client = Client::builder().build::<_, hyper::Body>(https); @@ -175,10 +143,9 @@ impl Verifier { // duration. let reverify_interval = tokio::time::interval(http_wait_duration); Ok(Verifier { + repo, metadata_rx, event_tx, - read_pool, - write_pool, settings, client, wait_after_finish, @@ -246,44 +213,40 @@ impl Verifier { let response_fut = self.client.request(req); - // HTTP request with timeout - match tokio::time::timeout(Duration::from_secs(5), response_fut).await { - Ok(response_res) => { - // limit size of verification document to 1MB. - const MAX_ALLOWED_RESPONSE_SIZE: u64 = 1024 * 1024; - let response = response_res?; - // determine content length from response - let response_content_length = match response.body().size_hint().upper() { - Some(v) => v, - None => MAX_ALLOWED_RESPONSE_SIZE + 1, // reject missing content length - }; - // TODO: test how hyper handles the client providing an inaccurate content-length. - if response_content_length <= MAX_ALLOWED_RESPONSE_SIZE { - let (parts, body) = response.into_parts(); - // TODO: consider redirects - if parts.status == http::StatusCode::OK { - // parse body, determine if the username / key / address is present - let body_bytes = hyper::body::to_bytes(body).await?; - let body_matches = body_contains_user(&nip.local, pubkey, body_bytes)?; - if body_matches { - return Ok(UserWebVerificationStatus::Verified); - } - // successful response, parsed as a nip-05 - // document, but this name/pubkey was not - // present. - return Ok(UserWebVerificationStatus::Unverified); + if let Ok(response_res) = tokio::time::timeout(Duration::from_secs(5), response_fut).await { + // limit size of verification document to 1MB. + const MAX_ALLOWED_RESPONSE_SIZE: u64 = 1024 * 1024; + let response = response_res?; + // determine content length from response + let response_content_length = match response.body().size_hint().upper() { + Some(v) => v, + None => MAX_ALLOWED_RESPONSE_SIZE + 1, // reject missing content length + }; + // TODO: test how hyper handles the client providing an inaccurate content-length. + if response_content_length <= MAX_ALLOWED_RESPONSE_SIZE { + let (parts, body) = response.into_parts(); + // TODO: consider redirects + if parts.status == http::StatusCode::OK { + // parse body, determine if the username / key / address is present + let body_bytes = hyper::body::to_bytes(body).await?; + let body_matches = body_contains_user(&nip.local, pubkey, &body_bytes)?; + if body_matches { + return Ok(UserWebVerificationStatus::Verified); } - } else { - info!( - "content length missing or exceeded limits for account: {:?}", - nip.to_string() - ); + // successful response, parsed as a nip-05 + // document, but this name/pubkey was not + // present. + return Ok(UserWebVerificationStatus::Unverified); } + } else { + info!( + "content length missing or exceeded limits for account: {:?}", + nip.to_string() + ); } - Err(_) => { - info!("timeout verifying account {:?}", nip); - return Ok(UserWebVerificationStatus::Unknown); - } + } else { + info!("timeout verifying account {:?}", nip); + return Ok(UserWebVerificationStatus::Unknown); } Ok(UserWebVerificationStatus::Unknown) } @@ -309,7 +272,7 @@ impl Verifier { if let Some(naddr) = e.get_nip05_addr() { info!("got metadata event for ({:?},{:?})", naddr.to_string() ,e.get_author_prefix()); // Process a new author, checking if they are verified: - let check_verified = get_latest_user_verification(self.read_pool.get().expect("could not get connection"), &e.pubkey).await; + let check_verified = self.repo.get_latest_user_verification(&e.pubkey).await; // ensure the event we got is more recent than the one we have, otherwise we can ignore it. if let Ok(last_check) = check_verified { if e.created_at <= last_check.event_created { @@ -370,7 +333,7 @@ impl Verifier { .duration_since(SystemTime::UNIX_EPOCH) .map(|x| x.as_secs()) .unwrap_or(0); - let vr = get_oldest_user_verification(self.read_pool.get()?, earliest_epoch).await; + let vr = self.repo.get_oldest_user_verification(earliest_epoch).await; match vr { Ok(ref v) => { let new_status = self.get_web_verification(&v.name, &v.address).await; @@ -378,8 +341,10 @@ impl Verifier { UserWebVerificationStatus::Verified => { // freshly verified account, update the // timestamp. - self.update_verification_record(self.write_pool.get()?, v) + self.repo.update_verification_timestamp(v.rowid) .await?; + info!("verification updated for {}", v.to_string()); + } UserWebVerificationStatus::DomainNotAllowed | UserWebVerificationStatus::Unknown => { @@ -394,18 +359,19 @@ impl Verifier { "giving up on verifying {:?} after {} failures", v.name, v.failure_count ); - self.delete_verification_record(self.write_pool.get()?, v) + self.repo.delete_verification(v.rowid) .await?; } else { // record normal failure, incrementing failure count - self.fail_verification_record(self.write_pool.get()?, v) - .await?; + info!("verification failed for {}", v.to_string()); + self.repo.fail_verification(v.rowid).await?; } } UserWebVerificationStatus::Unverified => { // domain has removed the verification, drop // the record on our side. - self.delete_verification_record(self.write_pool.get()?, v) + info!("verification rescinded for {}", v.to_string()); + self.repo.delete_verification(v.rowid) .await?; } } @@ -426,80 +392,6 @@ impl Verifier { Ok(()) } - /// Reset the verification timestamp on a VerificationRecord - pub async fn update_verification_record( - &mut self, - mut conn: db::PooledConnection, - vr: &VerificationRecord, - ) -> Result<()> { - let vr_id = vr.rowid; - let vr_str = vr.to_string(); - tokio::task::spawn_blocking(move || { - // add some jitter to the verification to prevent everything from stacking up together. - let verif_time = now_jitter(600); - let tx = conn.transaction()?; - { - // update verification time and reset any failure count - let query = - "UPDATE user_verification SET verified_at=?, failure_count=0 WHERE id=?"; - let mut stmt = tx.prepare(query)?; - stmt.execute(params![verif_time, vr_id])?; - } - tx.commit()?; - info!("verification updated for {}", vr_str); - let ok: Result<()> = Ok(()); - ok - }) - .await? - } - /// Reset the failure timestamp on a VerificationRecord - pub async fn fail_verification_record( - &mut self, - mut conn: db::PooledConnection, - vr: &VerificationRecord, - ) -> Result<()> { - let vr_id = vr.rowid; - let vr_str = vr.to_string(); - let fail_count = vr.failure_count.saturating_add(1); - tokio::task::spawn_blocking(move || { - // add some jitter to the verification to prevent everything from stacking up together. - let fail_time = now_jitter(600); - let tx = conn.transaction()?; - { - let query = "UPDATE user_verification SET failed_at=?, failure_count=? WHERE id=?"; - let mut stmt = tx.prepare(query)?; - stmt.execute(params![fail_time, fail_count, vr_id])?; - } - tx.commit()?; - info!("verification failed for {}", vr_str); - let ok: Result<()> = Ok(()); - ok - }) - .await? - } - /// Delete a VerificationRecord that is no longer valid - pub async fn delete_verification_record( - &mut self, - mut conn: db::PooledConnection, - vr: &VerificationRecord, - ) -> Result<()> { - let vr_id = vr.rowid; - let vr_str = vr.to_string(); - tokio::task::spawn_blocking(move || { - let tx = conn.transaction()?; - { - let query = "DELETE FROM user_verification WHERE id=?;"; - let mut stmt = tx.prepare(query)?; - stmt.execute(params![vr_id])?; - } - tx.commit()?; - info!("verification rescinded for {}", vr_str); - let ok: Result<()> = Ok(()); - ok - }) - .await? - } - /// Persist an event, create a verification record, and broadcast. // TODO: have more event-writing logic handled in the db module. // Right now, these events avoid the rate limit. That is @@ -513,27 +405,27 @@ impl Verifier { // disabled/passive, the event has already been persisted. let should_write_event = self.settings.verified_users.is_enabled(); if should_write_event { - match db::write_event(&mut self.write_pool.get()?, event) { - Ok(updated) => { - if updated != 0 { - info!( - "persisted event (new verified pubkey): {:?} in {:?}", - event.get_event_id_prefix(), - start.elapsed() - ); - self.event_tx.send(event.clone()).ok(); - } - } - Err(err) => { - warn!("event insert failed: {:?}", err); - if let Error::SqlError(r) = err { - warn!("because: : {:?}", r); - } - } - } + match self.repo.write_event(event).await { + Ok(updated) => { + if updated != 0 { + info!( + "persisted event (new verified pubkey): {:?} in {:?}", + event.get_event_id_prefix(), + start.elapsed() + ); + self.event_tx.send(event.clone()).ok(); + } + } + Err(err) => { + warn!("event insert failed: {:?}", err); + if let Error::SqlError(r) = err { + warn!("because: : {:?}", r); + } + } + } } // write the verification record - save_verification_record(self.write_pool.get()?, event, name).await?; + self.repo.create_verification_record(&event.id, name).await?; Ok(()) } } @@ -563,7 +455,7 @@ pub struct VerificationRecord { /// Check with settings to determine if a given domain is allowed to /// publish. -pub fn is_domain_allowed( +#[must_use] pub fn is_domain_allowed( domain: &str, whitelist: &Option>, blacklist: &Option>, @@ -583,7 +475,7 @@ pub fn is_domain_allowed( impl VerificationRecord { /// Check if the record is recent enough to be considered valid, /// and the domain is allowed. - pub fn is_valid(&self, verified_users_settings: &VerifiedUsers) -> bool { + #[must_use] pub fn is_valid(&self, verified_users_settings: &VerifiedUsers) -> bool { //let settings = SETTINGS.read().unwrap(); // how long a verification record is good for let nip05_expiration = &verified_users_settings.verify_expiration_duration; @@ -630,130 +522,6 @@ impl std::fmt::Display for VerificationRecord { } } -/// Create a new verification record based on an event -pub async fn save_verification_record( - mut conn: db::PooledConnection, - event: &Event, - name: &str, -) -> Result<()> { - let e = hex::decode(&event.id).ok(); - let n = name.to_owned(); - let a_prefix = event.get_author_prefix(); - tokio::task::spawn_blocking(move || { - let tx = conn.transaction()?; - { - // if we create a /new/ one, we should get rid of any old ones. or group the new ones by name and only consider the latest. - let query = "INSERT INTO user_verification (metadata_event, name, verified_at) VALUES ((SELECT id from event WHERE event_hash=?), ?, strftime('%s','now'));"; - let mut stmt = tx.prepare(query)?; - stmt.execute(params![e, n])?; - // get the row ID - let v_id = tx.last_insert_rowid(); - // delete everything else by this name - let del_query = "DELETE FROM user_verification WHERE name = ? AND id != ?;"; - let mut del_stmt = tx.prepare(del_query)?; - let count = del_stmt.execute(params![n,v_id])?; - if count > 0 { - info!("removed {} old verification records for ({:?},{:?})", count, n, a_prefix); - } - } - tx.commit()?; - info!("saved new verification record for ({:?},{:?})", n, a_prefix); - let ok: Result<()> = Ok(()); - ok - }).await? -} - -/// Retrieve the most recent verification record for a given pubkey (async). -pub async fn get_latest_user_verification( - conn: db::PooledConnection, - pubkey: &str, -) -> Result { - let p = pubkey.to_owned(); - tokio::task::spawn_blocking(move || query_latest_user_verification(conn, p)).await? -} - -/// Query database for the latest verification record for a given pubkey. -pub fn query_latest_user_verification( - mut conn: db::PooledConnection, - pubkey: String, -) -> Result { - let tx = conn.transaction()?; - let query = "SELECT v.id, v.name, e.event_hash, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v LEFT JOIN event e ON e.id=v.metadata_event WHERE e.author=? ORDER BY e.created_at DESC, v.verified_at DESC, v.failed_at DESC LIMIT 1;"; - let mut stmt = tx.prepare_cached(query)?; - let fields = stmt.query_row(params![hex::decode(&pubkey).ok()], |r| { - let rowid: u64 = r.get(0)?; - let rowname: String = r.get(1)?; - let eventid: Vec = r.get(2)?; - let created_at: u64 = r.get(3)?; - // create a tuple since we can't throw non-rusqlite errors in this closure - Ok(( - rowid, - rowname, - eventid, - created_at, - r.get(4).ok(), - r.get(5).ok(), - r.get(6)?, - )) - })?; - Ok(VerificationRecord { - rowid: fields.0, - name: Nip05Name::try_from(&fields.1[..])?, - address: pubkey, - event: hex::encode(fields.2), - event_created: fields.3, - last_success: fields.4, - last_failure: fields.5, - failure_count: fields.6, - }) -} - -/// Retrieve the oldest user verification (async) -pub async fn get_oldest_user_verification( - conn: db::PooledConnection, - earliest: u64, -) -> Result { - tokio::task::spawn_blocking(move || query_oldest_user_verification(conn, earliest)).await? -} - -pub fn query_oldest_user_verification( - mut conn: db::PooledConnection, - earliest: u64, -) -> Result { - let tx = conn.transaction()?; - let query = "SELECT v.id, v.name, e.event_hash, e.author, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v INNER JOIN event e ON e.id=v.metadata_event WHERE (v.verified_at < ? OR v.verified_at IS NULL) AND (v.failed_at < ? OR v.failed_at IS NULL) ORDER BY v.verified_at ASC, v.failed_at ASC LIMIT 1;"; - let mut stmt = tx.prepare_cached(query)?; - let fields = stmt.query_row(params![earliest, earliest], |r| { - let rowid: u64 = r.get(0)?; - let rowname: String = r.get(1)?; - let eventid: Vec = r.get(2)?; - let pubkey: Vec = r.get(3)?; - let created_at: u64 = r.get(4)?; - // create a tuple since we can't throw non-rusqlite errors in this closure - Ok(( - rowid, - rowname, - eventid, - pubkey, - created_at, - r.get(5).ok(), - r.get(6).ok(), - r.get(7)?, - )) - })?; - let vr = VerificationRecord { - rowid: fields.0, - name: Nip05Name::try_from(&fields.1[..])?, - address: hex::encode(fields.3), - event: hex::encode(fields.2), - event_created: fields.4, - last_success: fields.5, - last_failure: fields.6, - failure_count: fields.7, - }; - Ok(vr) -} - #[cfg(test)] mod tests { use super::*; @@ -762,7 +530,7 @@ mod tests { fn local_from_inet() { let addr = "bob@example.com"; let parsed = Nip05Name::try_from(addr); - assert!(!parsed.is_err()); + assert!(parsed.is_ok()); let v = parsed.unwrap(); assert_eq!(v.local, "bob"); assert_eq!(v.domain, "example.com"); diff --git a/src/notice.rs b/src/notice.rs index f229eb5..f780683 100644 --- a/src/notice.rs +++ b/src/notice.rs @@ -19,18 +19,14 @@ pub enum Notice { } impl EventResultStatus { - pub fn to_bool(&self) -> bool { + #[must_use] pub fn to_bool(&self) -> bool { match self { - Self::Saved => true, - Self::Duplicate => true, - Self::Invalid => false, - Self::Blocked => false, - Self::RateLimited => false, - Self::Error => false, + Self::Duplicate | Self::Saved => true, + Self::Invalid |Self::Blocked | Self::RateLimited | Self::Error => false, } } - pub fn prefix(&self) -> &'static str { + #[must_use] pub fn prefix(&self) -> &'static str { match self { Self::Saved => "saved", Self::Duplicate => "duplicate", @@ -47,7 +43,7 @@ impl Notice { // Notice::err_msg(format!("{}", err), id) //} - pub fn message(msg: String) -> Notice { + #[must_use] pub fn message(msg: String) -> Notice { Notice::Message(msg) } @@ -56,27 +52,27 @@ impl Notice { Notice::EventResult(EventResult { id, msg, status }) } - pub fn invalid(id: String, msg: &str) -> Notice { + #[must_use] pub fn invalid(id: String, msg: &str) -> Notice { Notice::prefixed(id, msg, EventResultStatus::Invalid) } - pub fn blocked(id: String, msg: &str) -> Notice { + #[must_use] pub fn blocked(id: String, msg: &str) -> Notice { Notice::prefixed(id, msg, EventResultStatus::Blocked) } - pub fn rate_limited(id: String, msg: &str) -> Notice { + #[must_use] pub fn rate_limited(id: String, msg: &str) -> Notice { Notice::prefixed(id, msg, EventResultStatus::RateLimited) } - pub fn duplicate(id: String) -> Notice { + #[must_use] pub fn duplicate(id: String) -> Notice { Notice::prefixed(id, "", EventResultStatus::Duplicate) } - pub fn error(id: String, msg: &str) -> Notice { + #[must_use] pub fn error(id: String, msg: &str) -> Notice { Notice::prefixed(id, msg, EventResultStatus::Error) } - pub fn saved(id: String) -> Notice { + #[must_use] pub fn saved(id: String) -> Notice { Notice::EventResult(EventResult { id, msg: "".into(), diff --git a/src/repo/mod.rs b/src/repo/mod.rs new file mode 100644 index 0000000..191409f --- /dev/null +++ b/src/repo/mod.rs @@ -0,0 +1,67 @@ +use crate::db::QueryResult; +use crate::error::Result; +use crate::event::Event; +use crate::nip05::VerificationRecord; +use crate::subscription::Subscription; +use crate::utils::unix_time; +use async_trait::async_trait; +use rand::Rng; + +pub mod sqlite; +pub mod sqlite_migration; + +#[async_trait] +pub trait NostrRepo: Send + Sync { + /// Start the repository (any initialization or maintenance tasks can be kicked off here) + async fn start(&self) -> Result<()>; + + /// Run migrations and return current version + async fn migrate_up(&self) -> Result; + + /// Persist event to database + async fn write_event(&self, e: &Event) -> Result; + + /// Perform a database query using a subscription. + /// + /// The [`Subscription`] is converted into a SQL query. Each result + /// is published on the `query_tx` channel as it is returned. If a + /// message becomes available on the `abandon_query_rx` channel, the + /// query is immediately aborted. + async fn query_subscription( + &self, + sub: Subscription, + client_id: String, + query_tx: tokio::sync::mpsc::Sender, + mut abandon_query_rx: tokio::sync::oneshot::Receiver<()>, + ) -> Result<()>; + + /// Perform normal maintenance + async fn optimize_db(&self) -> Result<()>; + + /// Create a new verification record connected to a specific event + async fn create_verification_record(&self, event_id: &str, name: &str) -> Result<()>; + + /// Update verification timestamp + async fn update_verification_timestamp(&self, id: u64) -> Result<()>; + + /// Update verification record as failed + async fn fail_verification(&self, id: u64) -> Result<()>; + + /// Delete verification record + async fn delete_verification(&self, id: u64) -> Result<()>; + + /// Get the latest verification record for a given pubkey. + async fn get_latest_user_verification(&self, pub_key: &str) -> Result; + + /// Get oldest verification before timestamp + async fn get_oldest_user_verification(&self, before: u64) -> Result; +} + +// Current time, with a slight forward jitter in seconds +pub(crate) fn now_jitter(sec: u64) -> u64 { + // random time between now, and 10min in future. + let mut rng = rand::thread_rng(); + let jitter_amount = rng.gen_range(0..sec); + let now = unix_time(); + now.saturating_add(jitter_amount) +} diff --git a/src/repo/sqlite.rs b/src/repo/sqlite.rs new file mode 100644 index 0000000..c610812 --- /dev/null +++ b/src/repo/sqlite.rs @@ -0,0 +1,948 @@ +//! Event persistence and querying +//use crate::config::SETTINGS; +use crate::config::Settings; +use crate::error::Result; +use crate::event::{single_char_tagname, Event}; +use crate::hexrange::hex_range; +use crate::hexrange::HexSearch; +use crate::repo::sqlite_migration::{STARTUP_SQL,upgrade_db}; +use crate::utils::{is_hex, is_lower_hex}; +use crate::nip05::{Nip05Name, VerificationRecord}; +use crate::subscription::{ReqFilter, Subscription}; +use hex; +use r2d2; +use r2d2_sqlite::SqliteConnectionManager; +use rusqlite::params; +use rusqlite::types::ToSql; +use rusqlite::OpenFlags; +use tokio::sync::{Mutex, MutexGuard}; +use std::fmt::Write as _; +use std::path::Path; +use std::sync::Arc; +use std::thread; +use std::time::Duration; +use std::time::Instant; +use tokio::task; +use tracing::{debug, info, trace, warn}; +use async_trait::async_trait; +use crate::db::QueryResult; + +use crate::repo::{now_jitter, NostrRepo}; + +pub type SqlitePool = r2d2::Pool; +pub type PooledConnection = r2d2::PooledConnection; +pub const DB_FILE: &str = "nostr.db"; + +#[derive(Clone)] +pub struct SqliteRepo { + /// Pool for reading events and NIP-05 status + read_pool: SqlitePool, + /// Pool for writing events and NIP-05 verification + write_pool: SqlitePool, + /// Pool for performing checkpoints/optimization + maint_pool: SqlitePool, + /// Flag to indicate a checkpoint is underway + checkpoint_in_progress: Arc>, + /// Flag to limit writer concurrency + write_in_progress: Arc>, +} + +impl SqliteRepo { + // build all the pools needed + #[must_use] pub fn new(settings: &Settings) -> SqliteRepo { + let maint_pool = build_pool( + "maintenance", + settings, + OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE, + 1, + 2, + true, + ); + let read_pool = build_pool( + "reader", + settings, + OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE, + settings.database.min_conn, + settings.database.max_conn, + true, + ); + let write_pool = build_pool( + "writer", + settings, + OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE, + 1, + 2, + false, + ); + + // this is used to block new reads during critical checkpoints + let checkpoint_in_progress = Arc::new(Mutex::new(0)); + // SQLite can only effectively write single threaded, so don't + // block multiple worker threads unnecessarily. + let write_in_progress = Arc::new(Mutex::new(0)); + + SqliteRepo { + read_pool, + write_pool, + maint_pool, + checkpoint_in_progress, + write_in_progress, + } + } + + /// Persist an event to the database, returning rows added. + pub fn persist_event(conn: &mut PooledConnection, e: &Event) -> Result { + // enable auto vacuum + conn.execute_batch("pragma auto_vacuum = FULL")?; + + // start transaction + let tx = conn.transaction()?; + // get relevant fields from event and convert to blobs. + let id_blob = hex::decode(&e.id).ok(); + let pubkey_blob: Option> = hex::decode(&e.pubkey).ok(); + let delegator_blob: Option> = e.delegated_by.as_ref().and_then(|d| hex::decode(d).ok()); + let event_str = serde_json::to_string(&e).ok(); + // check for replaceable events that would hide this one; we won't even attempt to insert these. + if e.is_replaceable() { + let repl_count = tx.query_row( + "SELECT e.id FROM event e INDEXED BY author_index WHERE e.author=? AND e.kind=? AND e.created_at > ? LIMIT 1;", + params![pubkey_blob, e.kind, e.created_at], |row| row.get::(0)); + if repl_count.ok().is_some() { + return Ok(0); + } + } + // ignore if the event hash is a duplicate. + let mut ins_count = tx.execute( + "INSERT OR IGNORE INTO event (event_hash, created_at, kind, author, delegated_by, content, first_seen, hidden) VALUES (?1, ?2, ?3, ?4, ?5, ?6, strftime('%s','now'), FALSE);", + params![id_blob, e.created_at, e.kind, pubkey_blob, delegator_blob, event_str] + )? as u64; + if ins_count == 0 { + // if the event was a duplicate, no need to insert event or + // pubkey references. + tx.rollback().ok(); + return Ok(ins_count); + } + // remember primary key of the event most recently inserted. + let ev_id = tx.last_insert_rowid(); + // add all tags to the tag table + for tag in &e.tags { + // ensure we have 2 values. + if tag.len() >= 2 { + let tagname = &tag[0]; + let tagval = &tag[1]; + // only single-char tags are searchable + let tagchar_opt = single_char_tagname(tagname); + match &tagchar_opt { + Some(_) => { + // if tagvalue is lowercase hex; + if is_lower_hex(tagval) && (tagval.len() % 2 == 0) { + tx.execute( + "INSERT OR IGNORE INTO tag (event_id, name, value_hex) VALUES (?1, ?2, ?3)", + params![ev_id, &tagname, hex::decode(tagval).ok()], + )?; + } else { + tx.execute( + "INSERT OR IGNORE INTO tag (event_id, name, value) VALUES (?1, ?2, ?3)", + params![ev_id, &tagname, &tagval], + )?; + } + } + None => {} + } + } + } + // if this event is replaceable update, remove other replaceable + // event with the same kind from the same author that was issued + // earlier than this. + if e.is_replaceable() { + let author = hex::decode(&e.pubkey).ok(); + // this is a backwards check - hide any events that were older. + let update_count = tx.execute( + "DELETE FROM event WHERE kind=? and author=? and id NOT IN (SELECT id FROM event INDEXED BY author_kind_index WHERE kind=? AND author=? ORDER BY created_at DESC LIMIT 1)", + params![e.kind, author, e.kind, author], + )?; + if update_count > 0 { + info!( + "removed {} older replaceable kind {} events for author: {:?}", + update_count, + e.kind, + e.get_author_prefix() + ); + } + } + // if this event is a deletion, hide the referenced events from the same author. + if e.kind == 5 { + let event_candidates = e.tag_values_by_name("e"); + // first parameter will be author + let mut params: Vec> = vec![Box::new(hex::decode(&e.pubkey)?)]; + event_candidates + .iter() + .filter(|x| is_hex(x) && x.len() == 64) + .filter_map(|x| hex::decode(x).ok()) + .for_each(|x| params.push(Box::new(x))); + let query = format!( + "UPDATE event SET hidden=TRUE WHERE kind!=5 AND author=? AND event_hash IN ({})", + repeat_vars(params.len() - 1) + ); + let mut stmt = tx.prepare(&query)?; + let update_count = stmt.execute(rusqlite::params_from_iter(params))?; + info!( + "hid {} deleted events for author {:?}", + update_count, + e.get_author_prefix() + ); + } else { + // check if a deletion has already been recorded for this event. + // Only relevant for non-deletion events + let del_count = tx.query_row( + "SELECT e.id FROM event e LEFT JOIN tag t ON e.id=t.event_id WHERE e.author=? AND t.name='e' AND e.kind=5 AND t.value_hex=? LIMIT 1;", + params![pubkey_blob, id_blob], |row| row.get::(0)); + // check if a the query returned a result, meaning we should + // hid the current event + if del_count.ok().is_some() { + // a deletion already existed, mark original event as hidden. + info!( + "hid event: {:?} due to existing deletion by author: {:?}", + e.get_event_id_prefix(), + e.get_author_prefix() + ); + let _update_count = + tx.execute("UPDATE event SET hidden=TRUE WHERE id=?", params![ev_id])?; + // event was deleted, so let caller know nothing new + // arrived, preventing this from being sent to active + // subscriptions + ins_count = 0; + } + } + tx.commit()?; + Ok(ins_count) + } +} + +#[async_trait] +impl NostrRepo for SqliteRepo { + + async fn start(&self) -> Result<()> { + db_checkpoint_task(self.maint_pool.clone(), Duration::from_secs(60), self.checkpoint_in_progress.clone()).await + } + + async fn migrate_up(&self) -> Result { + let _write_guard = self.write_in_progress.lock().await; + let mut conn = self.write_pool.get()?; + task::spawn_blocking(move || { + upgrade_db(&mut conn) + }).await? + } + /// Persist event to database + async fn write_event(&self, e: &Event) -> Result { + let _write_guard = self.write_in_progress.lock().await; + // spawn a blocking thread + let mut conn = self.write_pool.get()?; + let e = e.clone(); + task::spawn_blocking(move || { + SqliteRepo::persist_event(&mut conn, &e) + }).await? + } + + /// Perform a database query using a subscription. + /// + /// The [`Subscription`] is converted into a SQL query. Each result + /// is published on the `query_tx` channel as it is returned. If a + /// message becomes available on the `abandon_query_rx` channel, the + /// query is immediately aborted. + async fn query_subscription( + &self, + sub: Subscription, + client_id: String, + query_tx: tokio::sync::mpsc::Sender, + mut abandon_query_rx: tokio::sync::oneshot::Receiver<()>, + ) -> Result<()> { + let pre_spawn_start = Instant::now(); + let self=self.clone(); + task::spawn_blocking(move || { + { + // if we are waiting on a checkpoint, stop until it is complete + let _x = self.checkpoint_in_progress.blocking_lock(); + } + let db_queue_time = pre_spawn_start.elapsed(); + // if the queue time was very long (>5 seconds), spare the DB and abort. + if db_queue_time > Duration::from_secs(5) { + info!( + "shedding DB query load queued for {:?} (cid: {}, sub: {:?})", + db_queue_time, client_id, sub.id + ); + return Ok(()); + } + // otherwise, report queuing time if it is slow + else if db_queue_time > Duration::from_secs(1) { + debug!( + "(slow) DB query queued for {:?} (cid: {}, sub: {:?})", + db_queue_time, client_id, sub.id + ); + } + let start = Instant::now(); + let mut row_count: usize = 0; + // generate SQL query + let (q, p, idxs) = query_from_sub(&sub); + let sql_gen_elapsed = start.elapsed(); + + if sql_gen_elapsed > Duration::from_millis(10) { + debug!("SQL (slow) generated in {:?}", start.elapsed()); + } + // cutoff for displaying slow queries + let slow_cutoff = Duration::from_millis(2000); + // any client that doesn't cause us to generate new rows in 5 + // seconds gets dropped. + let abort_cutoff = Duration::from_secs(5); + let start = Instant::now(); + let mut slow_first_event; + let mut last_successful_send = Instant::now(); + if let Ok(mut conn) = self.read_pool.get() { + // execute the query. + // make the actual SQL query (with parameters inserted) available + conn.trace(Some(|x| {trace!("SQL trace: {:?}", x)})); + let mut stmt = conn.prepare_cached(&q)?; + let mut event_rows = stmt.query(rusqlite::params_from_iter(p))?; + + let mut first_result = true; + while let Some(row) = event_rows.next()? { + let first_event_elapsed = start.elapsed(); + slow_first_event = first_event_elapsed >= slow_cutoff; + if first_result { + debug!( + "first result in {:?} (cid: {}, sub: {:?}) [used indexes: {:?}]", + first_event_elapsed, client_id, sub.id, idxs + ); + first_result = false; + } + // logging for slow queries; show sub and SQL. + // to reduce logging; only show 1/16th of clients (leading 0) + if row_count == 0 && slow_first_event && client_id.starts_with('0') { + debug!( + "query req (slow): {:?} (cid: {}, sub: {:?})", + sub, client_id, sub.id + ); + } + // check if a checkpoint is trying to run, and abort + if row_count % 100 == 0 { + { + if self.checkpoint_in_progress.try_lock().is_err() { + // lock was held, abort this query + debug!("query aborted due to checkpoint (cid: {}, sub: {:?})", client_id, sub.id); + return Ok(()); + } + } + } + + // check if this is still active; every 100 rows + if row_count % 100 == 0 && abandon_query_rx.try_recv().is_ok() { + debug!("query aborted (cid: {}, sub: {:?})", client_id, sub.id); + return Ok(()); + } + row_count += 1; + let event_json = row.get(0)?; + loop { + if query_tx.capacity() != 0 { + // we have capacity to add another item + break; + } + // the queue is full + trace!("db reader thread is stalled"); + if last_successful_send + abort_cutoff < Instant::now() { + // the queue has been full for too long, abort + info!("aborting database query due to slow client (cid: {}, sub: {:?})", + client_id, sub.id); + let ok: Result<()> = Ok(()); + return ok; + } + // check if a checkpoint is trying to run, and abort + if self.checkpoint_in_progress.try_lock().is_err() { + // lock was held, abort this query + debug!("query aborted due to checkpoint (cid: {}, sub: {:?})", client_id, sub.id); + return Ok(()); + } + // give the queue a chance to clear before trying again + thread::sleep(Duration::from_millis(100)); + } + // TODO: we could use try_send, but we'd have to juggle + // getting the query result back as part of the error + // result. + query_tx + .blocking_send(QueryResult { + sub_id: sub.get_id(), + event: event_json, + }) + .ok(); + last_successful_send = Instant::now(); + } + query_tx + .blocking_send(QueryResult { + sub_id: sub.get_id(), + event: "EOSE".to_string(), + }) + .ok(); + debug!( + "query completed in {:?} (cid: {}, sub: {:?}, db_time: {:?}, rows: {})", + pre_spawn_start.elapsed(), + client_id, + sub.id, + start.elapsed(), + row_count + ); + } else { + warn!("Could not get a database connection for querying"); + } + let ok: Result<()> = Ok(()); + ok + }); + Ok(()) + } + + /// Perform normal maintenance + async fn optimize_db(&self) -> Result<()> { + let conn = self.write_pool.get()?; + task::spawn_blocking(move || { + let start = Instant::now(); + conn.execute_batch("PRAGMA optimize;").ok(); + info!("optimize ran in {:?}", start.elapsed()); + }).await?; + Ok(()) + } + + /// Create a new verification record connected to a specific event + async fn create_verification_record(&self, event_id: &str, name: &str) -> Result<()> { + let e = hex::decode(event_id).ok(); + let n = name.to_owned(); + let mut conn = self.write_pool.get()?; + tokio::task::spawn_blocking(move || { + let tx = conn.transaction()?; + { + // if we create a /new/ one, we should get rid of any old ones. or group the new ones by name and only consider the latest. + let query = "INSERT INTO user_verification (metadata_event, name, verified_at) VALUES ((SELECT id from event WHERE event_hash=?), ?, strftime('%s','now'));"; + let mut stmt = tx.prepare(query)?; + stmt.execute(params![e, n])?; + // get the row ID + let v_id = tx.last_insert_rowid(); + // delete everything else by this name + let del_query = "DELETE FROM user_verification WHERE name = ? AND id != ?;"; + let mut del_stmt = tx.prepare(del_query)?; + let count = del_stmt.execute(params![n,v_id])?; + if count > 0 { + info!("removed {} old verification records for ({:?})", count, n); + } + } + tx.commit()?; + info!("saved new verification record for ({:?})", n); + let ok: Result<()> = Ok(()); + ok + }).await? + } + + /// Update verification timestamp + async fn update_verification_timestamp(&self, id: u64) -> Result<()> { + let mut conn = self.write_pool.get()?; + tokio::task::spawn_blocking(move || { + // add some jitter to the verification to prevent everything from stacking up together. + let verif_time = now_jitter(600); + let tx = conn.transaction()?; + { + // update verification time and reset any failure count + let query = + "UPDATE user_verification SET verified_at=?, failure_count=0 WHERE id=?"; + let mut stmt = tx.prepare(query)?; + stmt.execute(params![verif_time, id])?; + } + tx.commit()?; + let ok: Result<()> = Ok(()); + ok + }) + .await? + + } + + /// Update verification record as failed + async fn fail_verification(&self, id: u64) -> Result<()> { + let mut conn = self.write_pool.get()?; + tokio::task::spawn_blocking(move || { + // add some jitter to the verification to prevent everything from stacking up together. + let fail_time = now_jitter(600); + let tx = conn.transaction()?; + { + let query = "UPDATE user_verification SET failed_at=?, failure_count=failure_count+1 WHERE id=?"; + let mut stmt = tx.prepare(query)?; + stmt.execute(params![fail_time, id])?; + } + tx.commit()?; + let ok: Result<()> = Ok(()); + ok + }) + .await? + } + + /// Delete verification record + async fn delete_verification(&self, id: u64) -> Result<()> { + let mut conn = self.write_pool.get()?; + tokio::task::spawn_blocking(move || { + let tx = conn.transaction()?; + { + let query = "DELETE FROM user_verification WHERE id=?;"; + let mut stmt = tx.prepare(query)?; + stmt.execute(params![id])?; + } + tx.commit()?; + let ok: Result<()> = Ok(()); + ok + }) + .await? + } + + /// Get the latest verification record for a given pubkey. + async fn get_latest_user_verification(&self, pub_key: &str) -> Result { + let mut conn = self.read_pool.get()?; + let pub_key = pub_key.to_owned(); + tokio::task::spawn_blocking(move || { + let tx = conn.transaction()?; + let query = "SELECT v.id, v.name, e.event_hash, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v LEFT JOIN event e ON e.id=v.metadata_event WHERE e.author=? ORDER BY e.created_at DESC, v.verified_at DESC, v.failed_at DESC LIMIT 1;"; + let mut stmt = tx.prepare_cached(query)?; + let fields = stmt.query_row(params![hex::decode(&pub_key).ok()], |r| { + let rowid: u64 = r.get(0)?; + let rowname: String = r.get(1)?; + let eventid: Vec = r.get(2)?; + let created_at: u64 = r.get(3)?; + // create a tuple since we can't throw non-rusqlite errors in this closure + Ok(( + rowid, + rowname, + eventid, + created_at, + r.get(4).ok(), + r.get(5).ok(), + r.get(6)?, + )) + })?; + Ok(VerificationRecord { + rowid: fields.0, + name: Nip05Name::try_from(&fields.1[..])?, + address: pub_key, + event: hex::encode(fields.2), + event_created: fields.3, + last_success: fields.4, + last_failure: fields.5, + failure_count: fields.6, + }) + }).await? + } + + /// Get oldest verification before timestamp + async fn get_oldest_user_verification(&self, before: u64) -> Result { + let mut conn = self.read_pool.get()?; + tokio::task::spawn_blocking(move || { + let tx = conn.transaction()?; + let query = "SELECT v.id, v.name, e.event_hash, e.author, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v INNER JOIN event e ON e.id=v.metadata_event WHERE (v.verified_at < ? OR v.verified_at IS NULL) AND (v.failed_at < ? OR v.failed_at IS NULL) ORDER BY v.verified_at ASC, v.failed_at ASC LIMIT 1;"; + let mut stmt = tx.prepare_cached(query)?; + let fields = stmt.query_row(params![before, before], |r| { + let rowid: u64 = r.get(0)?; + let rowname: String = r.get(1)?; + let eventid: Vec = r.get(2)?; + let pubkey: Vec = r.get(3)?; + let created_at: u64 = r.get(4)?; + // create a tuple since we can't throw non-rusqlite errors in this closure + Ok(( + rowid, + rowname, + eventid, + pubkey, + created_at, + r.get(5).ok(), + r.get(6).ok(), + r.get(7)?, + )) + })?; + let vr = VerificationRecord { + rowid: fields.0, + name: Nip05Name::try_from(&fields.1[..])?, + address: hex::encode(fields.3), + event: hex::encode(fields.2), + event_created: fields.4, + last_success: fields.5, + last_failure: fields.6, + failure_count: fields.7, + }; + Ok(vr) + }).await? + } +} + +/// Decide if there is an index that should be used explicitly +fn override_index(f: &ReqFilter) -> Option { + // queries for multiple kinds default to kind_index, which is + // significantly slower than kind_created_at_index. + if let Some(ks) = &f.kinds { + if f.ids.is_none() && + ks.len() > 1 && + f.since.is_none() && + f.until.is_none() && + f.tags.is_none() && + f.authors.is_none() { + return Some("kind_created_at_index".into()); + } + } + // if there is an author, it is much better to force the authors index. + if f.authors.is_some() { + if f.since.is_none() && f.until.is_none() { + if f.kinds.is_none() { + // with no use of kinds/created_at, just author + return Some("author_index".into()); + } + // prefer author_kind if there are kinds + return Some("author_kind_index".into()); + } + // finally, prefer author_created_at if time is provided + return Some("author_created_at_index".into()); + } + None +} + +/// Create a dynamic SQL subquery and params from a subscription filter (and optional explicit index used) +fn query_from_filter(f: &ReqFilter) -> (String, Vec>, Option) { + // build a dynamic SQL query. all user-input is either an integer + // (sqli-safe), or a string that is filtered to only contain + // hexadecimal characters. Strings that require escaping (tag + // names/values) use parameters. + + // if the filter is malformed, don't return anything. + if f.force_no_match { + let empty_query = "SELECT e.content, e.created_at FROM event e WHERE 1=0".to_owned(); + // query parameters for SQLite + let empty_params: Vec> = vec![]; + return (empty_query, empty_params, None); + } + + // check if the index needs to be overriden + let idx_name = override_index(f); + let idx_stmt = idx_name.as_ref().map_or_else(|| "".to_owned(), |i| format!("INDEXED BY {}",i)); + let mut query = format!("SELECT e.content, e.created_at FROM event e {}", idx_stmt); + // query parameters for SQLite + let mut params: Vec> = vec![]; + + // individual filter components (single conditions such as an author or event ID) + let mut filter_components: Vec = Vec::new(); + // Query for "authors", allowing prefix matches + if let Some(authvec) = &f.authors { + // take each author and convert to a hexsearch + let mut auth_searches: Vec = vec![]; + for auth in authvec { + match hex_range(auth) { + Some(HexSearch::Exact(ex)) => { + auth_searches.push("author=?".to_owned()); + params.push(Box::new(ex)); + } + Some(HexSearch::Range(lower, upper)) => { + auth_searches.push( + "(author>? AND author { + auth_searches.push("author>?".to_owned()); + params.push(Box::new(lower)); + } + None => { + info!("Could not parse hex range from author {:?}", auth); + } + } + } + if !authvec.is_empty() { + let auth_clause = format!("({})", auth_searches.join(" OR ")); + filter_components.push(auth_clause); + } else { + filter_components.push("false".to_owned()); + } + } + // Query for Kind + if let Some(ks) = &f.kinds { + // kind is number, no escaping needed + let str_kinds: Vec = ks.iter().map(std::string::ToString::to_string).collect(); + let kind_clause = format!("kind IN ({})", str_kinds.join(", ")); + filter_components.push(kind_clause); + } + // Query for event, allowing prefix matches + if let Some(idvec) = &f.ids { + // take each author and convert to a hexsearch + let mut id_searches: Vec = vec![]; + for id in idvec { + match hex_range(id) { + Some(HexSearch::Exact(ex)) => { + id_searches.push("event_hash=?".to_owned()); + params.push(Box::new(ex)); + } + Some(HexSearch::Range(lower, upper)) => { + id_searches.push("(event_hash>? AND event_hash { + id_searches.push("event_hash>?".to_owned()); + params.push(Box::new(lower)); + } + None => { + info!("Could not parse hex range from id {:?}", id); + } + } + } + if idvec.is_empty() { + // if the ids list was empty, we should never return + // any results. + filter_components.push("false".to_owned()); + } else { + let id_clause = format!("({})", id_searches.join(" OR ")); + filter_components.push(id_clause); + } + } + // Query for tags + if let Some(map) = &f.tags { + for (key, val) in map.iter() { + let mut str_vals: Vec> = vec![]; + let mut blob_vals: Vec> = vec![]; + for v in val { + if (v.len() % 2 == 0) && is_lower_hex(v) { + if let Ok(h) = hex::decode(v) { + blob_vals.push(Box::new(h)); + } + } else { + str_vals.push(Box::new(v.clone())); + } + } + // create clauses with "?" params for each tag value being searched + let str_clause = format!("value IN ({})", repeat_vars(str_vals.len())); + let blob_clause = format!("value_hex IN ({})", repeat_vars(blob_vals.len())); + // find evidence of the target tag name/value existing for this event. + let tag_clause = format!( + "e.id IN (SELECT e.id FROM event e LEFT JOIN tag t on e.id=t.event_id WHERE hidden!=TRUE and (name=? AND ({} OR {})))", + str_clause, blob_clause + ); + // add the tag name as the first parameter + params.push(Box::new(key.to_string())); + // add all tag values that are plain strings as params + params.append(&mut str_vals); + // add all tag values that are blobs as params + params.append(&mut blob_vals); + filter_components.push(tag_clause); + } + } + // Query for timestamp + if f.since.is_some() { + let created_clause = format!("created_at > {}", f.since.unwrap()); + filter_components.push(created_clause); + } + // Query for timestamp + if f.until.is_some() { + let until_clause = format!("created_at < {}", f.until.unwrap()); + filter_components.push(until_clause); + } + // never display hidden events + query.push_str(" WHERE hidden!=TRUE"); + // build filter component conditions + if !filter_components.is_empty() { + query.push_str(" AND "); + query.push_str(&filter_components.join(" AND ")); + } + // Apply per-filter limit to this subquery. + // The use of a LIMIT implies a DESC order, to capture only the most recent events. + if let Some(lim) = f.limit { + let _ = write!(query, " ORDER BY e.created_at DESC LIMIT {}", lim); + } else { + query.push_str(" ORDER BY e.created_at ASC"); + } + (query, params, idx_name) +} + +/// Create a dynamic SQL query string and params from a subscription. +fn query_from_sub(sub: &Subscription) -> (String, Vec>, Vec) { + // build a dynamic SQL query for an entire subscription, based on + // SQL subqueries for filters. + let mut subqueries: Vec = Vec::new(); + let mut indexes = vec![]; + // subquery params + let mut params: Vec> = vec![]; + // for every filter in the subscription, generate a subquery + for f in &sub.filters { + let (f_subquery, mut f_params, index) = query_from_filter(f); + if let Some(i) = index { + indexes.push(i); + } + subqueries.push(f_subquery); + params.append(&mut f_params); + } + // encapsulate subqueries into select statements + let subqueries_selects: Vec = subqueries + .iter() + .map(|s| format!("SELECT distinct content, created_at FROM ({})", s)) + .collect(); + let query: String = subqueries_selects.join(" UNION "); + (query, params,indexes) +} + +/// Build a database connection pool. +/// # Panics +/// +/// Will panic if the pool could not be created. +#[must_use] +pub fn build_pool( + name: &str, + settings: &Settings, + flags: OpenFlags, + min_size: u32, + max_size: u32, + wait_for_db: bool, +) -> SqlitePool { + let db_dir = &settings.database.data_directory; + let full_path = Path::new(db_dir).join(DB_FILE); + // small hack; if the database doesn't exist yet, that means the + // writer thread hasn't finished. Give it a chance to work. This + // is only an issue with the first time we run. + if !settings.database.in_memory { + while !full_path.exists() && wait_for_db { + debug!("Database reader pool is waiting on the database to be created..."); + thread::sleep(Duration::from_millis(500)); + } + } + let manager = if settings.database.in_memory { + SqliteConnectionManager::memory() + .with_flags(flags) + .with_init(|c| c.execute_batch(STARTUP_SQL)) + } else { + SqliteConnectionManager::file(&full_path) + .with_flags(flags) + .with_init(|c| c.execute_batch(STARTUP_SQL)) + }; + let pool: SqlitePool = r2d2::Pool::builder() + .test_on_check_out(true) // no noticeable performance hit + .min_idle(Some(min_size)) + .max_size(max_size) + .max_lifetime(Some(Duration::from_secs(30))) + .build(manager) + .unwrap(); + info!( + "Built a connection pool {:?} (min={}, max={})", + name, min_size, max_size + ); + pool +} + +/// Perform database WAL checkpoint on a regular basis +pub async fn db_checkpoint_task(pool: SqlitePool, frequency: Duration, checkpoint_in_progress: Arc>) -> Result<()> { + + tokio::task::spawn(async move { + // WAL size in pages. + let mut current_wal_size = 0; + // WAL threshold for more aggressive checkpointing (10,000 pages, or about 40MB) + let wal_threshold = 1000*10; + // default threshold for the busy timer + let busy_wait_default = Duration::from_secs(1); + // if the WAL file is getting too big, switch to this + let busy_wait_default_long = Duration::from_secs(10); + loop { + tokio::select! { + _ = tokio::time::sleep(frequency) => { + if let Ok(mut conn) = pool.get() { + let mut _guard:Option> = None; + // the busy timer will block writers, so don't set + // this any higher than you want max latency for event + // writes. + if current_wal_size <= wal_threshold { + conn.busy_timeout(busy_wait_default).ok(); + } else { + // if the wal size has exceeded a threshold, increase the busy timeout. + conn.busy_timeout(busy_wait_default_long).ok(); + // take a lock that will prevent new readers. + info!("blocking new readers to perform wal_checkpoint"); + _guard = Some(checkpoint_in_progress.lock().await); + } + debug!("running wal_checkpoint(TRUNCATE)"); + if let Ok(new_size) = checkpoint_db(&mut conn) { + current_wal_size = new_size; + } + } + } + }; + } + }); + + Ok(()) +} + +#[derive(Debug)] +enum SqliteStatus { + Ok, + Busy, + Error, + Other(u64), +} + +/// Checkpoint/Truncate WAL. Returns the number of WAL pages remaining. +pub fn checkpoint_db(conn: &mut PooledConnection) -> Result { + let query = "PRAGMA wal_checkpoint(TRUNCATE);"; + let start = Instant::now(); + let (cp_result, wal_size, _frames_checkpointed) = conn.query_row(query, [], |row| { + let checkpoint_result: u64 = row.get(0)?; + let wal_size: u64 = row.get(1)?; + let frames_checkpointed: u64 = row.get(2)?; + Ok((checkpoint_result, wal_size, frames_checkpointed)) + })?; + let result = match cp_result { + 0 => SqliteStatus::Ok, + 1 => SqliteStatus::Busy, + 2 => SqliteStatus::Error, + x => SqliteStatus::Other(x), + }; + info!( + "checkpoint ran in {:?} (result: {:?}, WAL size: {})", + start.elapsed(), + result, + wal_size + ); + Ok(wal_size as usize) +} + + +/// Produce a arbitrary list of '?' parameters. +fn repeat_vars(count: usize) -> String { + if count == 0 { + return "".to_owned(); + } + let mut s = "?,".repeat(count); + // Remove trailing comma + s.pop(); + s +} + +/// Display database pool stats every 1 minute +pub async fn monitor_pool(name: &str, pool: SqlitePool) { + let sleep_dur = Duration::from_secs(60); + loop { + log_pool_stats(name, &pool); + tokio::time::sleep(sleep_dur).await; + } +} + +/// Log pool stats +fn log_pool_stats(name: &str, pool: &SqlitePool) { + let state: r2d2::State = pool.state(); + let in_use_cxns = state.connections - state.idle_connections; + debug!( + "DB pool {:?} usage (in_use: {}, available: {}, max: {})", + name, + in_use_cxns, + state.connections, + pool.max_size() + ); +} + + +/// Check if the pool is fully utilized +fn _pool_at_capacity(pool: &SqlitePool) -> bool { + let state: r2d2::State = pool.state(); + state.idle_connections == 0 +} diff --git a/src/schema.rs b/src/repo/sqlite_migration.rs similarity index 98% rename from src/schema.rs rename to src/repo/sqlite_migration.rs index 6e86222..d1722b1 100644 --- a/src/schema.rs +++ b/src/repo/sqlite_migration.rs @@ -113,7 +113,7 @@ pub fn db_tag_count(conn: &mut Connection) -> Result { Ok(count) } -fn mig_init(conn: &mut PooledConnection) -> Result { +fn mig_init(conn: &mut PooledConnection) -> usize { match conn.execute_batch(INIT_SQL) { Ok(()) => { info!( @@ -126,11 +126,11 @@ fn mig_init(conn: &mut PooledConnection) -> Result { panic!("database could not be initialized"); } } - Ok(DB_VERSION) + DB_VERSION } /// Upgrade DB to latest version, and execute pragma settings -pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> { +pub fn upgrade_db(conn: &mut PooledConnection) -> Result { // check the version. let mut curr_version = curr_db_version(conn)?; info!("DB version = {:?}", curr_version); @@ -141,11 +141,11 @@ pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> { ); debug!( "SQLite max table/blob/text length: {} MB", - (conn.limit(Limit::SQLITE_LIMIT_LENGTH) as f64 / (1024 * 1024) as f64).floor() + (f64::from(conn.limit(Limit::SQLITE_LIMIT_LENGTH)) / f64::from(1024 * 1024)).floor() ); debug!( "SQLite max SQL length: {} MB", - (conn.limit(Limit::SQLITE_LIMIT_SQL_LENGTH) as f64 / (1024 * 1024) as f64).floor() + (f64::from(conn.limit(Limit::SQLITE_LIMIT_SQL_LENGTH)) / f64::from(1024 * 1024)).floor() ); match curr_version.cmp(&DB_VERSION) { @@ -153,7 +153,7 @@ pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> { Ordering::Less => { // initialize from scratch if curr_version == 0 { - curr_version = mig_init(conn)?; + curr_version = mig_init(conn); } // for initialized but out-of-date schemas, proceed to // upgrade sequentially until we are current. @@ -223,7 +223,7 @@ pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> { // Setup PRAGMA conn.execute_batch(STARTUP_SQL)?; debug!("SQLite PRAGMA startup completed"); - Ok(()) + Ok(DB_VERSION) } pub fn rebuild_tags(conn: &mut PooledConnection) -> Result<()> { diff --git a/src/server.rs b/src/server.rs index 66e170c..3053db0 100644 --- a/src/server.rs +++ b/src/server.rs @@ -3,6 +3,7 @@ use crate::close::Close; use crate::close::CloseCmd; use crate::config::{Settings, VerifiedUsersMode}; use crate::conn; +use crate::repo::NostrRepo; use crate::db; use crate::db::SubmittedEvent; use crate::error::{Error, Result}; @@ -22,10 +23,8 @@ use hyper::upgrade::Upgraded; use hyper::{ header, server::conn::AddrStream, upgrade, Body, Request, Response, Server, StatusCode, }; -use rusqlite::OpenFlags; use serde::{Deserialize, Serialize}; use serde_json::json; -use tokio::sync::Mutex; use std::collections::HashMap; use std::convert::Infallible; use std::net::SocketAddr; @@ -40,23 +39,22 @@ use tokio::sync::broadcast::{self, Receiver, Sender}; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio_tungstenite::WebSocketStream; -use tracing::*; +use tracing::{debug, error, info, trace, warn}; use tungstenite::error::CapacityError::MessageTooLong; use tungstenite::error::Error as WsError; use tungstenite::handshake; use tungstenite::protocol::Message; use tungstenite::protocol::WebSocketConfig; -/// Handle arbitrary HTTP requests, including for WebSocket upgrades. +/// Handle arbitrary HTTP requests, including for `WebSocket` upgrades. async fn handle_web_request( mut request: Request, - pool: db::SqlitePool, + repo: Arc, settings: Settings, remote_addr: SocketAddr, broadcast: Sender, event_tx: tokio::sync::mpsc::Sender, shutdown: Receiver<()>, - safe_to_read: Arc>, ) -> Result, Infallible> { match ( request.uri().path(), @@ -111,14 +109,13 @@ async fn handle_web_request( }; // spawn a nostr server with our websocket tokio::spawn(nostr_server( - pool, + repo, client_info, settings, ws_stream, broadcast, event_tx, shutdown, - safe_to_read, )); } // todo: trace, don't print... @@ -184,7 +181,7 @@ async fn handle_web_request( fn get_header_string(header: &str, headers: &HeaderMap) -> Option { headers .get(header) - .and_then(|x| x.to_str().ok().map(|x| x.to_string())) + .and_then(|x| x.to_str().ok().map(std::string::ToString::to_string)) } // return on a control-c or internally requested shutdown signal @@ -211,7 +208,7 @@ async fn ctrl_c_or_signal(mut shutdown_signal: Receiver<()>) { } /// Start running a Nostr relay server. -pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result<(), Error> { +pub fn start_server(settings: &Settings, shutdown_rx: MpscReceiver<()>) -> Result<(), Error> { trace!("Config: {:?}", settings); // do some config validation. if !Path::new(&settings.database.data_directory).is_dir() { @@ -274,8 +271,6 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result let broadcast_buffer_limit = settings.limits.broadcast_buffer; let persist_buffer_limit = settings.limits.event_persist_buffer; let verified_users_active = settings.verified_users.is_active(); - let db_min_conn = settings.database.min_conn; - let db_max_conn = settings.database.max_conn; let settings = settings.clone(); info!("listening on: {}", socket_addr); // all client-submitted valid events are broadcast to every @@ -298,23 +293,26 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result // overwhelming this will drop events and won't register // metadata events. let (metadata_tx, metadata_rx) = broadcast::channel::(4096); - // start the database writer thread. Give it a channel for + // build a repository for events + let repo = db::build_repo(&settings).await; + // start the database writer task. Give it a channel for // writing events, and for publishing events that have been // written (to all connected clients). - db::db_writer( - settings.clone(), - event_rx, - bcast_tx.clone(), - metadata_tx.clone(), - shutdown_listen, - ) - .await; + tokio::task::spawn( + db::db_writer( + repo.clone(), + settings.clone(), + event_rx, + bcast_tx.clone(), + metadata_tx.clone(), + shutdown_listen, + )); info!("db writer created"); // create a nip-05 verifier thread; if enabled. if settings.verified_users.mode != VerifiedUsersMode::Disabled { let verifier_opt = - nip05::Verifier::new(metadata_rx, bcast_tx.clone(), settings.clone()); + nip05::Verifier::new(repo.clone(), metadata_rx, bcast_tx.clone(), settings.clone()); if let Ok(mut v) = verifier_opt { if verified_users_active { tokio::task::spawn(async move { @@ -324,35 +322,19 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result } } } - // build a connection pool for DB maintenance - let maintenance_pool = db::build_pool( - "maintenance writer", - &settings, - OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE, - 1, - 2, - false, - ); - - // Create a mutex that will block readers, so that a - // checkpoint can be performed quickly. - let safe_to_read = Arc::new(Mutex::new(0)); - - db::db_optimize_task(maintenance_pool.clone()).await; - db::db_checkpoint_task(maintenance_pool, safe_to_read.clone()).await; // listen for (external to tokio) shutdown request let controlled_shutdown = invoke_shutdown.clone(); tokio::spawn(async move { info!("control message listener started"); + // we only have good "shutdown" messages propagation from this-> controlled shutdown. Not from controlled_shutdown-> this. Which means we have a task that is stuck waiting on a sync receive. recv is blocking, and this is async. match shutdown_rx.recv() { Ok(()) => { info!("control message requesting shutdown"); controlled_shutdown.send(()).ok(); - } + }, Err(std::sync::mpsc::RecvError) => { - // FIXME: spurious error on startup? - debug!("shutdown requestor is disconnected"); + trace!("shutdown requestor is disconnected (this is normal)"); } }; }); @@ -366,41 +348,30 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result info!("shutting down due to SIGINT (main)"); ctrl_c_shutdown.send(()).ok(); }); - // build a connection pool for sqlite connections - let pool = db::build_pool( - "client query", - &settings, - rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY, - db_min_conn, - db_max_conn, - true, - ); // spawn a task to check the pool size. - let pool_monitor = pool.clone(); - tokio::spawn(async move {db::monitor_pool("reader", pool_monitor).await;}); + //let pool_monitor = pool.clone(); + //tokio::spawn(async move {db::monitor_pool("reader", pool_monitor).await;}); // A `Service` is needed for every connection, so this // creates one from our `handle_request` function. let make_svc = make_service_fn(|conn: &AddrStream| { - let svc_pool = pool.clone(); + let repo = repo.clone(); let remote_addr = conn.remote_addr(); let bcast = bcast_tx.clone(); let event = event_tx.clone(); let stop = invoke_shutdown.clone(); let settings = settings.clone(); - let safe_to_read = safe_to_read.clone(); async move { // service_fn converts our function into a `Service` Ok::<_, Infallible>(service_fn(move |request: Request| { handle_web_request( request, - svc_pool.clone(), + repo.clone(), settings.clone(), remote_addr, bcast.clone(), event.clone(), stop.subscribe(), - safe_to_read.clone(), ) })) } @@ -428,9 +399,9 @@ pub enum NostrMessage { CloseMsg(CloseCmd), } -/// Convert Message to NostrMessage -fn convert_to_msg(msg: String, max_bytes: Option) -> Result { - let parsed_res: Result = serde_json::from_str(&msg).map_err(|e| e.into()); +/// Convert Message to `NostrMessage` +fn convert_to_msg(msg: &str, max_bytes: Option) -> Result { + let parsed_res: Result = serde_json::from_str(msg).map_err(std::convert::Into::into); match parsed_res { Ok(m) => { if let NostrMessage::SubMsg(_) = m { @@ -455,8 +426,8 @@ fn convert_to_msg(msg: String, max_bytes: Option) -> Result } } -/// Turn a string into a NOTICE message ready to send over a WebSocket -fn make_notice_message(notice: Notice) -> Message { +/// Turn a string into a NOTICE message ready to send over a `WebSocket` +fn make_notice_message(notice: &Notice) -> Message { let json = match notice { Notice::Message(ref msg) => json!(["NOTICE", msg]), Notice::EventResult(ref res) => json!(["OK", res.id, res.status.to_bool(), res.msg]), @@ -474,14 +445,13 @@ struct ClientInfo { /// Handle new client connections. This runs through an event loop /// for all client communication. async fn nostr_server( - pool: db::SqlitePool, + repo: Arc, client_info: ClientInfo, settings: Settings, mut ws_stream: WebSocketStream, broadcast: Sender, event_tx: mpsc::Sender, mut shutdown: Receiver<()>, - safe_to_read: Arc>, ) { // the time this websocket nostr server started let orig_start = Instant::now(); @@ -559,7 +529,7 @@ async fn nostr_server( ws_stream.send(Message::Ping(Vec::new())).await.ok(); }, Some(notice_msg) = notice_rx.recv() => { - ws_stream.send(make_notice_message(notice_msg)).await.ok(); + ws_stream.send(make_notice_message(¬ice_msg)).await.ok(); }, Some(query_result) = query_rx.recv() => { // database informed us of a query result we asked for @@ -603,11 +573,11 @@ async fn nostr_server( // Consume text messages from the client, parse into Nostr messages. let nostr_msg = match ws_next { Some(Ok(Message::Text(m))) => { - convert_to_msg(m,settings.limits.max_event_bytes) + convert_to_msg(&m,settings.limits.max_event_bytes) }, Some(Ok(Message::Binary(_))) => { ws_stream.send( - make_notice_message(Notice::message("binary messages are not accepted".into()))).await.ok(); + make_notice_message(&Notice::message("binary messages are not accepted".into()))).await.ok(); continue; }, Some(Ok(Message::Ping(_) | Message::Pong(_))) => { @@ -617,7 +587,7 @@ async fn nostr_server( }, Some(Err(WsError::Capacity(MessageTooLong{size, max_size}))) => { ws_stream.send( - make_notice_message(Notice::message(format!("message too large ({} > {})",size, max_size)))).await.ok(); + make_notice_message(&Notice::message(format!("message too large ({} > {})",size, max_size)))).await.ok(); continue; }, None | @@ -662,13 +632,13 @@ async fn nostr_server( if let Some(fut_sec) = settings.options.reject_future_seconds { let msg = format!("The event created_at field is out of the acceptable range (+{}sec) for this relay.",fut_sec); let notice = Notice::invalid(e.id, &msg); - ws_stream.send(make_notice_message(notice)).await.ok(); + ws_stream.send(make_notice_message(¬ice)).await.ok(); } } }, Err(e) => { info!("client sent an invalid event (cid: {})", cid); - ws_stream.send(make_notice_message(Notice::invalid(evid, &format!("{}", e)))).await.ok(); + ws_stream.send(make_notice_message(&Notice::invalid(evid, &format!("{}", e)))).await.ok(); } } }, @@ -679,31 +649,31 @@ async fn nostr_server( // * registering the subscription so future events can be matched // * making a channel to cancel to request later // * sending a request for a SQL query - // Do nothing if the sub already exists. - if !conn.has_subscription(&s) { - if let Some(ref lim) = sub_lim_opt { - lim.until_ready_with_jitter(jitter).await; - } + // Do nothing if the sub already exists. + if conn.has_subscription(&s) { + info!("client sent duplicate subscription, ignoring (cid: {}, sub: {:?})", cid, s.id); + } else { + if let Some(ref lim) = sub_lim_opt { + lim.until_ready_with_jitter(jitter).await; + } let (abandon_query_tx, abandon_query_rx) = oneshot::channel::<()>(); match conn.subscribe(s.clone()) { - Ok(()) => { + Ok(()) => { // when we insert, if there was a previous query running with the same name, cancel it. - if let Some(previous_query) = running_queries.insert(s.id.to_owned(), abandon_query_tx) { - previous_query.send(()).ok(); + if let Some(previous_query) = running_queries.insert(s.id.clone(), abandon_query_tx) { + previous_query.send(()).ok(); } - if s.needs_historical_events() { - // start a database query. this spawns a blocking database query on a worker thread. - db::db_query(s, cid.to_owned(), pool.clone(), query_tx.clone(), abandon_query_rx,safe_to_read.clone()).await; + if s.needs_historical_events() { + // start a database query. this spawns a blocking database query on a worker thread. + repo.query_subscription(s, cid.clone(), query_tx.clone(), abandon_query_rx).await.ok(); } - }, - Err(e) => { - info!("Subscription error: {} (cid: {}, sub: {:?})", e, cid, s.id); - ws_stream.send(make_notice_message(Notice::message(format!("Subscription error: {}", e)))).await.ok(); - } + }, + Err(e) => { + info!("Subscription error: {} (cid: {}, sub: {:?})", e, cid, s.id); + ws_stream.send(make_notice_message(&Notice::message(format!("Subscription error: {}", e)))).await.ok(); + } } - } else { - info!("client sent duplicate subscription, ignoring (cid: {}, sub: {:?})", cid, s.id); - } + } }, Ok(NostrMessage::CloseMsg(cc)) => { // closing a request simply removes the subscription. @@ -720,7 +690,7 @@ async fn nostr_server( conn.unsubscribe(&c); } else { info!("invalid command ignored"); - ws_stream.send(make_notice_message(Notice::message("could not parse command".into()))).await.ok(); + ws_stream.send(make_notice_message(&Notice::message("could not parse command".into()))).await.ok(); } }, Err(Error::ConnError) => { @@ -729,11 +699,11 @@ async fn nostr_server( } Err(Error::EventMaxLengthError(s)) => { info!("client sent event larger ({} bytes) than max size (cid: {})", s, cid); - ws_stream.send(make_notice_message(Notice::message("event exceeded max size".into()))).await.ok(); + ws_stream.send(make_notice_message(&Notice::message("event exceeded max size".into()))).await.ok(); }, Err(Error::ProtoParseError) => { info!("client sent event that could not be parsed (cid: {})", cid); - ws_stream.send(make_notice_message(Notice::message("could not parse command".into()))).await.ok(); + ws_stream.send(make_notice_message(&Notice::message("could not parse command".into()))).await.ok(); }, Err(e) => { info!("got non-fatal error from client (cid: {}, error: {:?}", cid, e); diff --git a/src/subscription.rs b/src/subscription.rs index f019ce2..db83ae3 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -68,7 +68,7 @@ impl<'de> Deserialize<'de> for ReqFilter { let empty_string = "".into(); let mut ts = None; // iterate through each key, and assign values that exist - for (key, val) in filter.into_iter() { + for (key, val) in filter { // ids if key == "ids" { let raw_ids: Option>= Deserialize::deserialize(val).ok(); @@ -107,7 +107,7 @@ impl<'de> Deserialize<'de> for ReqFilter { if let Some(m) = ts.as_mut() { let tag_vals: Option> = Deserialize::deserialize(val).ok(); if let Some(v) = tag_vals { - let hs = HashSet::from_iter(v.into_iter()); + let hs = v.into_iter().collect::>(); m.insert(tag_search.to_owned(), hs); } }; @@ -197,20 +197,20 @@ impl<'de> Deserialize<'de> for Subscription { impl Subscription { /// Get a copy of the subscription identifier. - pub fn get_id(&self) -> String { + #[must_use] pub fn get_id(&self) -> String { self.id.clone() } /// Determine if any filter is requesting historical (database) /// queries. If every filter has limit:0, we do not need to query the DB. - pub fn needs_historical_events(&self) -> bool { + #[must_use] pub fn needs_historical_events(&self) -> bool { self.filters.iter().any(|f| f.limit!=Some(0)) } /// Determine if this subscription matches a given [`Event`]. Any /// individual filter match is sufficient. - pub fn interested_in_event(&self, event: &Event) -> bool { - for f in self.filters.iter() { + #[must_use] pub fn interested_in_event(&self, event: &Event) -> bool { + for f in &self.filters { if f.interested_in_event(event) { return true; } @@ -233,23 +233,20 @@ impl ReqFilter { fn ids_match(&self, event: &Event) -> bool { self.ids .as_ref() - .map(|vs| prefix_match(vs, &event.id)) - .unwrap_or(true) + .map_or(true, |vs| prefix_match(vs, &event.id)) } fn authors_match(&self, event: &Event) -> bool { self.authors .as_ref() - .map(|vs| prefix_match(vs, &event.pubkey)) - .unwrap_or(true) + .map_or(true, |vs| prefix_match(vs, &event.pubkey)) } fn delegated_authors_match(&self, event: &Event) -> bool { if let Some(delegated_pubkey) = &event.delegated_by { self.authors .as_ref() - .map(|vs| prefix_match(vs, delegated_pubkey)) - .unwrap_or(true) + .map_or(true, |vs| prefix_match(vs, delegated_pubkey)) } else { false } @@ -275,16 +272,15 @@ impl ReqFilter { fn kind_match(&self, kind: u64) -> bool { self.kinds .as_ref() - .map(|ks| ks.contains(&kind)) - .unwrap_or(true) + .map_or(true, |ks| ks.contains(&kind)) } /// Determine if all populated fields in this filter match the provided event. - pub fn interested_in_event(&self, event: &Event) -> bool { + #[must_use] pub fn interested_in_event(&self, event: &Event) -> bool { // self.id.as_ref().map(|v| v == &event.id).unwrap_or(true) self.ids_match(event) - && self.since.map(|t| event.created_at > t).unwrap_or(true) - && self.until.map(|t| event.created_at < t).unwrap_or(true) + && self.since.map_or(true, |t| event.created_at > t) + && self.until.map_or(true, |t| event.created_at < t) && self.kind_match(event.kind) && (self.authors_match(event) || self.delegated_authors_match(event)) && self.tag_match(event) diff --git a/src/utils.rs b/src/utils.rs index 59fbd0e..6eff34d 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -2,7 +2,7 @@ use std::time::SystemTime; /// Seconds since 1970. -pub fn unix_time() -> u64 { +#[must_use] pub fn unix_time() -> u64 { SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .map(|x| x.as_secs()) @@ -10,12 +10,12 @@ pub fn unix_time() -> u64 { } /// Check if a string contains only hex characters. -pub fn is_hex(s: &str) -> bool { +#[must_use] pub fn is_hex(s: &str) -> bool { s.chars().all(|x| char::is_ascii_hexdigit(&x)) } /// Check if a string contains only lower-case hex chars. -pub fn is_lower_hex(s: &str) -> bool { +#[must_use] pub fn is_lower_hex(s: &str) -> bool { s.chars().all(|x| { (char::is_ascii_lowercase(&x) || char::is_ascii_digit(&x)) && char::is_ascii_hexdigit(&x) })