improvement: add NostrRepo trait, with sqlite implementation

This is inspired by the work of
v0l (https://github.com/v0l/nostr-rs-relay/).

A new trait abstracts the storage layer with an async API.  Rusqlite
is still used with worker threads, but this allows for Postgresql or
other backends to be used.

There may be bugs, this has not been rigorously tested.
This commit is contained in:
Greg Heartsfield 2023-01-22 09:49:49 -06:00
parent e996d4c009
commit 6800c2e39d
20 changed files with 1396 additions and 1345 deletions

1
Cargo.lock generated
View File

@ -1191,6 +1191,7 @@ name = "nostr-rs-relay"
version = "0.7.17" version = "0.7.17"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"async-trait",
"bitcoin_hashes", "bitcoin_hashes",
"clap", "clap",
"config", "config",

View File

@ -42,6 +42,7 @@ parse_duration = "2"
rand = "0.8" rand = "0.8"
const_format = "0.2.28" const_format = "0.2.28"
regex = "1" regex = "1"
async-trait = "0.1.60"
[dev-dependencies] [dev-dependencies]
anyhow = "1" anyhow = "1"

View File

@ -1,14 +1,13 @@
use std::io; use std::io;
use std::path::Path; use std::path::Path;
use nostr_rs_relay::utils::is_lower_hex; use nostr_rs_relay::utils::is_lower_hex;
use tracing::*; use tracing::info;
use nostr_rs_relay::config; use nostr_rs_relay::config;
use nostr_rs_relay::event::{Event,single_char_tagname}; use nostr_rs_relay::event::{Event,single_char_tagname};
use nostr_rs_relay::error::{Error, Result}; use nostr_rs_relay::error::{Error, Result};
use nostr_rs_relay::db::build_pool; use nostr_rs_relay::repo::sqlite::{PooledConnection, build_pool};
use nostr_rs_relay::schema::{curr_db_version, DB_VERSION}; use nostr_rs_relay::repo::sqlite_migration::{curr_db_version, DB_VERSION};
use rusqlite::{OpenFlags, Transaction}; use rusqlite::{OpenFlags, Transaction};
use nostr_rs_relay::db::PooledConnection;
use std::sync::mpsc; use std::sync::mpsc;
use std::thread; use std::thread;
use rusqlite::params; use rusqlite::params;
@ -67,7 +66,7 @@ pub fn main() -> Result<()> {
info!("finished parsing events"); info!("finished parsing events");
event_tx.send(None).ok(); event_tx.send(None).ok();
let ok: Result<()> = Ok(()); let ok: Result<()> = Ok(());
return ok; ok
}); });
let mut conn: PooledConnection = pool.get()?; let mut conn: PooledConnection = pool.get()?;
let mut events_read = 0; let mut events_read = 0;

View File

@ -18,6 +18,7 @@ pub struct Info {
#[allow(unused)] #[allow(unused)]
pub struct Database { pub struct Database {
pub data_directory: String, pub data_directory: String,
pub engine: String,
pub in_memory: bool, pub in_memory: bool,
pub min_conn: u32, pub min_conn: u32,
pub max_conn: u32, pub max_conn: u32,
@ -206,6 +207,7 @@ impl Default for Settings {
diagnostics: Diagnostics { tracing: false }, diagnostics: Diagnostics { tracing: false },
database: Database { database: Database {
data_directory: ".".to_owned(), data_directory: ".".to_owned(),
engine: "sqlite".to_owned(),
in_memory: false, in_memory: false,
min_conn: 4, min_conn: 4,
max_conn: 8, max_conn: 8,

View File

@ -14,7 +14,7 @@ const MAX_SUBSCRIPTION_ID_LEN: usize = 256;
/// State for a client connection /// State for a client connection
pub struct ClientConn { pub struct ClientConn {
/// Client IP (either from socket, or configured proxy header /// Client IP (either from socket, or configured proxy header
client_ip: String, client_ip_addr: String,
/// Unique client identifier generated at connection time /// Unique client identifier generated at connection time
client_id: Uuid, client_id: Uuid,
/// The current set of active client subscriptions /// The current set of active client subscriptions
@ -32,22 +32,22 @@ impl Default for ClientConn {
impl ClientConn { impl ClientConn {
/// Create a new, empty connection state. /// Create a new, empty connection state.
#[must_use] #[must_use]
pub fn new(client_ip: String) -> Self { pub fn new(client_ip_addr: String) -> Self {
let client_id = Uuid::new_v4(); let client_id = Uuid::new_v4();
ClientConn { ClientConn {
client_ip, client_ip_addr,
client_id, client_id,
subscriptions: HashMap::new(), subscriptions: HashMap::new(),
max_subs: 32, max_subs: 32,
} }
} }
pub fn subscriptions(&self) -> &HashMap<String, Subscription> { #[must_use] pub fn subscriptions(&self) -> &HashMap<String, Subscription> {
&self.subscriptions &self.subscriptions
} }
/// Check if the given subscription already exists /// Check if the given subscription already exists
pub fn has_subscription(&self, sub: &Subscription) -> bool { #[must_use] pub fn has_subscription(&self, sub: &Subscription) -> bool {
self.subscriptions.values().any(|x| x == sub) self.subscriptions.values().any(|x| x == sub)
} }
@ -60,7 +60,7 @@ impl ClientConn {
#[must_use] #[must_use]
pub fn ip(&self) -> &str { pub fn ip(&self) -> &str {
&self.client_ip &self.client_ip_addr
} }
/// Add a new subscription for this connection. /// Add a new subscription for this connection.

741
src/db.rs
View File

@ -1,32 +1,16 @@
//! Event persistence and querying //! Event persistence and querying
//use crate::config::SETTINGS;
use crate::config::Settings; use crate::config::Settings;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::event::{single_char_tagname, Event}; use crate::event::Event;
use crate::hexrange::hex_range;
use crate::hexrange::HexSearch;
use crate::nip05;
use crate::notice::Notice; use crate::notice::Notice;
use crate::schema::{upgrade_db, STARTUP_SQL};
use crate::subscription::ReqFilter;
use crate::subscription::Subscription;
use crate::utils::{is_hex, is_lower_hex};
use governor::clock::Clock; use governor::clock::Clock;
use governor::{Quota, RateLimiter}; use governor::{Quota, RateLimiter};
use hex;
use r2d2; use r2d2;
use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::params;
use rusqlite::types::ToSql;
use rusqlite::OpenFlags;
use tokio::sync::{Mutex, MutexGuard};
use std::fmt::Write as _;
use std::path::Path;
use std::sync::Arc; use std::sync::Arc;
use std::thread; use std::thread;
use std::time::Duration; use crate::repo::sqlite::SqliteRepo;
use crate::repo::NostrRepo;
use std::time::Instant; use std::time::Instant;
use tokio::task;
use tracing::{debug, info, trace, warn}; use tracing::{debug, info, trace, warn};
pub type SqlitePool = r2d2::Pool<r2d2_sqlite::SqliteConnectionManager>; pub type SqlitePool = r2d2::Pool<r2d2_sqlite::SqliteConnectionManager>;
@ -41,137 +25,39 @@ pub struct SubmittedEvent {
/// Database file /// Database file
pub const DB_FILE: &str = "nostr.db"; pub const DB_FILE: &str = "nostr.db";
/// How frequently to attempt checkpointing /// Build repo
pub const CHECKPOINT_FREQ_SEC: u64 = 60;
/// Build a database connection pool.
/// # Panics /// # Panics
/// ///
/// Will panic if the pool could not be created. /// Will panic if the pool could not be created.
#[must_use] pub async fn build_repo(settings: &Settings) -> Arc<dyn NostrRepo> {
pub fn build_pool( match settings.database.engine.as_str() {
name: &str, "sqlite" => {Arc::new(build_sqlite_pool(settings).await)},
settings: &Settings, _ => panic!("Unknown database engine"),
flags: OpenFlags,
min_size: u32,
max_size: u32,
wait_for_db: bool,
) -> SqlitePool {
let db_dir = &settings.database.data_directory;
let full_path = Path::new(db_dir).join(DB_FILE);
// small hack; if the database doesn't exist yet, that means the
// writer thread hasn't finished. Give it a chance to work. This
// is only an issue with the first time we run.
if !settings.database.in_memory {
while !full_path.exists() && wait_for_db {
debug!("Database reader pool is waiting on the database to be created...");
thread::sleep(Duration::from_millis(500));
}
}
let manager = if settings.database.in_memory {
SqliteConnectionManager::memory()
.with_flags(flags)
.with_init(|c| c.execute_batch(STARTUP_SQL))
} else {
SqliteConnectionManager::file(&full_path)
.with_flags(flags)
.with_init(|c| c.execute_batch(STARTUP_SQL))
};
let pool: SqlitePool = r2d2::Pool::builder()
.test_on_check_out(true) // no noticeable performance hit
.min_idle(Some(min_size))
.max_size(max_size)
.max_lifetime(Some(Duration::from_secs(30)))
.build(manager)
.unwrap();
info!(
"Built a connection pool {:?} (min={}, max={})",
name, min_size, max_size
);
pool
}
/// Display database pool stats every 1 minute
pub async fn monitor_pool(name: &str, pool: SqlitePool) {
let sleep_dur = Duration::from_secs(60);
loop {
log_pool_stats(name, &pool);
tokio::time::sleep(sleep_dur).await;
} }
} }
async fn build_sqlite_pool(settings: &Settings) -> SqliteRepo {
/// Perform normal maintenance let repo = SqliteRepo::new(settings);
pub fn optimize_db(conn: &mut PooledConnection) -> Result<()> { repo.start().await.ok();
let start = Instant::now(); repo.migrate_up().await.ok();
conn.execute_batch("PRAGMA optimize;")?; repo
info!("optimize ran in {:?}", start.elapsed());
Ok(())
}
#[derive(Debug)]
enum SqliteStatus {
Ok,
Busy,
Error,
Other(u64),
} }
/// Checkpoint/Truncate WAL. Returns the number of WAL pages remaining. /// Spawn a database writer that persists events to the `SQLite` store.
pub fn checkpoint_db(conn: &mut PooledConnection) -> Result<usize> {
let query = "PRAGMA wal_checkpoint(TRUNCATE);";
let start = Instant::now();
let (cp_result, wal_size, _frames_checkpointed) = conn.query_row(query, [], |row| {
let checkpoint_result: u64 = row.get(0)?;
let wal_size: u64 = row.get(1)?;
let frames_checkpointed: u64 = row.get(2)?;
Ok((checkpoint_result, wal_size, frames_checkpointed))
})?;
let result = match cp_result {
0 => SqliteStatus::Ok,
1 => SqliteStatus::Busy,
2 => SqliteStatus::Error,
x => SqliteStatus::Other(x),
};
info!(
"checkpoint ran in {:?} (result: {:?}, WAL size: {})",
start.elapsed(),
result,
wal_size
);
Ok(wal_size as usize)
}
/// Spawn a database writer that persists events to the SQLite store.
pub async fn db_writer( pub async fn db_writer(
repo: Arc<dyn NostrRepo>,
settings: Settings, settings: Settings,
mut event_rx: tokio::sync::mpsc::Receiver<SubmittedEvent>, mut event_rx: tokio::sync::mpsc::Receiver<SubmittedEvent>,
bcast_tx: tokio::sync::broadcast::Sender<Event>, bcast_tx: tokio::sync::broadcast::Sender<Event>,
metadata_tx: tokio::sync::broadcast::Sender<Event>, metadata_tx: tokio::sync::broadcast::Sender<Event>,
mut shutdown: tokio::sync::broadcast::Receiver<()>, mut shutdown: tokio::sync::broadcast::Receiver<()>,
) -> tokio::task::JoinHandle<Result<()>> { ) -> Result<()> {
// are we performing NIP-05 checking? // are we performing NIP-05 checking?
let nip05_active = settings.verified_users.is_active(); let nip05_active = settings.verified_users.is_active();
// are we requriing NIP-05 user verification? // are we requriing NIP-05 user verification?
let nip05_enabled = settings.verified_users.is_enabled(); let nip05_enabled = settings.verified_users.is_enabled();
task::spawn_blocking(move || { //upgrade_db(&mut pool.get()?)?;
let db_dir = &settings.database.data_directory;
let full_path = Path::new(db_dir).join(DB_FILE);
// create a connection pool
let pool = build_pool(
"event writer",
&settings,
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
1,
2,
false,
);
if settings.database.in_memory {
info!("using in-memory database, this will not persist a restart!");
} else {
info!("opened database {:?} for writing", full_path);
}
upgrade_db(&mut pool.get()?)?;
// Make a copy of the whitelist // Make a copy of the whitelist
let whitelist = &settings.authorization.pubkey_whitelist.clone(); let whitelist = &settings.authorization.pubkey_whitelist.clone();
@ -194,7 +80,7 @@ pub async fn db_writer(
break; break;
} }
// call blocking read on channel // call blocking read on channel
let next_event = event_rx.blocking_recv(); let next_event = event_rx.recv().await;
// if the channel has closed, we will never get work // if the channel has closed, we will never get work
if next_event.is_none() { if next_event.is_none() {
break; break;
@ -254,7 +140,7 @@ pub async fn db_writer(
// check for NIP-05 verification // check for NIP-05 verification
if nip05_enabled { if nip05_enabled {
match nip05::query_latest_user_verification(pool.get()?, event.pubkey.to_owned()) { match repo.get_latest_user_verification(&event.pubkey).await {
Ok(uv) => { Ok(uv) => {
if uv.is_valid(&settings.verified_users) { if uv.is_valid(&settings.verified_users) {
info!( info!(
@ -306,9 +192,9 @@ pub async fn db_writer(
event.get_author_prefix(), event.get_author_prefix(),
start.elapsed() start.elapsed()
); );
event_write = true event_write = true;
} else { } else {
match write_event(&mut pool.get()?, &event) { match repo.write_event(&event).await {
Ok(updated) => { Ok(updated) => {
if updated == 0 { if updated == 0 {
trace!("ignoring duplicate or deleted event"); trace!("ignoring duplicate or deleted event");
@ -360,135 +246,6 @@ pub async fn db_writer(
} }
info!("database connection closed"); info!("database connection closed");
Ok(()) Ok(())
})
}
/// Persist an event to the database, returning rows added.
pub fn write_event(conn: &mut PooledConnection, e: &Event) -> Result<usize> {
// enable auto vacuum
conn.execute_batch("pragma auto_vacuum = FULL")?;
// start transaction
let tx = conn.transaction()?;
// get relevant fields from event and convert to blobs.
let id_blob = hex::decode(&e.id).ok();
let pubkey_blob: Option<Vec<u8>> = hex::decode(&e.pubkey).ok();
let delegator_blob: Option<Vec<u8>> = e.delegated_by.as_ref().and_then(|d| hex::decode(d).ok());
let event_str = serde_json::to_string(&e).ok();
// check for replaceable events that would hide this one; we won't even attempt to insert these.
if e.is_replaceable() {
let repl_count = tx.query_row(
"SELECT e.id FROM event e INDEXED BY author_index WHERE e.author=? AND e.kind=? AND e.created_at > ? LIMIT 1;",
params![pubkey_blob, e.kind, e.created_at], |row| row.get::<usize, usize>(0));
if repl_count.ok().is_some() {
return Ok(0);
}
}
// ignore if the event hash is a duplicate.
let mut ins_count = tx.execute(
"INSERT OR IGNORE INTO event (event_hash, created_at, kind, author, delegated_by, content, first_seen, hidden) VALUES (?1, ?2, ?3, ?4, ?5, ?6, strftime('%s','now'), FALSE);",
params![id_blob, e.created_at, e.kind, pubkey_blob, delegator_blob, event_str]
)?;
if ins_count == 0 {
// if the event was a duplicate, no need to insert event or
// pubkey references.
tx.rollback().ok();
return Ok(ins_count);
}
// remember primary key of the event most recently inserted.
let ev_id = tx.last_insert_rowid();
// add all tags to the tag table
for tag in e.tags.iter() {
// ensure we have 2 values.
if tag.len() >= 2 {
let tagname = &tag[0];
let tagval = &tag[1];
// only single-char tags are searchable
let tagchar_opt = single_char_tagname(tagname);
match &tagchar_opt {
Some(_) => {
// if tagvalue is lowercase hex;
if is_lower_hex(tagval) && (tagval.len() % 2 == 0) {
tx.execute(
"INSERT OR IGNORE INTO tag (event_id, name, value_hex) VALUES (?1, ?2, ?3)",
params![ev_id, &tagname, hex::decode(tagval).ok()],
)?;
} else {
tx.execute(
"INSERT OR IGNORE INTO tag (event_id, name, value) VALUES (?1, ?2, ?3)",
params![ev_id, &tagname, &tagval],
)?;
}
}
None => {}
}
}
}
// if this event is replaceable update, remove other replaceable
// event with the same kind from the same author that was issued
// earlier than this.
if e.is_replaceable() {
let author = hex::decode(&e.pubkey).ok();
// this is a backwards check - hide any events that were older.
let update_count = tx.execute(
"DELETE FROM event WHERE kind=? and author=? and id NOT IN (SELECT id FROM event INDEXED BY author_kind_index WHERE kind=? AND author=? ORDER BY created_at DESC LIMIT 1)",
params![e.kind, author, e.kind, author],
)?;
if update_count > 0 {
info!(
"removed {} older replaceable kind {} events for author: {:?}",
update_count,
e.kind,
e.get_author_prefix()
);
}
}
// if this event is a deletion, hide the referenced events from the same author.
if e.kind == 5 {
let event_candidates = e.tag_values_by_name("e");
// first parameter will be author
let mut params: Vec<Box<dyn ToSql>> = vec![Box::new(hex::decode(&e.pubkey)?)];
event_candidates
.iter()
.filter(|x| is_hex(x) && x.len() == 64)
.filter_map(|x| hex::decode(x).ok())
.for_each(|x| params.push(Box::new(x)));
let query = format!(
"UPDATE event SET hidden=TRUE WHERE kind!=5 AND author=? AND event_hash IN ({})",
repeat_vars(params.len() - 1)
);
let mut stmt = tx.prepare(&query)?;
let update_count = stmt.execute(rusqlite::params_from_iter(params))?;
info!(
"hid {} deleted events for author {:?}",
update_count,
e.get_author_prefix()
);
} else {
// check if a deletion has already been recorded for this event.
// Only relevant for non-deletion events
let del_count = tx.query_row(
"SELECT e.id FROM event e LEFT JOIN tag t ON e.id=t.event_id WHERE e.author=? AND t.name='e' AND e.kind=5 AND t.value_hex=? LIMIT 1;",
params![pubkey_blob, id_blob], |row| row.get::<usize, usize>(0));
// check if a the query returned a result, meaning we should
// hid the current event
if del_count.ok().is_some() {
// a deletion already existed, mark original event as hidden.
info!(
"hid event: {:?} due to existing deletion by author: {:?}",
e.get_event_id_prefix(),
e.get_author_prefix()
);
let _update_count =
tx.execute("UPDATE event SET hidden=TRUE WHERE id=?", params![ev_id])?;
// event was deleted, so let caller know nothing new
// arrived, preventing this from being sent to active
// subscriptions
ins_count = 0;
}
}
tx.commit()?;
Ok(ins_count)
} }
/// Serialized event associated with a specific subscription request. /// Serialized event associated with a specific subscription request.
@ -499,459 +256,3 @@ pub struct QueryResult {
/// Serialized event /// Serialized event
pub event: String, pub event: String,
} }
/// Produce a arbitrary list of '?' parameters.
fn repeat_vars(count: usize) -> String {
if count == 0 {
return "".to_owned();
}
let mut s = "?,".repeat(count);
// Remove trailing comma
s.pop();
s
}
/// Decide if there is an index that should be used explicitly
fn override_index(f: &ReqFilter) -> Option<String> {
// queries for multiple kinds default to kind_index, which is
// significantly slower than kind_created_at_index.
if let Some(ks) = &f.kinds {
if f.ids.is_none() &&
ks.len() > 1 &&
f.since.is_none() &&
f.until.is_none() &&
f.tags.is_none() &&
f.authors.is_none() {
return Some("kind_created_at_index".into());
}
}
// if there is an author, it is much better to force the authors index.
if let Some(_) = &f.authors {
if f.since.is_none() && f.until.is_none() {
if f.kinds.is_none() {
// with no use of kinds/created_at, just author
return Some("author_index".into());
} else {
// prefer author_kind if there are kinds
return Some("author_kind_index".into());
}
} else {
// finally, prefer author_created_at if time is provided
return Some("author_created_at_index".into());
}
}
None
}
/// Create a dynamic SQL subquery and params from a subscription filter (and optional explicit index used)
fn query_from_filter(f: &ReqFilter) -> (String, Vec<Box<dyn ToSql>>, Option<String>) {
// build a dynamic SQL query. all user-input is either an integer
// (sqli-safe), or a string that is filtered to only contain
// hexadecimal characters. Strings that require escaping (tag
// names/values) use parameters.
// if the filter is malformed, don't return anything.
if f.force_no_match {
let empty_query = "SELECT e.content, e.created_at FROM event e WHERE 1=0".to_owned();
// query parameters for SQLite
let empty_params: Vec<Box<dyn ToSql>> = vec![];
return (empty_query, empty_params, None);
}
// check if the index needs to be overriden
let idx_name = override_index(f);
let idx_stmt = idx_name.as_ref().map_or_else(|| "".to_owned(), |i| format!("INDEXED BY {}",i));
let mut query = format!("SELECT e.content, e.created_at FROM event e {}", idx_stmt);
// query parameters for SQLite
let mut params: Vec<Box<dyn ToSql>> = vec![];
// individual filter components (single conditions such as an author or event ID)
let mut filter_components: Vec<String> = Vec::new();
// Query for "authors", allowing prefix matches
if let Some(authvec) = &f.authors {
// take each author and convert to a hexsearch
let mut auth_searches: Vec<String> = vec![];
for auth in authvec {
match hex_range(auth) {
Some(HexSearch::Exact(ex)) => {
auth_searches.push("author=?".to_owned());
params.push(Box::new(ex));
}
Some(HexSearch::Range(lower, upper)) => {
auth_searches.push(
"(author>? AND author<?)".to_owned(),
);
params.push(Box::new(lower));
params.push(Box::new(upper));
}
Some(HexSearch::LowerOnly(lower)) => {
auth_searches.push("author>?".to_owned());
params.push(Box::new(lower));
}
None => {
info!("Could not parse hex range from author {:?}", auth);
}
}
}
if !authvec.is_empty() {
let auth_clause = format!("({})", auth_searches.join(" OR "));
filter_components.push(auth_clause);
} else {
filter_components.push("false".to_owned());
}
}
// Query for Kind
if let Some(ks) = &f.kinds {
// kind is number, no escaping needed
let str_kinds: Vec<String> = ks.iter().map(|x| x.to_string()).collect();
let kind_clause = format!("kind IN ({})", str_kinds.join(", "));
filter_components.push(kind_clause);
}
// Query for event, allowing prefix matches
if let Some(idvec) = &f.ids {
// take each author and convert to a hexsearch
let mut id_searches: Vec<String> = vec![];
for id in idvec {
match hex_range(id) {
Some(HexSearch::Exact(ex)) => {
id_searches.push("event_hash=?".to_owned());
params.push(Box::new(ex));
}
Some(HexSearch::Range(lower, upper)) => {
id_searches.push("(event_hash>? AND event_hash<?)".to_owned());
params.push(Box::new(lower));
params.push(Box::new(upper));
}
Some(HexSearch::LowerOnly(lower)) => {
id_searches.push("event_hash>?".to_owned());
params.push(Box::new(lower));
}
None => {
info!("Could not parse hex range from id {:?}", id);
}
}
}
if !idvec.is_empty() {
let id_clause = format!("({})", id_searches.join(" OR "));
filter_components.push(id_clause);
} else {
// if the ids list was empty, we should never return
// any results.
filter_components.push("false".to_owned());
}
}
// Query for tags
if let Some(map) = &f.tags {
for (key, val) in map.iter() {
let mut str_vals: Vec<Box<dyn ToSql>> = vec![];
let mut blob_vals: Vec<Box<dyn ToSql>> = vec![];
for v in val {
if (v.len() % 2 == 0) && is_lower_hex(v) {
if let Ok(h) = hex::decode(v) {
blob_vals.push(Box::new(h));
}
} else {
str_vals.push(Box::new(v.to_owned()));
}
}
// create clauses with "?" params for each tag value being searched
let str_clause = format!("value IN ({})", repeat_vars(str_vals.len()));
let blob_clause = format!("value_hex IN ({})", repeat_vars(blob_vals.len()));
// find evidence of the target tag name/value existing for this event.
let tag_clause = format!(
"e.id IN (SELECT e.id FROM event e LEFT JOIN tag t on e.id=t.event_id WHERE hidden!=TRUE and (name=? AND ({} OR {})))",
str_clause, blob_clause
);
// add the tag name as the first parameter
params.push(Box::new(key.to_string()));
// add all tag values that are plain strings as params
params.append(&mut str_vals);
// add all tag values that are blobs as params
params.append(&mut blob_vals);
filter_components.push(tag_clause);
}
}
// Query for timestamp
if f.since.is_some() {
let created_clause = format!("created_at > {}", f.since.unwrap());
filter_components.push(created_clause);
}
// Query for timestamp
if f.until.is_some() {
let until_clause = format!("created_at < {}", f.until.unwrap());
filter_components.push(until_clause);
}
// never display hidden events
query.push_str(" WHERE hidden!=TRUE");
// build filter component conditions
if !filter_components.is_empty() {
query.push_str(" AND ");
query.push_str(&filter_components.join(" AND "));
}
// Apply per-filter limit to this subquery.
// The use of a LIMIT implies a DESC order, to capture only the most recent events.
if let Some(lim) = f.limit {
let _ = write!(query, " ORDER BY e.created_at DESC LIMIT {}", lim);
} else {
query.push_str(" ORDER BY e.created_at ASC")
}
(query, params, idx_name)
}
/// Create a dynamic SQL query string and params from a subscription.
fn query_from_sub(sub: &Subscription) -> (String, Vec<Box<dyn ToSql>>, Vec<String>) {
// build a dynamic SQL query for an entire subscription, based on
// SQL subqueries for filters.
let mut subqueries: Vec<String> = Vec::new();
let mut indexes = vec![];
// subquery params
let mut params: Vec<Box<dyn ToSql>> = vec![];
// for every filter in the subscription, generate a subquery
for f in sub.filters.iter() {
let (f_subquery, mut f_params, index) = query_from_filter(f);
if let Some(i) = index {
indexes.push(i);
}
subqueries.push(f_subquery);
params.append(&mut f_params);
}
// encapsulate subqueries into select statements
let subqueries_selects: Vec<String> = subqueries
.iter()
.map(|s| format!("SELECT distinct content, created_at FROM ({})", s))
.collect();
let query: String = subqueries_selects.join(" UNION ");
(query, params,indexes)
}
/// Check if the pool is fully utilized
fn _pool_at_capacity(pool: &SqlitePool) -> bool {
let state: r2d2::State = pool.state();
state.idle_connections == 0
}
/// Log pool stats
fn log_pool_stats(name: &str, pool: &SqlitePool) {
let state: r2d2::State = pool.state();
let in_use_cxns = state.connections - state.idle_connections;
debug!(
"DB pool {:?} usage (in_use: {}, available: {}, max: {})",
name,
in_use_cxns,
state.connections,
pool.max_size()
);
}
/// Perform database maintenance on a regular basis
pub async fn db_optimize_task(pool: SqlitePool) {
tokio::task::spawn(async move {
loop {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(60*60)) => {
if let Ok(mut conn) = pool.get() {
// the busy timer will block writers, so don't set
// this any higher than you want max latency for event
// writes.
info!("running database optimizer");
optimize_db(&mut conn).ok();
}
}
};
}
});
}
/// Perform database WAL checkpoint on a regular basis
pub async fn db_checkpoint_task(pool: SqlitePool, safe_to_read: Arc<Mutex<u64>>) {
tokio::task::spawn(async move {
// WAL size in pages.
let mut current_wal_size = 0;
// WAL threshold for more aggressive checkpointing (10,000 pages, or about 40MB)
let wal_threshold = 1000*10;
// default threshold for the busy timer
let busy_wait_default = Duration::from_secs(1);
// if the WAL file is getting too big, switch to this
let busy_wait_default_long = Duration::from_secs(10);
loop {
tokio::select! {
_ = tokio::time::sleep(Duration::from_secs(CHECKPOINT_FREQ_SEC)) => {
if let Ok(mut conn) = pool.get() {
let mut _guard:Option<MutexGuard<u64>> = None;
// the busy timer will block writers, so don't set
// this any higher than you want max latency for event
// writes.
if current_wal_size <= wal_threshold {
conn.busy_timeout(busy_wait_default).ok();
} else {
// if the wal size has exceeded a threshold, increase the busy timeout.
conn.busy_timeout(busy_wait_default_long).ok();
// take a lock that will prevent new readers.
info!("blocking new readers to perform wal_checkpoint");
_guard = Some(safe_to_read.lock().await);
}
debug!("running wal_checkpoint(TRUNCATE)");
if let Ok(new_size) = checkpoint_db(&mut conn) {
current_wal_size = new_size;
}
}
}
};
}
});
}
/// Perform a database query using a subscription.
///
/// The [`Subscription`] is converted into a SQL query. Each result
/// is published on the `query_tx` channel as it is returned. If a
/// message becomes available on the `abandon_query_rx` channel, the
/// query is immediately aborted.
pub async fn db_query(
sub: Subscription,
client_id: String,
pool: SqlitePool,
query_tx: tokio::sync::mpsc::Sender<QueryResult>,
mut abandon_query_rx: tokio::sync::oneshot::Receiver<()>,
safe_to_read: Arc<Mutex<u64>>,
) {
let pre_spawn_start = Instant::now();
task::spawn_blocking(move || {
{
// if we are waiting on a checkpoint, stop until it is complete
let _ = safe_to_read.blocking_lock();
}
let db_queue_time = pre_spawn_start.elapsed();
// if the queue time was very long (>5 seconds), spare the DB and abort.
if db_queue_time > Duration::from_secs(5) {
info!(
"shedding DB query load queued for {:?} (cid: {}, sub: {:?})",
db_queue_time, client_id, sub.id
);
return Ok(());
}
// otherwise, report queuing time if it is slow
else if db_queue_time > Duration::from_secs(1) {
debug!(
"(slow) DB query queued for {:?} (cid: {}, sub: {:?})",
db_queue_time, client_id, sub.id
);
}
let start = Instant::now();
let mut row_count: usize = 0;
// generate SQL query
let (q, p, idxs) = query_from_sub(&sub);
let sql_gen_elapsed = start.elapsed();
if sql_gen_elapsed > Duration::from_millis(10) {
debug!("SQL (slow) generated in {:?}", start.elapsed());
}
// cutoff for displaying slow queries
let slow_cutoff = Duration::from_millis(2000);
// any client that doesn't cause us to generate new rows in 5
// seconds gets dropped.
let abort_cutoff = Duration::from_secs(5);
let start = Instant::now();
let mut slow_first_event;
let mut last_successful_send = Instant::now();
if let Ok(mut conn) = pool.get() {
// execute the query.
// make the actual SQL query (with parameters inserted) available
conn.trace(Some(|x| {trace!("SQL trace: {:?}", x)}));
let mut stmt = conn.prepare_cached(&q)?;
let mut event_rows = stmt.query(rusqlite::params_from_iter(p))?;
let mut first_result = true;
while let Some(row) = event_rows.next()? {
let first_event_elapsed = start.elapsed();
slow_first_event = first_event_elapsed >= slow_cutoff;
if first_result {
debug!(
"first result in {:?} (cid: {}, sub: {:?}) [used indexes: {:?}]",
first_event_elapsed, client_id, sub.id, idxs
);
first_result = false;
}
// logging for slow queries; show sub and SQL.
// to reduce logging; only show 1/16th of clients (leading 0)
if row_count == 0 && slow_first_event && client_id.starts_with('0') {
debug!(
"query req (slow): {:?} (cid: {}, sub: {:?})",
sub, client_id, sub.id
);
}
// check if a checkpoint is trying to run, and abort
if row_count % 100 == 0 {
{
if safe_to_read.try_lock().is_err() {
// lock was held, abort this query
debug!("query aborted due to checkpoint (cid: {}, sub: {:?})", client_id, sub.id);
return Ok(());
}
}
}
// check if this is still active; every 100 rows
if row_count % 100 == 0 && abandon_query_rx.try_recv().is_ok() {
debug!("query aborted (cid: {}, sub: {:?})", client_id, sub.id);
return Ok(());
}
row_count += 1;
let event_json = row.get(0)?;
loop {
if query_tx.capacity() != 0 {
// we have capacity to add another item
break;
} else {
// the queue is full
trace!("db reader thread is stalled");
if last_successful_send + abort_cutoff < Instant::now() {
// the queue has been full for too long, abort
info!("aborting database query due to slow client (cid: {}, sub: {:?})",
client_id, sub.id);
let ok: Result<()> = Ok(());
return ok;
}
// check if a checkpoint is trying to run, and abort
if safe_to_read.try_lock().is_err() {
// lock was held, abort this query
debug!("query aborted due to checkpoint (cid: {}, sub: {:?})", client_id, sub.id);
return Ok(());
}
// give the queue a chance to clear before trying again
thread::sleep(Duration::from_millis(100));
}
}
// TODO: we could use try_send, but we'd have to juggle
// getting the query result back as part of the error
// result.
query_tx
.blocking_send(QueryResult {
sub_id: sub.get_id(),
event: event_json,
})
.ok();
last_successful_send = Instant::now();
}
query_tx
.blocking_send(QueryResult {
sub_id: sub.get_id(),
event: "EOSE".to_string(),
})
.ok();
debug!(
"query completed in {:?} (cid: {}, sub: {:?}, db_time: {:?}, rows: {})",
pre_spawn_start.elapsed(),
client_id,
sub.id,
start.elapsed(),
row_count
);
} else {
warn!("Could not get a database connection for querying");
}
let ok: Result<()> = Ok(());
ok
});
}

View File

@ -84,7 +84,7 @@ pub struct ConditionQuery {
} }
impl ConditionQuery { impl ConditionQuery {
pub fn allows_event(&self, event: &Event) -> bool { #[must_use] pub fn allows_event(&self, event: &Event) -> bool {
// check each condition, to ensure that the event complies // check each condition, to ensure that the event complies
// with the restriction. // with the restriction.
for c in &self.conditions { for c in &self.conditions {
@ -101,7 +101,7 @@ impl ConditionQuery {
} }
// Verify that the delegator approved the delegation; return a ConditionQuery if so. // Verify that the delegator approved the delegation; return a ConditionQuery if so.
pub fn validate_delegation( #[must_use] pub fn validate_delegation(
delegator: &str, delegator: &str,
delegatee: &str, delegatee: &str,
cond_query: &str, cond_query: &str,
@ -133,8 +133,8 @@ pub fn validate_delegation(
} }
/// Parsed delegation condition /// Parsed delegation condition
/// see https://github.com/nostr-protocol/nips/pull/28#pullrequestreview-1084903800 /// see <https://github.com/nostr-protocol/nips/pull/28#pullrequestreview-1084903800>
/// An example complex condition would be: kind=1,2,3&created_at<1665265999 /// An example complex condition would be: `kind=1,2,3&created_at<1665265999`
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
pub struct Condition { pub struct Condition {
pub field: Field, pub field: Field,
@ -144,7 +144,7 @@ pub struct Condition {
impl Condition { impl Condition {
/// Check if this condition allows the given event to be delegated /// Check if this condition allows the given event to be delegated
pub fn allows_event(&self, event: &Event) -> bool { #[must_use] pub fn allows_event(&self, event: &Event) -> bool {
// determine what the right-hand side of the operator is // determine what the right-hand side of the operator is
let resolved_field = match &self.field { let resolved_field = match &self.field {
Field::Kind => event.kind, Field::Kind => event.kind,
@ -323,7 +323,7 @@ mod tests {
Condition { Condition {
field: Field::CreatedAt, field: Field::CreatedAt,
operator: Operator::LessThan, operator: Operator::LessThan,
values: vec![1665867123], values: vec![1_665_867_123],
}, },
], ],
}; };

View File

@ -1,6 +1,6 @@
//! Event parsing and validation //! Event parsing and validation
use crate::delegation::validate_delegation; use crate::delegation::validate_delegation;
use crate::error::Error::*; use crate::error::Error::{CommandUnknownError, EventCouldNotCanonicalize, EventInvalidId, EventInvalidSignature, EventMalformedPubkey};
use crate::error::Result; use crate::error::Result;
use crate::nip05; use crate::nip05;
use crate::utils::unix_time; use crate::utils::unix_time;
@ -28,7 +28,7 @@ pub struct EventCmd {
} }
impl EventCmd { impl EventCmd {
pub fn event_id(&self) -> &str { #[must_use] pub fn event_id(&self) -> &str {
&self.event.id &self.event.id
} }
} }
@ -65,7 +65,7 @@ where
} }
/// Attempt to form a single-char tag name. /// Attempt to form a single-char tag name.
pub fn single_char_tagname(tagname: &str) -> Option<char> { #[must_use] pub fn single_char_tagname(tagname: &str) -> Option<char> {
// We return the tag character if and only if the tagname consists // We return the tag character if and only if the tagname consists
// of a single char. // of a single char.
let mut tagnamechars = tagname.chars(); let mut tagnamechars = tagname.chars();
@ -87,22 +87,22 @@ pub fn single_char_tagname(tagname: &str) -> Option<char> {
impl From<EventCmd> for Result<Event> { impl From<EventCmd> for Result<Event> {
fn from(ec: EventCmd) -> Result<Event> { fn from(ec: EventCmd) -> Result<Event> {
// ensure command is correct // ensure command is correct
if ec.cmd != "EVENT" { if ec.cmd == "EVENT" {
Err(CommandUnknownError)
} else {
ec.event.validate().map(|_| { ec.event.validate().map(|_| {
let mut e = ec.event; let mut e = ec.event;
e.build_index(); e.build_index();
e.update_delegation(); e.update_delegation();
e e
}) })
} else {
Err(CommandUnknownError)
} }
} }
} }
impl Event { impl Event {
#[cfg(test)] #[cfg(test)]
pub fn simple_event() -> Event { #[must_use] pub fn simple_event() -> Event {
Event { Event {
id: "0".to_owned(), id: "0".to_owned(),
pubkey: "0".to_owned(), pubkey: "0".to_owned(),
@ -116,17 +116,17 @@ impl Event {
} }
} }
pub fn is_kind_metadata(&self) -> bool { #[must_use] pub fn is_kind_metadata(&self) -> bool {
self.kind == 0 self.kind == 0
} }
/// Should this event be replaced with newer timestamps from same author? /// Should this event be replaced with newer timestamps from same author?
pub fn is_replaceable(&self) -> bool { #[must_use] pub fn is_replaceable(&self) -> bool {
self.kind == 0 || self.kind == 3 || self.kind == 41 || (self.kind >= 10000 && self.kind < 20000) self.kind == 0 || self.kind == 3 || self.kind == 41 || (self.kind >= 10000 && self.kind < 20000)
} }
/// Pull a NIP-05 Name out of the event, if one exists /// Pull a NIP-05 Name out of the event, if one exists
pub fn get_nip05_addr(&self) -> Option<nip05::Nip05Name> { #[must_use] pub fn get_nip05_addr(&self) -> Option<nip05::Nip05Name> {
if self.is_kind_metadata() { if self.is_kind_metadata() {
// very quick check if we should attempt to parse this json // very quick check if we should attempt to parse this json
if self.content.contains("\"nip05\"") { if self.content.contains("\"nip05\"") {
@ -143,7 +143,7 @@ impl Event {
// is this event delegated (properly)? // is this event delegated (properly)?
// does the signature match, and are conditions valid? // does the signature match, and are conditions valid?
// if so, return an alternate author for the event // if so, return an alternate author for the event
pub fn delegated_author(&self) -> Option<String> { #[must_use] pub fn delegated_author(&self) -> Option<String> {
// is there a delegation tag? // is there a delegation tag?
let delegation_tag: Vec<String> = self let delegation_tag: Vec<String> = self
.tags .tags
@ -151,8 +151,7 @@ impl Event {
.filter(|x| x.len() == 4) .filter(|x| x.len() == 4)
.filter(|x| x.get(0).unwrap() == "delegation") .filter(|x| x.get(0).unwrap() == "delegation")
.take(1) .take(1)
.next()? .next()?.clone(); // get first tag
.to_vec(); // get first tag
//let delegation_tag = self.tag_values_by_name("delegation"); //let delegation_tag = self.tag_values_by_name("delegation");
// delegation tags should have exactly 3 elements after the name (pubkey, condition, sig) // delegation tags should have exactly 3 elements after the name (pubkey, condition, sig)
@ -212,24 +211,24 @@ impl Event {
} }
/// Create a short event identifier, suitable for logging. /// Create a short event identifier, suitable for logging.
pub fn get_event_id_prefix(&self) -> String { #[must_use] pub fn get_event_id_prefix(&self) -> String {
self.id.chars().take(8).collect() self.id.chars().take(8).collect()
} }
pub fn get_author_prefix(&self) -> String { #[must_use] pub fn get_author_prefix(&self) -> String {
self.pubkey.chars().take(8).collect() self.pubkey.chars().take(8).collect()
} }
/// Retrieve tag initial values across all tags matching the name /// Retrieve tag initial values across all tags matching the name
pub fn tag_values_by_name(&self, tag_name: &str) -> Vec<String> { #[must_use] pub fn tag_values_by_name(&self, tag_name: &str) -> Vec<String> {
self.tags self.tags
.iter() .iter()
.filter(|x| x.len() > 1) .filter(|x| x.len() > 1)
.filter(|x| x.get(0).unwrap() == tag_name) .filter(|x| x.get(0).unwrap() == tag_name)
.map(|x| x.get(1).unwrap().to_owned()) .map(|x| x.get(1).unwrap().clone())
.collect() .collect()
} }
pub fn is_valid_timestamp(&self, reject_future_seconds: Option<usize>) -> bool { #[must_use] pub fn is_valid_timestamp(&self, reject_future_seconds: Option<usize>) -> bool {
if let Some(allowable_future) = reject_future_seconds { if let Some(allowable_future) = reject_future_seconds {
let curr_time = unix_time(); let curr_time = unix_time();
// calculate difference, plus how far future we allow // calculate difference, plus how far future we allow
@ -291,7 +290,7 @@ impl Event {
let id = Number::from(0_u64); let id = Number::from(0_u64);
c.push(serde_json::Value::Number(id)); c.push(serde_json::Value::Number(id));
// public key // public key
c.push(Value::String(self.pubkey.to_owned())); c.push(Value::String(self.pubkey.clone()));
// creation time // creation time
let created_at = Number::from(self.created_at); let created_at = Number::from(self.created_at);
c.push(serde_json::Value::Number(created_at)); c.push(serde_json::Value::Number(created_at));
@ -301,7 +300,7 @@ impl Event {
// tags // tags
c.push(self.tags_to_canonical()); c.push(self.tags_to_canonical());
// content // content
c.push(Value::String(self.content.to_owned())); c.push(Value::String(self.content.clone()));
serde_json::to_string(&Value::Array(c)).ok() serde_json::to_string(&Value::Array(c)).ok()
} }
@ -309,11 +308,11 @@ impl Event {
fn tags_to_canonical(&self) -> Value { fn tags_to_canonical(&self) -> Value {
let mut tags = Vec::<Value>::new(); let mut tags = Vec::<Value>::new();
// iterate over self tags, // iterate over self tags,
for t in self.tags.iter() { for t in &self.tags {
// each tag is a vec of strings // each tag is a vec of strings
let mut a = Vec::<Value>::new(); let mut a = Vec::<Value>::new();
for v in t.iter() { for v in t.iter() {
a.push(serde_json::Value::String(v.to_owned())); a.push(serde_json::Value::String(v.clone()));
} }
tags.push(serde_json::Value::Array(a)); tags.push(serde_json::Value::Array(a));
} }
@ -321,7 +320,7 @@ impl Event {
} }
/// Determine if the given tag and value set intersect with tags in this event. /// Determine if the given tag and value set intersect with tags in this event.
pub fn generic_tag_val_intersect(&self, tagname: char, check: &HashSet<String>) -> bool { #[must_use] pub fn generic_tag_val_intersect(&self, tagname: char, check: &HashSet<String>) -> bool {
match &self.tagidx { match &self.tagidx {
// check if this is indexable tagname // check if this is indexable tagname
Some(idx) => match idx.get(&tagname) { Some(idx) => match idx.get(&tagname) {
@ -413,7 +412,7 @@ mod tests {
id: "999".to_owned(), id: "999".to_owned(),
pubkey: "012345".to_owned(), pubkey: "012345".to_owned(),
delegated_by: None, delegated_by: None,
created_at: 501234, created_at: 501_234,
kind: 1, kind: 1,
tags: vec![], tags: vec![],
content: "this is a test".to_owned(), content: "this is a test".to_owned(),
@ -431,7 +430,7 @@ mod tests {
id: "999".to_owned(), id: "999".to_owned(),
pubkey: "012345".to_owned(), pubkey: "012345".to_owned(),
delegated_by: None, delegated_by: None,
created_at: 501234, created_at: 501_234,
kind: 1, kind: 1,
tags: vec![ tags: vec![
vec!["j".to_owned(), "abc".to_owned()], vec!["j".to_owned(), "abc".to_owned()],
@ -458,7 +457,7 @@ mod tests {
id: "999".to_owned(), id: "999".to_owned(),
pubkey: "012345".to_owned(), pubkey: "012345".to_owned(),
delegated_by: None, delegated_by: None,
created_at: 501234, created_at: 501_234,
kind: 1, kind: 1,
tags: vec![ tags: vec![
vec!["j".to_owned(), "abc".to_owned()], vec!["j".to_owned(), "abc".to_owned()],
@ -485,7 +484,7 @@ mod tests {
id: "999".to_owned(), id: "999".to_owned(),
pubkey: "012345".to_owned(), pubkey: "012345".to_owned(),
delegated_by: None, delegated_by: None,
created_at: 501234, created_at: 501_234,
kind: 1, kind: 1,
tags: vec![ tags: vec![
vec!["#e".to_owned(), "aoeu".to_owned()], vec!["#e".to_owned(), "aoeu".to_owned()],

View File

@ -19,7 +19,7 @@ fn is_all_fs(s: &str) -> bool {
} }
/// Find the next hex sequence greater than the argument. /// Find the next hex sequence greater than the argument.
pub fn hex_range(s: &str) -> Option<HexSearch> { #[must_use] pub fn hex_range(s: &str) -> Option<HexSearch> {
// handle special cases // handle special cases
if !is_hex(s) || s.len() > 64 { if !is_hex(s) || s.len() > 64 {
return None; return None;

View File

@ -37,7 +37,7 @@ impl From<config::Info> for RelayInfo {
contact: i.contact, contact: i.contact,
supported_nips: Some(vec![1, 2, 9, 11, 12, 15, 16, 20, 22]), supported_nips: Some(vec![1, 2, 9, 11, 12, 15, 16, 20, 22]),
software: Some("https://git.sr.ht/~gheartsfield/nostr-rs-relay".to_owned()), software: Some("https://git.sr.ht/~gheartsfield/nostr-rs-relay".to_owned()),
version: CARGO_PKG_VERSION.map(|x| x.to_owned()), version: CARGO_PKG_VERSION.map(std::borrow::ToOwned::to_owned),
} }
} }
} }

View File

@ -10,7 +10,7 @@ pub mod hexrange;
pub mod info; pub mod info;
pub mod nip05; pub mod nip05;
pub mod notice; pub mod notice;
pub mod schema; pub mod repo;
pub mod subscription; pub mod subscription;
pub mod utils; pub mod utils;
// Public API for creating relays programatically // Public API for creating relays programatically

View File

@ -1,6 +1,6 @@
//! Server process //! Server process
use clap::Parser; use clap::Parser;
use nostr_rs_relay::cli::*; use nostr_rs_relay::cli::CLIArgs;
use nostr_rs_relay::config; use nostr_rs_relay::config;
use nostr_rs_relay::server::start_server; use nostr_rs_relay::server::start_server;
use std::sync::mpsc as syncmpsc; use std::sync::mpsc as syncmpsc;
@ -37,12 +37,15 @@ fn main() {
if let Some(db_dir) = db_dir_arg { if let Some(db_dir) = db_dir_arg {
settings.database.data_directory = db_dir; settings.database.data_directory = db_dir;
} }
// we should have a 'control plane' channel to monitor and bump
// the server. this will let us do stuff like clear the database,
// shutdown, etc.; for now all this does is initiate shutdown if
// `()` is sent. This will change in the future, this is just a
// stopgap to shutdown the relay when it is used as a library.
let (_, ctrl_rx): (MpscSender<()>, MpscReceiver<()>) = syncmpsc::channel(); let (_, ctrl_rx): (MpscSender<()>, MpscReceiver<()>) = syncmpsc::channel();
// run this in a new thread // run this in a new thread
let handle = thread::spawn(|| { let handle = thread::spawn(move || {
// we should have a 'control plane' channel to monitor and bump the server. let _svr = start_server(&settings, ctrl_rx);
// this will let us do stuff like clear the database, shutdown, etc.
let _svr = start_server(settings, ctrl_rx);
}); });
// block on nostr thread to finish. // block on nostr thread to finish.
handle.join().unwrap(); handle.join().unwrap();

View File

@ -5,16 +5,14 @@
//! consumes a stream of metadata events, and keeps a database table //! consumes a stream of metadata events, and keeps a database table
//! updated with the current NIP-05 verification status. //! updated with the current NIP-05 verification status.
use crate::config::VerifiedUsers; use crate::config::VerifiedUsers;
use crate::db;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::event::Event; use crate::event::Event;
use crate::utils::unix_time; use crate::repo::NostrRepo;
use std::sync::Arc;
use hyper::body::HttpBody; use hyper::body::HttpBody;
use hyper::client::connect::HttpConnector; use hyper::client::connect::HttpConnector;
use hyper::Client; use hyper::Client;
use hyper_tls::HttpsConnector; use hyper_tls::HttpsConnector;
use rand::Rng;
use rusqlite::params;
use std::time::Duration; use std::time::Duration;
use std::time::Instant; use std::time::Instant;
use std::time::SystemTime; use std::time::SystemTime;
@ -23,14 +21,12 @@ use tracing::{debug, info, warn};
/// NIP-05 verifier state /// NIP-05 verifier state
pub struct Verifier { pub struct Verifier {
/// Repository for saving/retrieving events and records
repo: Arc<dyn NostrRepo>,
/// Metadata events for us to inspect /// Metadata events for us to inspect
metadata_rx: tokio::sync::broadcast::Receiver<Event>, metadata_rx: tokio::sync::broadcast::Receiver<Event>,
/// Newly validated events get written and then broadcast on this channel to subscribers /// Newly validated events get written and then broadcast on this channel to subscribers
event_tx: tokio::sync::broadcast::Sender<Event>, event_tx: tokio::sync::broadcast::Sender<Event>,
/// SQLite read query pool
read_pool: db::SqlitePool,
/// SQLite write query pool
write_pool: db::SqlitePool,
/// Settings /// Settings
settings: crate::config::Settings, settings: crate::config::Settings,
/// HTTP client /// HTTP client
@ -52,7 +48,7 @@ pub struct Nip05Name {
impl Nip05Name { impl Nip05Name {
/// Does this name represent the entire domain? /// Does this name represent the entire domain?
pub fn is_domain_only(&self) -> bool { #[must_use] pub fn is_domain_only(&self) -> bool {
self.local == "_" self.local == "_"
} }
@ -73,16 +69,11 @@ impl std::convert::TryFrom<&str> for Nip05Name {
fn try_from(inet: &str) -> Result<Self, Self::Error> { fn try_from(inet: &str) -> Result<Self, Self::Error> {
// break full name at the @ boundary. // break full name at the @ boundary.
let components: Vec<&str> = inet.split('@').collect(); let components: Vec<&str> = inet.split('@').collect();
if components.len() != 2 { if components.len() == 2 {
Err(Error::CustomError("too many/few components".to_owned()))
} else {
// check if local name is valid // check if local name is valid
let local = components[0]; let local = components[0];
let domain = components[1]; let domain = components[1];
if local if local.chars().all(|x| x.is_alphanumeric() || x == '_' || x == '-' || x == '.') {
.chars()
.all(|x| x.is_alphanumeric() || x == '_' || x == '-' || x == '.')
{
if domain if domain
.chars() .chars()
.all(|x| x.is_alphanumeric() || x == '-' || x == '.') .all(|x| x.is_alphanumeric() || x == '-' || x == '.')
@ -101,6 +92,8 @@ impl std::convert::TryFrom<&str> for Nip05Name {
"invalid character in local part".to_owned(), "invalid character in local part".to_owned(),
)) ))
} }
} else {
Err(Error::CustomError("too many/few components".to_owned()))
} }
} }
} }
@ -111,55 +104,30 @@ impl std::fmt::Display for Nip05Name {
} }
} }
// Current time, with a slight foward jitter in seconds
fn now_jitter(sec: u64) -> u64 {
// random time between now, and 10min in future.
let mut rng = rand::thread_rng();
let jitter_amount = rng.gen_range(0..sec);
let now = unix_time();
now.saturating_add(jitter_amount)
}
/// Check if the specified username and address are present and match in this response body /// Check if the specified username and address are present and match in this response body
fn body_contains_user(username: &str, address: &str, bytes: hyper::body::Bytes) -> Result<bool> { fn body_contains_user(username: &str, address: &str, bytes: &hyper::body::Bytes) -> Result<bool> {
// convert the body into json // convert the body into json
let body: serde_json::Value = serde_json::from_slice(&bytes)?; let body: serde_json::Value = serde_json::from_slice(&bytes)?;
// ensure we have a names object. // ensure we have a names object.
let names_map = body let names_map = body
.as_object() .as_object()
.and_then(|x| x.get("names")) .and_then(|x| x.get("names"))
.and_then(|x| x.as_object()) .and_then(serde_json::Value::as_object)
.ok_or_else(|| Error::CustomError("not a map".to_owned()))?; .ok_or_else(|| Error::CustomError("not a map".to_owned()))?;
// get the pubkey for the requested user // get the pubkey for the requested user
let check_name = names_map.get(username).and_then(|x| x.as_str()); let check_name = names_map.get(username).and_then(serde_json::Value::as_str);
// ensure the address is a match // ensure the address is a match
Ok(check_name.map(|x| x == address).unwrap_or(false)) Ok(check_name.map_or(false, |x| x == address))
} }
impl Verifier { impl Verifier {
pub fn new( pub fn new(
repo: Arc<dyn NostrRepo>,
metadata_rx: tokio::sync::broadcast::Receiver<Event>, metadata_rx: tokio::sync::broadcast::Receiver<Event>,
event_tx: tokio::sync::broadcast::Sender<Event>, event_tx: tokio::sync::broadcast::Sender<Event>,
settings: crate::config::Settings, settings: crate::config::Settings,
) -> Result<Self> { ) -> Result<Self> {
info!("creating NIP-05 verifier"); info!("creating NIP-05 verifier");
// build a database connection for reading and writing.
let write_pool = db::build_pool(
"nip05 writer",
&settings,
rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
1, // min conns
4, // max conns
true, // wait for DB
);
let read_pool = db::build_pool(
"nip05 reader",
&settings,
rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
1, // min conns
8, // max conns
true, // wait for DB
);
// setup hyper client // setup hyper client
let https = HttpsConnector::new(); let https = HttpsConnector::new();
let client = Client::builder().build::<_, hyper::Body>(https); let client = Client::builder().build::<_, hyper::Body>(https);
@ -175,10 +143,9 @@ impl Verifier {
// duration. // duration.
let reverify_interval = tokio::time::interval(http_wait_duration); let reverify_interval = tokio::time::interval(http_wait_duration);
Ok(Verifier { Ok(Verifier {
repo,
metadata_rx, metadata_rx,
event_tx, event_tx,
read_pool,
write_pool,
settings, settings,
client, client,
wait_after_finish, wait_after_finish,
@ -246,9 +213,7 @@ impl Verifier {
let response_fut = self.client.request(req); let response_fut = self.client.request(req);
// HTTP request with timeout if let Ok(response_res) = tokio::time::timeout(Duration::from_secs(5), response_fut).await {
match tokio::time::timeout(Duration::from_secs(5), response_fut).await {
Ok(response_res) => {
// limit size of verification document to 1MB. // limit size of verification document to 1MB.
const MAX_ALLOWED_RESPONSE_SIZE: u64 = 1024 * 1024; const MAX_ALLOWED_RESPONSE_SIZE: u64 = 1024 * 1024;
let response = response_res?; let response = response_res?;
@ -264,7 +229,7 @@ impl Verifier {
if parts.status == http::StatusCode::OK { if parts.status == http::StatusCode::OK {
// parse body, determine if the username / key / address is present // parse body, determine if the username / key / address is present
let body_bytes = hyper::body::to_bytes(body).await?; let body_bytes = hyper::body::to_bytes(body).await?;
let body_matches = body_contains_user(&nip.local, pubkey, body_bytes)?; let body_matches = body_contains_user(&nip.local, pubkey, &body_bytes)?;
if body_matches { if body_matches {
return Ok(UserWebVerificationStatus::Verified); return Ok(UserWebVerificationStatus::Verified);
} }
@ -279,12 +244,10 @@ impl Verifier {
nip.to_string() nip.to_string()
); );
} }
} } else {
Err(_) => {
info!("timeout verifying account {:?}", nip); info!("timeout verifying account {:?}", nip);
return Ok(UserWebVerificationStatus::Unknown); return Ok(UserWebVerificationStatus::Unknown);
} }
}
Ok(UserWebVerificationStatus::Unknown) Ok(UserWebVerificationStatus::Unknown)
} }
@ -309,7 +272,7 @@ impl Verifier {
if let Some(naddr) = e.get_nip05_addr() { if let Some(naddr) = e.get_nip05_addr() {
info!("got metadata event for ({:?},{:?})", naddr.to_string() ,e.get_author_prefix()); info!("got metadata event for ({:?},{:?})", naddr.to_string() ,e.get_author_prefix());
// Process a new author, checking if they are verified: // Process a new author, checking if they are verified:
let check_verified = get_latest_user_verification(self.read_pool.get().expect("could not get connection"), &e.pubkey).await; let check_verified = self.repo.get_latest_user_verification(&e.pubkey).await;
// ensure the event we got is more recent than the one we have, otherwise we can ignore it. // ensure the event we got is more recent than the one we have, otherwise we can ignore it.
if let Ok(last_check) = check_verified { if let Ok(last_check) = check_verified {
if e.created_at <= last_check.event_created { if e.created_at <= last_check.event_created {
@ -370,7 +333,7 @@ impl Verifier {
.duration_since(SystemTime::UNIX_EPOCH) .duration_since(SystemTime::UNIX_EPOCH)
.map(|x| x.as_secs()) .map(|x| x.as_secs())
.unwrap_or(0); .unwrap_or(0);
let vr = get_oldest_user_verification(self.read_pool.get()?, earliest_epoch).await; let vr = self.repo.get_oldest_user_verification(earliest_epoch).await;
match vr { match vr {
Ok(ref v) => { Ok(ref v) => {
let new_status = self.get_web_verification(&v.name, &v.address).await; let new_status = self.get_web_verification(&v.name, &v.address).await;
@ -378,8 +341,10 @@ impl Verifier {
UserWebVerificationStatus::Verified => { UserWebVerificationStatus::Verified => {
// freshly verified account, update the // freshly verified account, update the
// timestamp. // timestamp.
self.update_verification_record(self.write_pool.get()?, v) self.repo.update_verification_timestamp(v.rowid)
.await?; .await?;
info!("verification updated for {}", v.to_string());
} }
UserWebVerificationStatus::DomainNotAllowed UserWebVerificationStatus::DomainNotAllowed
| UserWebVerificationStatus::Unknown => { | UserWebVerificationStatus::Unknown => {
@ -394,18 +359,19 @@ impl Verifier {
"giving up on verifying {:?} after {} failures", "giving up on verifying {:?} after {} failures",
v.name, v.failure_count v.name, v.failure_count
); );
self.delete_verification_record(self.write_pool.get()?, v) self.repo.delete_verification(v.rowid)
.await?; .await?;
} else { } else {
// record normal failure, incrementing failure count // record normal failure, incrementing failure count
self.fail_verification_record(self.write_pool.get()?, v) info!("verification failed for {}", v.to_string());
.await?; self.repo.fail_verification(v.rowid).await?;
} }
} }
UserWebVerificationStatus::Unverified => { UserWebVerificationStatus::Unverified => {
// domain has removed the verification, drop // domain has removed the verification, drop
// the record on our side. // the record on our side.
self.delete_verification_record(self.write_pool.get()?, v) info!("verification rescinded for {}", v.to_string());
self.repo.delete_verification(v.rowid)
.await?; .await?;
} }
} }
@ -426,80 +392,6 @@ impl Verifier {
Ok(()) Ok(())
} }
/// Reset the verification timestamp on a VerificationRecord
pub async fn update_verification_record(
&mut self,
mut conn: db::PooledConnection,
vr: &VerificationRecord,
) -> Result<()> {
let vr_id = vr.rowid;
let vr_str = vr.to_string();
tokio::task::spawn_blocking(move || {
// add some jitter to the verification to prevent everything from stacking up together.
let verif_time = now_jitter(600);
let tx = conn.transaction()?;
{
// update verification time and reset any failure count
let query =
"UPDATE user_verification SET verified_at=?, failure_count=0 WHERE id=?";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![verif_time, vr_id])?;
}
tx.commit()?;
info!("verification updated for {}", vr_str);
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Reset the failure timestamp on a VerificationRecord
pub async fn fail_verification_record(
&mut self,
mut conn: db::PooledConnection,
vr: &VerificationRecord,
) -> Result<()> {
let vr_id = vr.rowid;
let vr_str = vr.to_string();
let fail_count = vr.failure_count.saturating_add(1);
tokio::task::spawn_blocking(move || {
// add some jitter to the verification to prevent everything from stacking up together.
let fail_time = now_jitter(600);
let tx = conn.transaction()?;
{
let query = "UPDATE user_verification SET failed_at=?, failure_count=? WHERE id=?";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![fail_time, fail_count, vr_id])?;
}
tx.commit()?;
info!("verification failed for {}", vr_str);
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Delete a VerificationRecord that is no longer valid
pub async fn delete_verification_record(
&mut self,
mut conn: db::PooledConnection,
vr: &VerificationRecord,
) -> Result<()> {
let vr_id = vr.rowid;
let vr_str = vr.to_string();
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
{
let query = "DELETE FROM user_verification WHERE id=?;";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![vr_id])?;
}
tx.commit()?;
info!("verification rescinded for {}", vr_str);
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Persist an event, create a verification record, and broadcast. /// Persist an event, create a verification record, and broadcast.
// TODO: have more event-writing logic handled in the db module. // TODO: have more event-writing logic handled in the db module.
// Right now, these events avoid the rate limit. That is // Right now, these events avoid the rate limit. That is
@ -513,7 +405,7 @@ impl Verifier {
// disabled/passive, the event has already been persisted. // disabled/passive, the event has already been persisted.
let should_write_event = self.settings.verified_users.is_enabled(); let should_write_event = self.settings.verified_users.is_enabled();
if should_write_event { if should_write_event {
match db::write_event(&mut self.write_pool.get()?, event) { match self.repo.write_event(event).await {
Ok(updated) => { Ok(updated) => {
if updated != 0 { if updated != 0 {
info!( info!(
@ -533,7 +425,7 @@ impl Verifier {
} }
} }
// write the verification record // write the verification record
save_verification_record(self.write_pool.get()?, event, name).await?; self.repo.create_verification_record(&event.id, name).await?;
Ok(()) Ok(())
} }
} }
@ -563,7 +455,7 @@ pub struct VerificationRecord {
/// Check with settings to determine if a given domain is allowed to /// Check with settings to determine if a given domain is allowed to
/// publish. /// publish.
pub fn is_domain_allowed( #[must_use] pub fn is_domain_allowed(
domain: &str, domain: &str,
whitelist: &Option<Vec<String>>, whitelist: &Option<Vec<String>>,
blacklist: &Option<Vec<String>>, blacklist: &Option<Vec<String>>,
@ -583,7 +475,7 @@ pub fn is_domain_allowed(
impl VerificationRecord { impl VerificationRecord {
/// Check if the record is recent enough to be considered valid, /// Check if the record is recent enough to be considered valid,
/// and the domain is allowed. /// and the domain is allowed.
pub fn is_valid(&self, verified_users_settings: &VerifiedUsers) -> bool { #[must_use] pub fn is_valid(&self, verified_users_settings: &VerifiedUsers) -> bool {
//let settings = SETTINGS.read().unwrap(); //let settings = SETTINGS.read().unwrap();
// how long a verification record is good for // how long a verification record is good for
let nip05_expiration = &verified_users_settings.verify_expiration_duration; let nip05_expiration = &verified_users_settings.verify_expiration_duration;
@ -630,130 +522,6 @@ impl std::fmt::Display for VerificationRecord {
} }
} }
/// Create a new verification record based on an event
pub async fn save_verification_record(
mut conn: db::PooledConnection,
event: &Event,
name: &str,
) -> Result<()> {
let e = hex::decode(&event.id).ok();
let n = name.to_owned();
let a_prefix = event.get_author_prefix();
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
{
// if we create a /new/ one, we should get rid of any old ones. or group the new ones by name and only consider the latest.
let query = "INSERT INTO user_verification (metadata_event, name, verified_at) VALUES ((SELECT id from event WHERE event_hash=?), ?, strftime('%s','now'));";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![e, n])?;
// get the row ID
let v_id = tx.last_insert_rowid();
// delete everything else by this name
let del_query = "DELETE FROM user_verification WHERE name = ? AND id != ?;";
let mut del_stmt = tx.prepare(del_query)?;
let count = del_stmt.execute(params![n,v_id])?;
if count > 0 {
info!("removed {} old verification records for ({:?},{:?})", count, n, a_prefix);
}
}
tx.commit()?;
info!("saved new verification record for ({:?},{:?})", n, a_prefix);
let ok: Result<()> = Ok(());
ok
}).await?
}
/// Retrieve the most recent verification record for a given pubkey (async).
pub async fn get_latest_user_verification(
conn: db::PooledConnection,
pubkey: &str,
) -> Result<VerificationRecord> {
let p = pubkey.to_owned();
tokio::task::spawn_blocking(move || query_latest_user_verification(conn, p)).await?
}
/// Query database for the latest verification record for a given pubkey.
pub fn query_latest_user_verification(
mut conn: db::PooledConnection,
pubkey: String,
) -> Result<VerificationRecord> {
let tx = conn.transaction()?;
let query = "SELECT v.id, v.name, e.event_hash, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v LEFT JOIN event e ON e.id=v.metadata_event WHERE e.author=? ORDER BY e.created_at DESC, v.verified_at DESC, v.failed_at DESC LIMIT 1;";
let mut stmt = tx.prepare_cached(query)?;
let fields = stmt.query_row(params![hex::decode(&pubkey).ok()], |r| {
let rowid: u64 = r.get(0)?;
let rowname: String = r.get(1)?;
let eventid: Vec<u8> = r.get(2)?;
let created_at: u64 = r.get(3)?;
// create a tuple since we can't throw non-rusqlite errors in this closure
Ok((
rowid,
rowname,
eventid,
created_at,
r.get(4).ok(),
r.get(5).ok(),
r.get(6)?,
))
})?;
Ok(VerificationRecord {
rowid: fields.0,
name: Nip05Name::try_from(&fields.1[..])?,
address: pubkey,
event: hex::encode(fields.2),
event_created: fields.3,
last_success: fields.4,
last_failure: fields.5,
failure_count: fields.6,
})
}
/// Retrieve the oldest user verification (async)
pub async fn get_oldest_user_verification(
conn: db::PooledConnection,
earliest: u64,
) -> Result<VerificationRecord> {
tokio::task::spawn_blocking(move || query_oldest_user_verification(conn, earliest)).await?
}
pub fn query_oldest_user_verification(
mut conn: db::PooledConnection,
earliest: u64,
) -> Result<VerificationRecord> {
let tx = conn.transaction()?;
let query = "SELECT v.id, v.name, e.event_hash, e.author, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v INNER JOIN event e ON e.id=v.metadata_event WHERE (v.verified_at < ? OR v.verified_at IS NULL) AND (v.failed_at < ? OR v.failed_at IS NULL) ORDER BY v.verified_at ASC, v.failed_at ASC LIMIT 1;";
let mut stmt = tx.prepare_cached(query)?;
let fields = stmt.query_row(params![earliest, earliest], |r| {
let rowid: u64 = r.get(0)?;
let rowname: String = r.get(1)?;
let eventid: Vec<u8> = r.get(2)?;
let pubkey: Vec<u8> = r.get(3)?;
let created_at: u64 = r.get(4)?;
// create a tuple since we can't throw non-rusqlite errors in this closure
Ok((
rowid,
rowname,
eventid,
pubkey,
created_at,
r.get(5).ok(),
r.get(6).ok(),
r.get(7)?,
))
})?;
let vr = VerificationRecord {
rowid: fields.0,
name: Nip05Name::try_from(&fields.1[..])?,
address: hex::encode(fields.3),
event: hex::encode(fields.2),
event_created: fields.4,
last_success: fields.5,
last_failure: fields.6,
failure_count: fields.7,
};
Ok(vr)
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -762,7 +530,7 @@ mod tests {
fn local_from_inet() { fn local_from_inet() {
let addr = "bob@example.com"; let addr = "bob@example.com";
let parsed = Nip05Name::try_from(addr); let parsed = Nip05Name::try_from(addr);
assert!(!parsed.is_err()); assert!(parsed.is_ok());
let v = parsed.unwrap(); let v = parsed.unwrap();
assert_eq!(v.local, "bob"); assert_eq!(v.local, "bob");
assert_eq!(v.domain, "example.com"); assert_eq!(v.domain, "example.com");

View File

@ -19,18 +19,14 @@ pub enum Notice {
} }
impl EventResultStatus { impl EventResultStatus {
pub fn to_bool(&self) -> bool { #[must_use] pub fn to_bool(&self) -> bool {
match self { match self {
Self::Saved => true, Self::Duplicate | Self::Saved => true,
Self::Duplicate => true, Self::Invalid |Self::Blocked | Self::RateLimited | Self::Error => false,
Self::Invalid => false,
Self::Blocked => false,
Self::RateLimited => false,
Self::Error => false,
} }
} }
pub fn prefix(&self) -> &'static str { #[must_use] pub fn prefix(&self) -> &'static str {
match self { match self {
Self::Saved => "saved", Self::Saved => "saved",
Self::Duplicate => "duplicate", Self::Duplicate => "duplicate",
@ -47,7 +43,7 @@ impl Notice {
// Notice::err_msg(format!("{}", err), id) // Notice::err_msg(format!("{}", err), id)
//} //}
pub fn message(msg: String) -> Notice { #[must_use] pub fn message(msg: String) -> Notice {
Notice::Message(msg) Notice::Message(msg)
} }
@ -56,27 +52,27 @@ impl Notice {
Notice::EventResult(EventResult { id, msg, status }) Notice::EventResult(EventResult { id, msg, status })
} }
pub fn invalid(id: String, msg: &str) -> Notice { #[must_use] pub fn invalid(id: String, msg: &str) -> Notice {
Notice::prefixed(id, msg, EventResultStatus::Invalid) Notice::prefixed(id, msg, EventResultStatus::Invalid)
} }
pub fn blocked(id: String, msg: &str) -> Notice { #[must_use] pub fn blocked(id: String, msg: &str) -> Notice {
Notice::prefixed(id, msg, EventResultStatus::Blocked) Notice::prefixed(id, msg, EventResultStatus::Blocked)
} }
pub fn rate_limited(id: String, msg: &str) -> Notice { #[must_use] pub fn rate_limited(id: String, msg: &str) -> Notice {
Notice::prefixed(id, msg, EventResultStatus::RateLimited) Notice::prefixed(id, msg, EventResultStatus::RateLimited)
} }
pub fn duplicate(id: String) -> Notice { #[must_use] pub fn duplicate(id: String) -> Notice {
Notice::prefixed(id, "", EventResultStatus::Duplicate) Notice::prefixed(id, "", EventResultStatus::Duplicate)
} }
pub fn error(id: String, msg: &str) -> Notice { #[must_use] pub fn error(id: String, msg: &str) -> Notice {
Notice::prefixed(id, msg, EventResultStatus::Error) Notice::prefixed(id, msg, EventResultStatus::Error)
} }
pub fn saved(id: String) -> Notice { #[must_use] pub fn saved(id: String) -> Notice {
Notice::EventResult(EventResult { Notice::EventResult(EventResult {
id, id,
msg: "".into(), msg: "".into(),

67
src/repo/mod.rs Normal file
View File

@ -0,0 +1,67 @@
use crate::db::QueryResult;
use crate::error::Result;
use crate::event::Event;
use crate::nip05::VerificationRecord;
use crate::subscription::Subscription;
use crate::utils::unix_time;
use async_trait::async_trait;
use rand::Rng;
pub mod sqlite;
pub mod sqlite_migration;
#[async_trait]
pub trait NostrRepo: Send + Sync {
/// Start the repository (any initialization or maintenance tasks can be kicked off here)
async fn start(&self) -> Result<()>;
/// Run migrations and return current version
async fn migrate_up(&self) -> Result<usize>;
/// Persist event to database
async fn write_event(&self, e: &Event) -> Result<u64>;
/// Perform a database query using a subscription.
///
/// The [`Subscription`] is converted into a SQL query. Each result
/// is published on the `query_tx` channel as it is returned. If a
/// message becomes available on the `abandon_query_rx` channel, the
/// query is immediately aborted.
async fn query_subscription(
&self,
sub: Subscription,
client_id: String,
query_tx: tokio::sync::mpsc::Sender<QueryResult>,
mut abandon_query_rx: tokio::sync::oneshot::Receiver<()>,
) -> Result<()>;
/// Perform normal maintenance
async fn optimize_db(&self) -> Result<()>;
/// Create a new verification record connected to a specific event
async fn create_verification_record(&self, event_id: &str, name: &str) -> Result<()>;
/// Update verification timestamp
async fn update_verification_timestamp(&self, id: u64) -> Result<()>;
/// Update verification record as failed
async fn fail_verification(&self, id: u64) -> Result<()>;
/// Delete verification record
async fn delete_verification(&self, id: u64) -> Result<()>;
/// Get the latest verification record for a given pubkey.
async fn get_latest_user_verification(&self, pub_key: &str) -> Result<VerificationRecord>;
/// Get oldest verification before timestamp
async fn get_oldest_user_verification(&self, before: u64) -> Result<VerificationRecord>;
}
// Current time, with a slight forward jitter in seconds
pub(crate) fn now_jitter(sec: u64) -> u64 {
// random time between now, and 10min in future.
let mut rng = rand::thread_rng();
let jitter_amount = rng.gen_range(0..sec);
let now = unix_time();
now.saturating_add(jitter_amount)
}

948
src/repo/sqlite.rs Normal file
View File

@ -0,0 +1,948 @@
//! Event persistence and querying
//use crate::config::SETTINGS;
use crate::config::Settings;
use crate::error::Result;
use crate::event::{single_char_tagname, Event};
use crate::hexrange::hex_range;
use crate::hexrange::HexSearch;
use crate::repo::sqlite_migration::{STARTUP_SQL,upgrade_db};
use crate::utils::{is_hex, is_lower_hex};
use crate::nip05::{Nip05Name, VerificationRecord};
use crate::subscription::{ReqFilter, Subscription};
use hex;
use r2d2;
use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::params;
use rusqlite::types::ToSql;
use rusqlite::OpenFlags;
use tokio::sync::{Mutex, MutexGuard};
use std::fmt::Write as _;
use std::path::Path;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use std::time::Instant;
use tokio::task;
use tracing::{debug, info, trace, warn};
use async_trait::async_trait;
use crate::db::QueryResult;
use crate::repo::{now_jitter, NostrRepo};
pub type SqlitePool = r2d2::Pool<r2d2_sqlite::SqliteConnectionManager>;
pub type PooledConnection = r2d2::PooledConnection<r2d2_sqlite::SqliteConnectionManager>;
pub const DB_FILE: &str = "nostr.db";
#[derive(Clone)]
pub struct SqliteRepo {
/// Pool for reading events and NIP-05 status
read_pool: SqlitePool,
/// Pool for writing events and NIP-05 verification
write_pool: SqlitePool,
/// Pool for performing checkpoints/optimization
maint_pool: SqlitePool,
/// Flag to indicate a checkpoint is underway
checkpoint_in_progress: Arc<Mutex<u64>>,
/// Flag to limit writer concurrency
write_in_progress: Arc<Mutex<u64>>,
}
impl SqliteRepo {
// build all the pools needed
#[must_use] pub fn new(settings: &Settings) -> SqliteRepo {
let maint_pool = build_pool(
"maintenance",
settings,
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
1,
2,
true,
);
let read_pool = build_pool(
"reader",
settings,
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
settings.database.min_conn,
settings.database.max_conn,
true,
);
let write_pool = build_pool(
"writer",
settings,
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
1,
2,
false,
);
// this is used to block new reads during critical checkpoints
let checkpoint_in_progress = Arc::new(Mutex::new(0));
// SQLite can only effectively write single threaded, so don't
// block multiple worker threads unnecessarily.
let write_in_progress = Arc::new(Mutex::new(0));
SqliteRepo {
read_pool,
write_pool,
maint_pool,
checkpoint_in_progress,
write_in_progress,
}
}
/// Persist an event to the database, returning rows added.
pub fn persist_event(conn: &mut PooledConnection, e: &Event) -> Result<u64> {
// enable auto vacuum
conn.execute_batch("pragma auto_vacuum = FULL")?;
// start transaction
let tx = conn.transaction()?;
// get relevant fields from event and convert to blobs.
let id_blob = hex::decode(&e.id).ok();
let pubkey_blob: Option<Vec<u8>> = hex::decode(&e.pubkey).ok();
let delegator_blob: Option<Vec<u8>> = e.delegated_by.as_ref().and_then(|d| hex::decode(d).ok());
let event_str = serde_json::to_string(&e).ok();
// check for replaceable events that would hide this one; we won't even attempt to insert these.
if e.is_replaceable() {
let repl_count = tx.query_row(
"SELECT e.id FROM event e INDEXED BY author_index WHERE e.author=? AND e.kind=? AND e.created_at > ? LIMIT 1;",
params![pubkey_blob, e.kind, e.created_at], |row| row.get::<usize, usize>(0));
if repl_count.ok().is_some() {
return Ok(0);
}
}
// ignore if the event hash is a duplicate.
let mut ins_count = tx.execute(
"INSERT OR IGNORE INTO event (event_hash, created_at, kind, author, delegated_by, content, first_seen, hidden) VALUES (?1, ?2, ?3, ?4, ?5, ?6, strftime('%s','now'), FALSE);",
params![id_blob, e.created_at, e.kind, pubkey_blob, delegator_blob, event_str]
)? as u64;
if ins_count == 0 {
// if the event was a duplicate, no need to insert event or
// pubkey references.
tx.rollback().ok();
return Ok(ins_count);
}
// remember primary key of the event most recently inserted.
let ev_id = tx.last_insert_rowid();
// add all tags to the tag table
for tag in &e.tags {
// ensure we have 2 values.
if tag.len() >= 2 {
let tagname = &tag[0];
let tagval = &tag[1];
// only single-char tags are searchable
let tagchar_opt = single_char_tagname(tagname);
match &tagchar_opt {
Some(_) => {
// if tagvalue is lowercase hex;
if is_lower_hex(tagval) && (tagval.len() % 2 == 0) {
tx.execute(
"INSERT OR IGNORE INTO tag (event_id, name, value_hex) VALUES (?1, ?2, ?3)",
params![ev_id, &tagname, hex::decode(tagval).ok()],
)?;
} else {
tx.execute(
"INSERT OR IGNORE INTO tag (event_id, name, value) VALUES (?1, ?2, ?3)",
params![ev_id, &tagname, &tagval],
)?;
}
}
None => {}
}
}
}
// if this event is replaceable update, remove other replaceable
// event with the same kind from the same author that was issued
// earlier than this.
if e.is_replaceable() {
let author = hex::decode(&e.pubkey).ok();
// this is a backwards check - hide any events that were older.
let update_count = tx.execute(
"DELETE FROM event WHERE kind=? and author=? and id NOT IN (SELECT id FROM event INDEXED BY author_kind_index WHERE kind=? AND author=? ORDER BY created_at DESC LIMIT 1)",
params![e.kind, author, e.kind, author],
)?;
if update_count > 0 {
info!(
"removed {} older replaceable kind {} events for author: {:?}",
update_count,
e.kind,
e.get_author_prefix()
);
}
}
// if this event is a deletion, hide the referenced events from the same author.
if e.kind == 5 {
let event_candidates = e.tag_values_by_name("e");
// first parameter will be author
let mut params: Vec<Box<dyn ToSql>> = vec![Box::new(hex::decode(&e.pubkey)?)];
event_candidates
.iter()
.filter(|x| is_hex(x) && x.len() == 64)
.filter_map(|x| hex::decode(x).ok())
.for_each(|x| params.push(Box::new(x)));
let query = format!(
"UPDATE event SET hidden=TRUE WHERE kind!=5 AND author=? AND event_hash IN ({})",
repeat_vars(params.len() - 1)
);
let mut stmt = tx.prepare(&query)?;
let update_count = stmt.execute(rusqlite::params_from_iter(params))?;
info!(
"hid {} deleted events for author {:?}",
update_count,
e.get_author_prefix()
);
} else {
// check if a deletion has already been recorded for this event.
// Only relevant for non-deletion events
let del_count = tx.query_row(
"SELECT e.id FROM event e LEFT JOIN tag t ON e.id=t.event_id WHERE e.author=? AND t.name='e' AND e.kind=5 AND t.value_hex=? LIMIT 1;",
params![pubkey_blob, id_blob], |row| row.get::<usize, usize>(0));
// check if a the query returned a result, meaning we should
// hid the current event
if del_count.ok().is_some() {
// a deletion already existed, mark original event as hidden.
info!(
"hid event: {:?} due to existing deletion by author: {:?}",
e.get_event_id_prefix(),
e.get_author_prefix()
);
let _update_count =
tx.execute("UPDATE event SET hidden=TRUE WHERE id=?", params![ev_id])?;
// event was deleted, so let caller know nothing new
// arrived, preventing this from being sent to active
// subscriptions
ins_count = 0;
}
}
tx.commit()?;
Ok(ins_count)
}
}
#[async_trait]
impl NostrRepo for SqliteRepo {
async fn start(&self) -> Result<()> {
db_checkpoint_task(self.maint_pool.clone(), Duration::from_secs(60), self.checkpoint_in_progress.clone()).await
}
async fn migrate_up(&self) -> Result<usize> {
let _write_guard = self.write_in_progress.lock().await;
let mut conn = self.write_pool.get()?;
task::spawn_blocking(move || {
upgrade_db(&mut conn)
}).await?
}
/// Persist event to database
async fn write_event(&self, e: &Event) -> Result<u64> {
let _write_guard = self.write_in_progress.lock().await;
// spawn a blocking thread
let mut conn = self.write_pool.get()?;
let e = e.clone();
task::spawn_blocking(move || {
SqliteRepo::persist_event(&mut conn, &e)
}).await?
}
/// Perform a database query using a subscription.
///
/// The [`Subscription`] is converted into a SQL query. Each result
/// is published on the `query_tx` channel as it is returned. If a
/// message becomes available on the `abandon_query_rx` channel, the
/// query is immediately aborted.
async fn query_subscription(
&self,
sub: Subscription,
client_id: String,
query_tx: tokio::sync::mpsc::Sender<QueryResult>,
mut abandon_query_rx: tokio::sync::oneshot::Receiver<()>,
) -> Result<()> {
let pre_spawn_start = Instant::now();
let self=self.clone();
task::spawn_blocking(move || {
{
// if we are waiting on a checkpoint, stop until it is complete
let _x = self.checkpoint_in_progress.blocking_lock();
}
let db_queue_time = pre_spawn_start.elapsed();
// if the queue time was very long (>5 seconds), spare the DB and abort.
if db_queue_time > Duration::from_secs(5) {
info!(
"shedding DB query load queued for {:?} (cid: {}, sub: {:?})",
db_queue_time, client_id, sub.id
);
return Ok(());
}
// otherwise, report queuing time if it is slow
else if db_queue_time > Duration::from_secs(1) {
debug!(
"(slow) DB query queued for {:?} (cid: {}, sub: {:?})",
db_queue_time, client_id, sub.id
);
}
let start = Instant::now();
let mut row_count: usize = 0;
// generate SQL query
let (q, p, idxs) = query_from_sub(&sub);
let sql_gen_elapsed = start.elapsed();
if sql_gen_elapsed > Duration::from_millis(10) {
debug!("SQL (slow) generated in {:?}", start.elapsed());
}
// cutoff for displaying slow queries
let slow_cutoff = Duration::from_millis(2000);
// any client that doesn't cause us to generate new rows in 5
// seconds gets dropped.
let abort_cutoff = Duration::from_secs(5);
let start = Instant::now();
let mut slow_first_event;
let mut last_successful_send = Instant::now();
if let Ok(mut conn) = self.read_pool.get() {
// execute the query.
// make the actual SQL query (with parameters inserted) available
conn.trace(Some(|x| {trace!("SQL trace: {:?}", x)}));
let mut stmt = conn.prepare_cached(&q)?;
let mut event_rows = stmt.query(rusqlite::params_from_iter(p))?;
let mut first_result = true;
while let Some(row) = event_rows.next()? {
let first_event_elapsed = start.elapsed();
slow_first_event = first_event_elapsed >= slow_cutoff;
if first_result {
debug!(
"first result in {:?} (cid: {}, sub: {:?}) [used indexes: {:?}]",
first_event_elapsed, client_id, sub.id, idxs
);
first_result = false;
}
// logging for slow queries; show sub and SQL.
// to reduce logging; only show 1/16th of clients (leading 0)
if row_count == 0 && slow_first_event && client_id.starts_with('0') {
debug!(
"query req (slow): {:?} (cid: {}, sub: {:?})",
sub, client_id, sub.id
);
}
// check if a checkpoint is trying to run, and abort
if row_count % 100 == 0 {
{
if self.checkpoint_in_progress.try_lock().is_err() {
// lock was held, abort this query
debug!("query aborted due to checkpoint (cid: {}, sub: {:?})", client_id, sub.id);
return Ok(());
}
}
}
// check if this is still active; every 100 rows
if row_count % 100 == 0 && abandon_query_rx.try_recv().is_ok() {
debug!("query aborted (cid: {}, sub: {:?})", client_id, sub.id);
return Ok(());
}
row_count += 1;
let event_json = row.get(0)?;
loop {
if query_tx.capacity() != 0 {
// we have capacity to add another item
break;
}
// the queue is full
trace!("db reader thread is stalled");
if last_successful_send + abort_cutoff < Instant::now() {
// the queue has been full for too long, abort
info!("aborting database query due to slow client (cid: {}, sub: {:?})",
client_id, sub.id);
let ok: Result<()> = Ok(());
return ok;
}
// check if a checkpoint is trying to run, and abort
if self.checkpoint_in_progress.try_lock().is_err() {
// lock was held, abort this query
debug!("query aborted due to checkpoint (cid: {}, sub: {:?})", client_id, sub.id);
return Ok(());
}
// give the queue a chance to clear before trying again
thread::sleep(Duration::from_millis(100));
}
// TODO: we could use try_send, but we'd have to juggle
// getting the query result back as part of the error
// result.
query_tx
.blocking_send(QueryResult {
sub_id: sub.get_id(),
event: event_json,
})
.ok();
last_successful_send = Instant::now();
}
query_tx
.blocking_send(QueryResult {
sub_id: sub.get_id(),
event: "EOSE".to_string(),
})
.ok();
debug!(
"query completed in {:?} (cid: {}, sub: {:?}, db_time: {:?}, rows: {})",
pre_spawn_start.elapsed(),
client_id,
sub.id,
start.elapsed(),
row_count
);
} else {
warn!("Could not get a database connection for querying");
}
let ok: Result<()> = Ok(());
ok
});
Ok(())
}
/// Perform normal maintenance
async fn optimize_db(&self) -> Result<()> {
let conn = self.write_pool.get()?;
task::spawn_blocking(move || {
let start = Instant::now();
conn.execute_batch("PRAGMA optimize;").ok();
info!("optimize ran in {:?}", start.elapsed());
}).await?;
Ok(())
}
/// Create a new verification record connected to a specific event
async fn create_verification_record(&self, event_id: &str, name: &str) -> Result<()> {
let e = hex::decode(event_id).ok();
let n = name.to_owned();
let mut conn = self.write_pool.get()?;
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
{
// if we create a /new/ one, we should get rid of any old ones. or group the new ones by name and only consider the latest.
let query = "INSERT INTO user_verification (metadata_event, name, verified_at) VALUES ((SELECT id from event WHERE event_hash=?), ?, strftime('%s','now'));";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![e, n])?;
// get the row ID
let v_id = tx.last_insert_rowid();
// delete everything else by this name
let del_query = "DELETE FROM user_verification WHERE name = ? AND id != ?;";
let mut del_stmt = tx.prepare(del_query)?;
let count = del_stmt.execute(params![n,v_id])?;
if count > 0 {
info!("removed {} old verification records for ({:?})", count, n);
}
}
tx.commit()?;
info!("saved new verification record for ({:?})", n);
let ok: Result<()> = Ok(());
ok
}).await?
}
/// Update verification timestamp
async fn update_verification_timestamp(&self, id: u64) -> Result<()> {
let mut conn = self.write_pool.get()?;
tokio::task::spawn_blocking(move || {
// add some jitter to the verification to prevent everything from stacking up together.
let verif_time = now_jitter(600);
let tx = conn.transaction()?;
{
// update verification time and reset any failure count
let query =
"UPDATE user_verification SET verified_at=?, failure_count=0 WHERE id=?";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![verif_time, id])?;
}
tx.commit()?;
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Update verification record as failed
async fn fail_verification(&self, id: u64) -> Result<()> {
let mut conn = self.write_pool.get()?;
tokio::task::spawn_blocking(move || {
// add some jitter to the verification to prevent everything from stacking up together.
let fail_time = now_jitter(600);
let tx = conn.transaction()?;
{
let query = "UPDATE user_verification SET failed_at=?, failure_count=failure_count+1 WHERE id=?";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![fail_time, id])?;
}
tx.commit()?;
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Delete verification record
async fn delete_verification(&self, id: u64) -> Result<()> {
let mut conn = self.write_pool.get()?;
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
{
let query = "DELETE FROM user_verification WHERE id=?;";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![id])?;
}
tx.commit()?;
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Get the latest verification record for a given pubkey.
async fn get_latest_user_verification(&self, pub_key: &str) -> Result<VerificationRecord> {
let mut conn = self.read_pool.get()?;
let pub_key = pub_key.to_owned();
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
let query = "SELECT v.id, v.name, e.event_hash, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v LEFT JOIN event e ON e.id=v.metadata_event WHERE e.author=? ORDER BY e.created_at DESC, v.verified_at DESC, v.failed_at DESC LIMIT 1;";
let mut stmt = tx.prepare_cached(query)?;
let fields = stmt.query_row(params![hex::decode(&pub_key).ok()], |r| {
let rowid: u64 = r.get(0)?;
let rowname: String = r.get(1)?;
let eventid: Vec<u8> = r.get(2)?;
let created_at: u64 = r.get(3)?;
// create a tuple since we can't throw non-rusqlite errors in this closure
Ok((
rowid,
rowname,
eventid,
created_at,
r.get(4).ok(),
r.get(5).ok(),
r.get(6)?,
))
})?;
Ok(VerificationRecord {
rowid: fields.0,
name: Nip05Name::try_from(&fields.1[..])?,
address: pub_key,
event: hex::encode(fields.2),
event_created: fields.3,
last_success: fields.4,
last_failure: fields.5,
failure_count: fields.6,
})
}).await?
}
/// Get oldest verification before timestamp
async fn get_oldest_user_verification(&self, before: u64) -> Result<VerificationRecord> {
let mut conn = self.read_pool.get()?;
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
let query = "SELECT v.id, v.name, e.event_hash, e.author, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v INNER JOIN event e ON e.id=v.metadata_event WHERE (v.verified_at < ? OR v.verified_at IS NULL) AND (v.failed_at < ? OR v.failed_at IS NULL) ORDER BY v.verified_at ASC, v.failed_at ASC LIMIT 1;";
let mut stmt = tx.prepare_cached(query)?;
let fields = stmt.query_row(params![before, before], |r| {
let rowid: u64 = r.get(0)?;
let rowname: String = r.get(1)?;
let eventid: Vec<u8> = r.get(2)?;
let pubkey: Vec<u8> = r.get(3)?;
let created_at: u64 = r.get(4)?;
// create a tuple since we can't throw non-rusqlite errors in this closure
Ok((
rowid,
rowname,
eventid,
pubkey,
created_at,
r.get(5).ok(),
r.get(6).ok(),
r.get(7)?,
))
})?;
let vr = VerificationRecord {
rowid: fields.0,
name: Nip05Name::try_from(&fields.1[..])?,
address: hex::encode(fields.3),
event: hex::encode(fields.2),
event_created: fields.4,
last_success: fields.5,
last_failure: fields.6,
failure_count: fields.7,
};
Ok(vr)
}).await?
}
}
/// Decide if there is an index that should be used explicitly
fn override_index(f: &ReqFilter) -> Option<String> {
// queries for multiple kinds default to kind_index, which is
// significantly slower than kind_created_at_index.
if let Some(ks) = &f.kinds {
if f.ids.is_none() &&
ks.len() > 1 &&
f.since.is_none() &&
f.until.is_none() &&
f.tags.is_none() &&
f.authors.is_none() {
return Some("kind_created_at_index".into());
}
}
// if there is an author, it is much better to force the authors index.
if f.authors.is_some() {
if f.since.is_none() && f.until.is_none() {
if f.kinds.is_none() {
// with no use of kinds/created_at, just author
return Some("author_index".into());
}
// prefer author_kind if there are kinds
return Some("author_kind_index".into());
}
// finally, prefer author_created_at if time is provided
return Some("author_created_at_index".into());
}
None
}
/// Create a dynamic SQL subquery and params from a subscription filter (and optional explicit index used)
fn query_from_filter(f: &ReqFilter) -> (String, Vec<Box<dyn ToSql>>, Option<String>) {
// build a dynamic SQL query. all user-input is either an integer
// (sqli-safe), or a string that is filtered to only contain
// hexadecimal characters. Strings that require escaping (tag
// names/values) use parameters.
// if the filter is malformed, don't return anything.
if f.force_no_match {
let empty_query = "SELECT e.content, e.created_at FROM event e WHERE 1=0".to_owned();
// query parameters for SQLite
let empty_params: Vec<Box<dyn ToSql>> = vec![];
return (empty_query, empty_params, None);
}
// check if the index needs to be overriden
let idx_name = override_index(f);
let idx_stmt = idx_name.as_ref().map_or_else(|| "".to_owned(), |i| format!("INDEXED BY {}",i));
let mut query = format!("SELECT e.content, e.created_at FROM event e {}", idx_stmt);
// query parameters for SQLite
let mut params: Vec<Box<dyn ToSql>> = vec![];
// individual filter components (single conditions such as an author or event ID)
let mut filter_components: Vec<String> = Vec::new();
// Query for "authors", allowing prefix matches
if let Some(authvec) = &f.authors {
// take each author and convert to a hexsearch
let mut auth_searches: Vec<String> = vec![];
for auth in authvec {
match hex_range(auth) {
Some(HexSearch::Exact(ex)) => {
auth_searches.push("author=?".to_owned());
params.push(Box::new(ex));
}
Some(HexSearch::Range(lower, upper)) => {
auth_searches.push(
"(author>? AND author<?)".to_owned(),
);
params.push(Box::new(lower));
params.push(Box::new(upper));
}
Some(HexSearch::LowerOnly(lower)) => {
auth_searches.push("author>?".to_owned());
params.push(Box::new(lower));
}
None => {
info!("Could not parse hex range from author {:?}", auth);
}
}
}
if !authvec.is_empty() {
let auth_clause = format!("({})", auth_searches.join(" OR "));
filter_components.push(auth_clause);
} else {
filter_components.push("false".to_owned());
}
}
// Query for Kind
if let Some(ks) = &f.kinds {
// kind is number, no escaping needed
let str_kinds: Vec<String> = ks.iter().map(std::string::ToString::to_string).collect();
let kind_clause = format!("kind IN ({})", str_kinds.join(", "));
filter_components.push(kind_clause);
}
// Query for event, allowing prefix matches
if let Some(idvec) = &f.ids {
// take each author and convert to a hexsearch
let mut id_searches: Vec<String> = vec![];
for id in idvec {
match hex_range(id) {
Some(HexSearch::Exact(ex)) => {
id_searches.push("event_hash=?".to_owned());
params.push(Box::new(ex));
}
Some(HexSearch::Range(lower, upper)) => {
id_searches.push("(event_hash>? AND event_hash<?)".to_owned());
params.push(Box::new(lower));
params.push(Box::new(upper));
}
Some(HexSearch::LowerOnly(lower)) => {
id_searches.push("event_hash>?".to_owned());
params.push(Box::new(lower));
}
None => {
info!("Could not parse hex range from id {:?}", id);
}
}
}
if idvec.is_empty() {
// if the ids list was empty, we should never return
// any results.
filter_components.push("false".to_owned());
} else {
let id_clause = format!("({})", id_searches.join(" OR "));
filter_components.push(id_clause);
}
}
// Query for tags
if let Some(map) = &f.tags {
for (key, val) in map.iter() {
let mut str_vals: Vec<Box<dyn ToSql>> = vec![];
let mut blob_vals: Vec<Box<dyn ToSql>> = vec![];
for v in val {
if (v.len() % 2 == 0) && is_lower_hex(v) {
if let Ok(h) = hex::decode(v) {
blob_vals.push(Box::new(h));
}
} else {
str_vals.push(Box::new(v.clone()));
}
}
// create clauses with "?" params for each tag value being searched
let str_clause = format!("value IN ({})", repeat_vars(str_vals.len()));
let blob_clause = format!("value_hex IN ({})", repeat_vars(blob_vals.len()));
// find evidence of the target tag name/value existing for this event.
let tag_clause = format!(
"e.id IN (SELECT e.id FROM event e LEFT JOIN tag t on e.id=t.event_id WHERE hidden!=TRUE and (name=? AND ({} OR {})))",
str_clause, blob_clause
);
// add the tag name as the first parameter
params.push(Box::new(key.to_string()));
// add all tag values that are plain strings as params
params.append(&mut str_vals);
// add all tag values that are blobs as params
params.append(&mut blob_vals);
filter_components.push(tag_clause);
}
}
// Query for timestamp
if f.since.is_some() {
let created_clause = format!("created_at > {}", f.since.unwrap());
filter_components.push(created_clause);
}
// Query for timestamp
if f.until.is_some() {
let until_clause = format!("created_at < {}", f.until.unwrap());
filter_components.push(until_clause);
}
// never display hidden events
query.push_str(" WHERE hidden!=TRUE");
// build filter component conditions
if !filter_components.is_empty() {
query.push_str(" AND ");
query.push_str(&filter_components.join(" AND "));
}
// Apply per-filter limit to this subquery.
// The use of a LIMIT implies a DESC order, to capture only the most recent events.
if let Some(lim) = f.limit {
let _ = write!(query, " ORDER BY e.created_at DESC LIMIT {}", lim);
} else {
query.push_str(" ORDER BY e.created_at ASC");
}
(query, params, idx_name)
}
/// Create a dynamic SQL query string and params from a subscription.
fn query_from_sub(sub: &Subscription) -> (String, Vec<Box<dyn ToSql>>, Vec<String>) {
// build a dynamic SQL query for an entire subscription, based on
// SQL subqueries for filters.
let mut subqueries: Vec<String> = Vec::new();
let mut indexes = vec![];
// subquery params
let mut params: Vec<Box<dyn ToSql>> = vec![];
// for every filter in the subscription, generate a subquery
for f in &sub.filters {
let (f_subquery, mut f_params, index) = query_from_filter(f);
if let Some(i) = index {
indexes.push(i);
}
subqueries.push(f_subquery);
params.append(&mut f_params);
}
// encapsulate subqueries into select statements
let subqueries_selects: Vec<String> = subqueries
.iter()
.map(|s| format!("SELECT distinct content, created_at FROM ({})", s))
.collect();
let query: String = subqueries_selects.join(" UNION ");
(query, params,indexes)
}
/// Build a database connection pool.
/// # Panics
///
/// Will panic if the pool could not be created.
#[must_use]
pub fn build_pool(
name: &str,
settings: &Settings,
flags: OpenFlags,
min_size: u32,
max_size: u32,
wait_for_db: bool,
) -> SqlitePool {
let db_dir = &settings.database.data_directory;
let full_path = Path::new(db_dir).join(DB_FILE);
// small hack; if the database doesn't exist yet, that means the
// writer thread hasn't finished. Give it a chance to work. This
// is only an issue with the first time we run.
if !settings.database.in_memory {
while !full_path.exists() && wait_for_db {
debug!("Database reader pool is waiting on the database to be created...");
thread::sleep(Duration::from_millis(500));
}
}
let manager = if settings.database.in_memory {
SqliteConnectionManager::memory()
.with_flags(flags)
.with_init(|c| c.execute_batch(STARTUP_SQL))
} else {
SqliteConnectionManager::file(&full_path)
.with_flags(flags)
.with_init(|c| c.execute_batch(STARTUP_SQL))
};
let pool: SqlitePool = r2d2::Pool::builder()
.test_on_check_out(true) // no noticeable performance hit
.min_idle(Some(min_size))
.max_size(max_size)
.max_lifetime(Some(Duration::from_secs(30)))
.build(manager)
.unwrap();
info!(
"Built a connection pool {:?} (min={}, max={})",
name, min_size, max_size
);
pool
}
/// Perform database WAL checkpoint on a regular basis
pub async fn db_checkpoint_task(pool: SqlitePool, frequency: Duration, checkpoint_in_progress: Arc<Mutex<u64>>) -> Result<()> {
tokio::task::spawn(async move {
// WAL size in pages.
let mut current_wal_size = 0;
// WAL threshold for more aggressive checkpointing (10,000 pages, or about 40MB)
let wal_threshold = 1000*10;
// default threshold for the busy timer
let busy_wait_default = Duration::from_secs(1);
// if the WAL file is getting too big, switch to this
let busy_wait_default_long = Duration::from_secs(10);
loop {
tokio::select! {
_ = tokio::time::sleep(frequency) => {
if let Ok(mut conn) = pool.get() {
let mut _guard:Option<MutexGuard<u64>> = None;
// the busy timer will block writers, so don't set
// this any higher than you want max latency for event
// writes.
if current_wal_size <= wal_threshold {
conn.busy_timeout(busy_wait_default).ok();
} else {
// if the wal size has exceeded a threshold, increase the busy timeout.
conn.busy_timeout(busy_wait_default_long).ok();
// take a lock that will prevent new readers.
info!("blocking new readers to perform wal_checkpoint");
_guard = Some(checkpoint_in_progress.lock().await);
}
debug!("running wal_checkpoint(TRUNCATE)");
if let Ok(new_size) = checkpoint_db(&mut conn) {
current_wal_size = new_size;
}
}
}
};
}
});
Ok(())
}
#[derive(Debug)]
enum SqliteStatus {
Ok,
Busy,
Error,
Other(u64),
}
/// Checkpoint/Truncate WAL. Returns the number of WAL pages remaining.
pub fn checkpoint_db(conn: &mut PooledConnection) -> Result<usize> {
let query = "PRAGMA wal_checkpoint(TRUNCATE);";
let start = Instant::now();
let (cp_result, wal_size, _frames_checkpointed) = conn.query_row(query, [], |row| {
let checkpoint_result: u64 = row.get(0)?;
let wal_size: u64 = row.get(1)?;
let frames_checkpointed: u64 = row.get(2)?;
Ok((checkpoint_result, wal_size, frames_checkpointed))
})?;
let result = match cp_result {
0 => SqliteStatus::Ok,
1 => SqliteStatus::Busy,
2 => SqliteStatus::Error,
x => SqliteStatus::Other(x),
};
info!(
"checkpoint ran in {:?} (result: {:?}, WAL size: {})",
start.elapsed(),
result,
wal_size
);
Ok(wal_size as usize)
}
/// Produce a arbitrary list of '?' parameters.
fn repeat_vars(count: usize) -> String {
if count == 0 {
return "".to_owned();
}
let mut s = "?,".repeat(count);
// Remove trailing comma
s.pop();
s
}
/// Display database pool stats every 1 minute
pub async fn monitor_pool(name: &str, pool: SqlitePool) {
let sleep_dur = Duration::from_secs(60);
loop {
log_pool_stats(name, &pool);
tokio::time::sleep(sleep_dur).await;
}
}
/// Log pool stats
fn log_pool_stats(name: &str, pool: &SqlitePool) {
let state: r2d2::State = pool.state();
let in_use_cxns = state.connections - state.idle_connections;
debug!(
"DB pool {:?} usage (in_use: {}, available: {}, max: {})",
name,
in_use_cxns,
state.connections,
pool.max_size()
);
}
/// Check if the pool is fully utilized
fn _pool_at_capacity(pool: &SqlitePool) -> bool {
let state: r2d2::State = pool.state();
state.idle_connections == 0
}

View File

@ -113,7 +113,7 @@ pub fn db_tag_count(conn: &mut Connection) -> Result<usize> {
Ok(count) Ok(count)
} }
fn mig_init(conn: &mut PooledConnection) -> Result<usize> { fn mig_init(conn: &mut PooledConnection) -> usize {
match conn.execute_batch(INIT_SQL) { match conn.execute_batch(INIT_SQL) {
Ok(()) => { Ok(()) => {
info!( info!(
@ -126,11 +126,11 @@ fn mig_init(conn: &mut PooledConnection) -> Result<usize> {
panic!("database could not be initialized"); panic!("database could not be initialized");
} }
} }
Ok(DB_VERSION) DB_VERSION
} }
/// Upgrade DB to latest version, and execute pragma settings /// Upgrade DB to latest version, and execute pragma settings
pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> { pub fn upgrade_db(conn: &mut PooledConnection) -> Result<usize> {
// check the version. // check the version.
let mut curr_version = curr_db_version(conn)?; let mut curr_version = curr_db_version(conn)?;
info!("DB version = {:?}", curr_version); info!("DB version = {:?}", curr_version);
@ -141,11 +141,11 @@ pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> {
); );
debug!( debug!(
"SQLite max table/blob/text length: {} MB", "SQLite max table/blob/text length: {} MB",
(conn.limit(Limit::SQLITE_LIMIT_LENGTH) as f64 / (1024 * 1024) as f64).floor() (f64::from(conn.limit(Limit::SQLITE_LIMIT_LENGTH)) / f64::from(1024 * 1024)).floor()
); );
debug!( debug!(
"SQLite max SQL length: {} MB", "SQLite max SQL length: {} MB",
(conn.limit(Limit::SQLITE_LIMIT_SQL_LENGTH) as f64 / (1024 * 1024) as f64).floor() (f64::from(conn.limit(Limit::SQLITE_LIMIT_SQL_LENGTH)) / f64::from(1024 * 1024)).floor()
); );
match curr_version.cmp(&DB_VERSION) { match curr_version.cmp(&DB_VERSION) {
@ -153,7 +153,7 @@ pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> {
Ordering::Less => { Ordering::Less => {
// initialize from scratch // initialize from scratch
if curr_version == 0 { if curr_version == 0 {
curr_version = mig_init(conn)?; curr_version = mig_init(conn);
} }
// for initialized but out-of-date schemas, proceed to // for initialized but out-of-date schemas, proceed to
// upgrade sequentially until we are current. // upgrade sequentially until we are current.
@ -223,7 +223,7 @@ pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> {
// Setup PRAGMA // Setup PRAGMA
conn.execute_batch(STARTUP_SQL)?; conn.execute_batch(STARTUP_SQL)?;
debug!("SQLite PRAGMA startup completed"); debug!("SQLite PRAGMA startup completed");
Ok(()) Ok(DB_VERSION)
} }
pub fn rebuild_tags(conn: &mut PooledConnection) -> Result<()> { pub fn rebuild_tags(conn: &mut PooledConnection) -> Result<()> {

View File

@ -3,6 +3,7 @@ use crate::close::Close;
use crate::close::CloseCmd; use crate::close::CloseCmd;
use crate::config::{Settings, VerifiedUsersMode}; use crate::config::{Settings, VerifiedUsersMode};
use crate::conn; use crate::conn;
use crate::repo::NostrRepo;
use crate::db; use crate::db;
use crate::db::SubmittedEvent; use crate::db::SubmittedEvent;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
@ -22,10 +23,8 @@ use hyper::upgrade::Upgraded;
use hyper::{ use hyper::{
header, server::conn::AddrStream, upgrade, Body, Request, Response, Server, StatusCode, header, server::conn::AddrStream, upgrade, Body, Request, Response, Server, StatusCode,
}; };
use rusqlite::OpenFlags;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use tokio::sync::Mutex;
use std::collections::HashMap; use std::collections::HashMap;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
@ -40,23 +39,22 @@ use tokio::sync::broadcast::{self, Receiver, Sender};
use tokio::sync::mpsc; use tokio::sync::mpsc;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::WebSocketStream;
use tracing::*; use tracing::{debug, error, info, trace, warn};
use tungstenite::error::CapacityError::MessageTooLong; use tungstenite::error::CapacityError::MessageTooLong;
use tungstenite::error::Error as WsError; use tungstenite::error::Error as WsError;
use tungstenite::handshake; use tungstenite::handshake;
use tungstenite::protocol::Message; use tungstenite::protocol::Message;
use tungstenite::protocol::WebSocketConfig; use tungstenite::protocol::WebSocketConfig;
/// Handle arbitrary HTTP requests, including for WebSocket upgrades. /// Handle arbitrary HTTP requests, including for `WebSocket` upgrades.
async fn handle_web_request( async fn handle_web_request(
mut request: Request<Body>, mut request: Request<Body>,
pool: db::SqlitePool, repo: Arc<dyn NostrRepo>,
settings: Settings, settings: Settings,
remote_addr: SocketAddr, remote_addr: SocketAddr,
broadcast: Sender<Event>, broadcast: Sender<Event>,
event_tx: tokio::sync::mpsc::Sender<SubmittedEvent>, event_tx: tokio::sync::mpsc::Sender<SubmittedEvent>,
shutdown: Receiver<()>, shutdown: Receiver<()>,
safe_to_read: Arc<Mutex<u64>>,
) -> Result<Response<Body>, Infallible> { ) -> Result<Response<Body>, Infallible> {
match ( match (
request.uri().path(), request.uri().path(),
@ -111,14 +109,13 @@ async fn handle_web_request(
}; };
// spawn a nostr server with our websocket // spawn a nostr server with our websocket
tokio::spawn(nostr_server( tokio::spawn(nostr_server(
pool, repo,
client_info, client_info,
settings, settings,
ws_stream, ws_stream,
broadcast, broadcast,
event_tx, event_tx,
shutdown, shutdown,
safe_to_read,
)); ));
} }
// todo: trace, don't print... // todo: trace, don't print...
@ -184,7 +181,7 @@ async fn handle_web_request(
fn get_header_string(header: &str, headers: &HeaderMap) -> Option<String> { fn get_header_string(header: &str, headers: &HeaderMap) -> Option<String> {
headers headers
.get(header) .get(header)
.and_then(|x| x.to_str().ok().map(|x| x.to_string())) .and_then(|x| x.to_str().ok().map(std::string::ToString::to_string))
} }
// return on a control-c or internally requested shutdown signal // return on a control-c or internally requested shutdown signal
@ -211,7 +208,7 @@ async fn ctrl_c_or_signal(mut shutdown_signal: Receiver<()>) {
} }
/// Start running a Nostr relay server. /// Start running a Nostr relay server.
pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result<(), Error> { pub fn start_server(settings: &Settings, shutdown_rx: MpscReceiver<()>) -> Result<(), Error> {
trace!("Config: {:?}", settings); trace!("Config: {:?}", settings);
// do some config validation. // do some config validation.
if !Path::new(&settings.database.data_directory).is_dir() { if !Path::new(&settings.database.data_directory).is_dir() {
@ -274,8 +271,6 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result
let broadcast_buffer_limit = settings.limits.broadcast_buffer; let broadcast_buffer_limit = settings.limits.broadcast_buffer;
let persist_buffer_limit = settings.limits.event_persist_buffer; let persist_buffer_limit = settings.limits.event_persist_buffer;
let verified_users_active = settings.verified_users.is_active(); let verified_users_active = settings.verified_users.is_active();
let db_min_conn = settings.database.min_conn;
let db_max_conn = settings.database.max_conn;
let settings = settings.clone(); let settings = settings.clone();
info!("listening on: {}", socket_addr); info!("listening on: {}", socket_addr);
// all client-submitted valid events are broadcast to every // all client-submitted valid events are broadcast to every
@ -298,23 +293,26 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result
// overwhelming this will drop events and won't register // overwhelming this will drop events and won't register
// metadata events. // metadata events.
let (metadata_tx, metadata_rx) = broadcast::channel::<Event>(4096); let (metadata_tx, metadata_rx) = broadcast::channel::<Event>(4096);
// start the database writer thread. Give it a channel for // build a repository for events
let repo = db::build_repo(&settings).await;
// start the database writer task. Give it a channel for
// writing events, and for publishing events that have been // writing events, and for publishing events that have been
// written (to all connected clients). // written (to all connected clients).
tokio::task::spawn(
db::db_writer( db::db_writer(
repo.clone(),
settings.clone(), settings.clone(),
event_rx, event_rx,
bcast_tx.clone(), bcast_tx.clone(),
metadata_tx.clone(), metadata_tx.clone(),
shutdown_listen, shutdown_listen,
) ));
.await;
info!("db writer created"); info!("db writer created");
// create a nip-05 verifier thread; if enabled. // create a nip-05 verifier thread; if enabled.
if settings.verified_users.mode != VerifiedUsersMode::Disabled { if settings.verified_users.mode != VerifiedUsersMode::Disabled {
let verifier_opt = let verifier_opt =
nip05::Verifier::new(metadata_rx, bcast_tx.clone(), settings.clone()); nip05::Verifier::new(repo.clone(), metadata_rx, bcast_tx.clone(), settings.clone());
if let Ok(mut v) = verifier_opt { if let Ok(mut v) = verifier_opt {
if verified_users_active { if verified_users_active {
tokio::task::spawn(async move { tokio::task::spawn(async move {
@ -324,35 +322,19 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result
} }
} }
} }
// build a connection pool for DB maintenance
let maintenance_pool = db::build_pool(
"maintenance writer",
&settings,
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
1,
2,
false,
);
// Create a mutex that will block readers, so that a
// checkpoint can be performed quickly.
let safe_to_read = Arc::new(Mutex::new(0));
db::db_optimize_task(maintenance_pool.clone()).await;
db::db_checkpoint_task(maintenance_pool, safe_to_read.clone()).await;
// listen for (external to tokio) shutdown request // listen for (external to tokio) shutdown request
let controlled_shutdown = invoke_shutdown.clone(); let controlled_shutdown = invoke_shutdown.clone();
tokio::spawn(async move { tokio::spawn(async move {
info!("control message listener started"); info!("control message listener started");
// we only have good "shutdown" messages propagation from this-> controlled shutdown. Not from controlled_shutdown-> this. Which means we have a task that is stuck waiting on a sync receive. recv is blocking, and this is async.
match shutdown_rx.recv() { match shutdown_rx.recv() {
Ok(()) => { Ok(()) => {
info!("control message requesting shutdown"); info!("control message requesting shutdown");
controlled_shutdown.send(()).ok(); controlled_shutdown.send(()).ok();
} },
Err(std::sync::mpsc::RecvError) => { Err(std::sync::mpsc::RecvError) => {
// FIXME: spurious error on startup? trace!("shutdown requestor is disconnected (this is normal)");
debug!("shutdown requestor is disconnected");
} }
}; };
}); });
@ -366,41 +348,30 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result
info!("shutting down due to SIGINT (main)"); info!("shutting down due to SIGINT (main)");
ctrl_c_shutdown.send(()).ok(); ctrl_c_shutdown.send(()).ok();
}); });
// build a connection pool for sqlite connections
let pool = db::build_pool(
"client query",
&settings,
rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
db_min_conn,
db_max_conn,
true,
);
// spawn a task to check the pool size. // spawn a task to check the pool size.
let pool_monitor = pool.clone(); //let pool_monitor = pool.clone();
tokio::spawn(async move {db::monitor_pool("reader", pool_monitor).await;}); //tokio::spawn(async move {db::monitor_pool("reader", pool_monitor).await;});
// A `Service` is needed for every connection, so this // A `Service` is needed for every connection, so this
// creates one from our `handle_request` function. // creates one from our `handle_request` function.
let make_svc = make_service_fn(|conn: &AddrStream| { let make_svc = make_service_fn(|conn: &AddrStream| {
let svc_pool = pool.clone(); let repo = repo.clone();
let remote_addr = conn.remote_addr(); let remote_addr = conn.remote_addr();
let bcast = bcast_tx.clone(); let bcast = bcast_tx.clone();
let event = event_tx.clone(); let event = event_tx.clone();
let stop = invoke_shutdown.clone(); let stop = invoke_shutdown.clone();
let settings = settings.clone(); let settings = settings.clone();
let safe_to_read = safe_to_read.clone();
async move { async move {
// service_fn converts our function into a `Service` // service_fn converts our function into a `Service`
Ok::<_, Infallible>(service_fn(move |request: Request<Body>| { Ok::<_, Infallible>(service_fn(move |request: Request<Body>| {
handle_web_request( handle_web_request(
request, request,
svc_pool.clone(), repo.clone(),
settings.clone(), settings.clone(),
remote_addr, remote_addr,
bcast.clone(), bcast.clone(),
event.clone(), event.clone(),
stop.subscribe(), stop.subscribe(),
safe_to_read.clone(),
) )
})) }))
} }
@ -428,9 +399,9 @@ pub enum NostrMessage {
CloseMsg(CloseCmd), CloseMsg(CloseCmd),
} }
/// Convert Message to NostrMessage /// Convert Message to `NostrMessage`
fn convert_to_msg(msg: String, max_bytes: Option<usize>) -> Result<NostrMessage> { fn convert_to_msg(msg: &str, max_bytes: Option<usize>) -> Result<NostrMessage> {
let parsed_res: Result<NostrMessage> = serde_json::from_str(&msg).map_err(|e| e.into()); let parsed_res: Result<NostrMessage> = serde_json::from_str(msg).map_err(std::convert::Into::into);
match parsed_res { match parsed_res {
Ok(m) => { Ok(m) => {
if let NostrMessage::SubMsg(_) = m { if let NostrMessage::SubMsg(_) = m {
@ -455,8 +426,8 @@ fn convert_to_msg(msg: String, max_bytes: Option<usize>) -> Result<NostrMessage>
} }
} }
/// Turn a string into a NOTICE message ready to send over a WebSocket /// Turn a string into a NOTICE message ready to send over a `WebSocket`
fn make_notice_message(notice: Notice) -> Message { fn make_notice_message(notice: &Notice) -> Message {
let json = match notice { let json = match notice {
Notice::Message(ref msg) => json!(["NOTICE", msg]), Notice::Message(ref msg) => json!(["NOTICE", msg]),
Notice::EventResult(ref res) => json!(["OK", res.id, res.status.to_bool(), res.msg]), Notice::EventResult(ref res) => json!(["OK", res.id, res.status.to_bool(), res.msg]),
@ -474,14 +445,13 @@ struct ClientInfo {
/// Handle new client connections. This runs through an event loop /// Handle new client connections. This runs through an event loop
/// for all client communication. /// for all client communication.
async fn nostr_server( async fn nostr_server(
pool: db::SqlitePool, repo: Arc<dyn NostrRepo>,
client_info: ClientInfo, client_info: ClientInfo,
settings: Settings, settings: Settings,
mut ws_stream: WebSocketStream<Upgraded>, mut ws_stream: WebSocketStream<Upgraded>,
broadcast: Sender<Event>, broadcast: Sender<Event>,
event_tx: mpsc::Sender<SubmittedEvent>, event_tx: mpsc::Sender<SubmittedEvent>,
mut shutdown: Receiver<()>, mut shutdown: Receiver<()>,
safe_to_read: Arc<Mutex<u64>>,
) { ) {
// the time this websocket nostr server started // the time this websocket nostr server started
let orig_start = Instant::now(); let orig_start = Instant::now();
@ -559,7 +529,7 @@ async fn nostr_server(
ws_stream.send(Message::Ping(Vec::new())).await.ok(); ws_stream.send(Message::Ping(Vec::new())).await.ok();
}, },
Some(notice_msg) = notice_rx.recv() => { Some(notice_msg) = notice_rx.recv() => {
ws_stream.send(make_notice_message(notice_msg)).await.ok(); ws_stream.send(make_notice_message(&notice_msg)).await.ok();
}, },
Some(query_result) = query_rx.recv() => { Some(query_result) = query_rx.recv() => {
// database informed us of a query result we asked for // database informed us of a query result we asked for
@ -603,11 +573,11 @@ async fn nostr_server(
// Consume text messages from the client, parse into Nostr messages. // Consume text messages from the client, parse into Nostr messages.
let nostr_msg = match ws_next { let nostr_msg = match ws_next {
Some(Ok(Message::Text(m))) => { Some(Ok(Message::Text(m))) => {
convert_to_msg(m,settings.limits.max_event_bytes) convert_to_msg(&m,settings.limits.max_event_bytes)
}, },
Some(Ok(Message::Binary(_))) => { Some(Ok(Message::Binary(_))) => {
ws_stream.send( ws_stream.send(
make_notice_message(Notice::message("binary messages are not accepted".into()))).await.ok(); make_notice_message(&Notice::message("binary messages are not accepted".into()))).await.ok();
continue; continue;
}, },
Some(Ok(Message::Ping(_) | Message::Pong(_))) => { Some(Ok(Message::Ping(_) | Message::Pong(_))) => {
@ -617,7 +587,7 @@ async fn nostr_server(
}, },
Some(Err(WsError::Capacity(MessageTooLong{size, max_size}))) => { Some(Err(WsError::Capacity(MessageTooLong{size, max_size}))) => {
ws_stream.send( ws_stream.send(
make_notice_message(Notice::message(format!("message too large ({} > {})",size, max_size)))).await.ok(); make_notice_message(&Notice::message(format!("message too large ({} > {})",size, max_size)))).await.ok();
continue; continue;
}, },
None | None |
@ -662,13 +632,13 @@ async fn nostr_server(
if let Some(fut_sec) = settings.options.reject_future_seconds { if let Some(fut_sec) = settings.options.reject_future_seconds {
let msg = format!("The event created_at field is out of the acceptable range (+{}sec) for this relay.",fut_sec); let msg = format!("The event created_at field is out of the acceptable range (+{}sec) for this relay.",fut_sec);
let notice = Notice::invalid(e.id, &msg); let notice = Notice::invalid(e.id, &msg);
ws_stream.send(make_notice_message(notice)).await.ok(); ws_stream.send(make_notice_message(&notice)).await.ok();
} }
} }
}, },
Err(e) => { Err(e) => {
info!("client sent an invalid event (cid: {})", cid); info!("client sent an invalid event (cid: {})", cid);
ws_stream.send(make_notice_message(Notice::invalid(evid, &format!("{}", e)))).await.ok(); ws_stream.send(make_notice_message(&Notice::invalid(evid, &format!("{}", e)))).await.ok();
} }
} }
}, },
@ -680,7 +650,9 @@ async fn nostr_server(
// * making a channel to cancel to request later // * making a channel to cancel to request later
// * sending a request for a SQL query // * sending a request for a SQL query
// Do nothing if the sub already exists. // Do nothing if the sub already exists.
if !conn.has_subscription(&s) { if conn.has_subscription(&s) {
info!("client sent duplicate subscription, ignoring (cid: {}, sub: {:?})", cid, s.id);
} else {
if let Some(ref lim) = sub_lim_opt { if let Some(ref lim) = sub_lim_opt {
lim.until_ready_with_jitter(jitter).await; lim.until_ready_with_jitter(jitter).await;
} }
@ -688,21 +660,19 @@ async fn nostr_server(
match conn.subscribe(s.clone()) { match conn.subscribe(s.clone()) {
Ok(()) => { Ok(()) => {
// when we insert, if there was a previous query running with the same name, cancel it. // when we insert, if there was a previous query running with the same name, cancel it.
if let Some(previous_query) = running_queries.insert(s.id.to_owned(), abandon_query_tx) { if let Some(previous_query) = running_queries.insert(s.id.clone(), abandon_query_tx) {
previous_query.send(()).ok(); previous_query.send(()).ok();
} }
if s.needs_historical_events() { if s.needs_historical_events() {
// start a database query. this spawns a blocking database query on a worker thread. // start a database query. this spawns a blocking database query on a worker thread.
db::db_query(s, cid.to_owned(), pool.clone(), query_tx.clone(), abandon_query_rx,safe_to_read.clone()).await; repo.query_subscription(s, cid.clone(), query_tx.clone(), abandon_query_rx).await.ok();
} }
}, },
Err(e) => { Err(e) => {
info!("Subscription error: {} (cid: {}, sub: {:?})", e, cid, s.id); info!("Subscription error: {} (cid: {}, sub: {:?})", e, cid, s.id);
ws_stream.send(make_notice_message(Notice::message(format!("Subscription error: {}", e)))).await.ok(); ws_stream.send(make_notice_message(&Notice::message(format!("Subscription error: {}", e)))).await.ok();
} }
} }
} else {
info!("client sent duplicate subscription, ignoring (cid: {}, sub: {:?})", cid, s.id);
} }
}, },
Ok(NostrMessage::CloseMsg(cc)) => { Ok(NostrMessage::CloseMsg(cc)) => {
@ -720,7 +690,7 @@ async fn nostr_server(
conn.unsubscribe(&c); conn.unsubscribe(&c);
} else { } else {
info!("invalid command ignored"); info!("invalid command ignored");
ws_stream.send(make_notice_message(Notice::message("could not parse command".into()))).await.ok(); ws_stream.send(make_notice_message(&Notice::message("could not parse command".into()))).await.ok();
} }
}, },
Err(Error::ConnError) => { Err(Error::ConnError) => {
@ -729,11 +699,11 @@ async fn nostr_server(
} }
Err(Error::EventMaxLengthError(s)) => { Err(Error::EventMaxLengthError(s)) => {
info!("client sent event larger ({} bytes) than max size (cid: {})", s, cid); info!("client sent event larger ({} bytes) than max size (cid: {})", s, cid);
ws_stream.send(make_notice_message(Notice::message("event exceeded max size".into()))).await.ok(); ws_stream.send(make_notice_message(&Notice::message("event exceeded max size".into()))).await.ok();
}, },
Err(Error::ProtoParseError) => { Err(Error::ProtoParseError) => {
info!("client sent event that could not be parsed (cid: {})", cid); info!("client sent event that could not be parsed (cid: {})", cid);
ws_stream.send(make_notice_message(Notice::message("could not parse command".into()))).await.ok(); ws_stream.send(make_notice_message(&Notice::message("could not parse command".into()))).await.ok();
}, },
Err(e) => { Err(e) => {
info!("got non-fatal error from client (cid: {}, error: {:?}", cid, e); info!("got non-fatal error from client (cid: {}, error: {:?}", cid, e);

View File

@ -68,7 +68,7 @@ impl<'de> Deserialize<'de> for ReqFilter {
let empty_string = "".into(); let empty_string = "".into();
let mut ts = None; let mut ts = None;
// iterate through each key, and assign values that exist // iterate through each key, and assign values that exist
for (key, val) in filter.into_iter() { for (key, val) in filter {
// ids // ids
if key == "ids" { if key == "ids" {
let raw_ids: Option<Vec<String>>= Deserialize::deserialize(val).ok(); let raw_ids: Option<Vec<String>>= Deserialize::deserialize(val).ok();
@ -107,7 +107,7 @@ impl<'de> Deserialize<'de> for ReqFilter {
if let Some(m) = ts.as_mut() { if let Some(m) = ts.as_mut() {
let tag_vals: Option<Vec<String>> = Deserialize::deserialize(val).ok(); let tag_vals: Option<Vec<String>> = Deserialize::deserialize(val).ok();
if let Some(v) = tag_vals { if let Some(v) = tag_vals {
let hs = HashSet::from_iter(v.into_iter()); let hs = v.into_iter().collect::<HashSet<_>>();
m.insert(tag_search.to_owned(), hs); m.insert(tag_search.to_owned(), hs);
} }
}; };
@ -197,20 +197,20 @@ impl<'de> Deserialize<'de> for Subscription {
impl Subscription { impl Subscription {
/// Get a copy of the subscription identifier. /// Get a copy of the subscription identifier.
pub fn get_id(&self) -> String { #[must_use] pub fn get_id(&self) -> String {
self.id.clone() self.id.clone()
} }
/// Determine if any filter is requesting historical (database) /// Determine if any filter is requesting historical (database)
/// queries. If every filter has limit:0, we do not need to query the DB. /// queries. If every filter has limit:0, we do not need to query the DB.
pub fn needs_historical_events(&self) -> bool { #[must_use] pub fn needs_historical_events(&self) -> bool {
self.filters.iter().any(|f| f.limit!=Some(0)) self.filters.iter().any(|f| f.limit!=Some(0))
} }
/// Determine if this subscription matches a given [`Event`]. Any /// Determine if this subscription matches a given [`Event`]. Any
/// individual filter match is sufficient. /// individual filter match is sufficient.
pub fn interested_in_event(&self, event: &Event) -> bool { #[must_use] pub fn interested_in_event(&self, event: &Event) -> bool {
for f in self.filters.iter() { for f in &self.filters {
if f.interested_in_event(event) { if f.interested_in_event(event) {
return true; return true;
} }
@ -233,23 +233,20 @@ impl ReqFilter {
fn ids_match(&self, event: &Event) -> bool { fn ids_match(&self, event: &Event) -> bool {
self.ids self.ids
.as_ref() .as_ref()
.map(|vs| prefix_match(vs, &event.id)) .map_or(true, |vs| prefix_match(vs, &event.id))
.unwrap_or(true)
} }
fn authors_match(&self, event: &Event) -> bool { fn authors_match(&self, event: &Event) -> bool {
self.authors self.authors
.as_ref() .as_ref()
.map(|vs| prefix_match(vs, &event.pubkey)) .map_or(true, |vs| prefix_match(vs, &event.pubkey))
.unwrap_or(true)
} }
fn delegated_authors_match(&self, event: &Event) -> bool { fn delegated_authors_match(&self, event: &Event) -> bool {
if let Some(delegated_pubkey) = &event.delegated_by { if let Some(delegated_pubkey) = &event.delegated_by {
self.authors self.authors
.as_ref() .as_ref()
.map(|vs| prefix_match(vs, delegated_pubkey)) .map_or(true, |vs| prefix_match(vs, delegated_pubkey))
.unwrap_or(true)
} else { } else {
false false
} }
@ -275,16 +272,15 @@ impl ReqFilter {
fn kind_match(&self, kind: u64) -> bool { fn kind_match(&self, kind: u64) -> bool {
self.kinds self.kinds
.as_ref() .as_ref()
.map(|ks| ks.contains(&kind)) .map_or(true, |ks| ks.contains(&kind))
.unwrap_or(true)
} }
/// Determine if all populated fields in this filter match the provided event. /// Determine if all populated fields in this filter match the provided event.
pub fn interested_in_event(&self, event: &Event) -> bool { #[must_use] pub fn interested_in_event(&self, event: &Event) -> bool {
// self.id.as_ref().map(|v| v == &event.id).unwrap_or(true) // self.id.as_ref().map(|v| v == &event.id).unwrap_or(true)
self.ids_match(event) self.ids_match(event)
&& self.since.map(|t| event.created_at > t).unwrap_or(true) && self.since.map_or(true, |t| event.created_at > t)
&& self.until.map(|t| event.created_at < t).unwrap_or(true) && self.until.map_or(true, |t| event.created_at < t)
&& self.kind_match(event.kind) && self.kind_match(event.kind)
&& (self.authors_match(event) || self.delegated_authors_match(event)) && (self.authors_match(event) || self.delegated_authors_match(event))
&& self.tag_match(event) && self.tag_match(event)

View File

@ -2,7 +2,7 @@
use std::time::SystemTime; use std::time::SystemTime;
/// Seconds since 1970. /// Seconds since 1970.
pub fn unix_time() -> u64 { #[must_use] pub fn unix_time() -> u64 {
SystemTime::now() SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH) .duration_since(SystemTime::UNIX_EPOCH)
.map(|x| x.as_secs()) .map(|x| x.as_secs())
@ -10,12 +10,12 @@ pub fn unix_time() -> u64 {
} }
/// Check if a string contains only hex characters. /// Check if a string contains only hex characters.
pub fn is_hex(s: &str) -> bool { #[must_use] pub fn is_hex(s: &str) -> bool {
s.chars().all(|x| char::is_ascii_hexdigit(&x)) s.chars().all(|x| char::is_ascii_hexdigit(&x))
} }
/// Check if a string contains only lower-case hex chars. /// Check if a string contains only lower-case hex chars.
pub fn is_lower_hex(s: &str) -> bool { #[must_use] pub fn is_lower_hex(s: &str) -> bool {
s.chars().all(|x| { s.chars().all(|x| {
(char::is_ascii_lowercase(&x) || char::is_ascii_digit(&x)) && char::is_ascii_hexdigit(&x) (char::is_ascii_lowercase(&x) || char::is_ascii_digit(&x)) && char::is_ascii_hexdigit(&x)
}) })