diff --git a/.gitignore b/.gitignore index ea8c4bf..1596390 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +nostr.db diff --git a/Cargo.lock b/Cargo.lock index 481e2e4..1c21c01 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -310,6 +310,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hex" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" + [[package]] name = "http" version = "0.2.5" @@ -435,6 +441,7 @@ dependencies = [ "env_logger", "futures", "futures-util", + "hex", "log", "rusqlite", "secp256k1", diff --git a/Cargo.toml b/Cargo.toml index 0015e0c..fce27f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,4 +18,5 @@ bitcoin_hashes = { version = "0.9.7", features = ["serde"] } secp256k1 = { version = "0.20.3", features = ["rand", "rand-std", "serde", "bitcoin_hashes"] } serde = { version = "1.0.130", features = ["derive"] } serde_json = "1.0.72" +hex = "0.4.3" rusqlite = "0.26.1" diff --git a/src/close.rs b/src/close.rs index 516df50..5392622 100644 --- a/src/close.rs +++ b/src/close.rs @@ -9,7 +9,7 @@ pub struct CloseCmd { #[derive(Serialize, Deserialize, PartialEq, Debug, Clone)] pub struct Close { - id: String, + pub id: String, } impl From for Result { diff --git a/src/db.rs b/src/db.rs new file mode 100644 index 0000000..e9f73fb --- /dev/null +++ b/src/db.rs @@ -0,0 +1,274 @@ +use crate::error::Result; +use crate::event::Event; +use crate::subscription::Subscription; +use hex; +use log::*; +use rusqlite::params; +use rusqlite::Connection; +use rusqlite::OpenFlags; +use std::path::Path; +use tokio::task; + +const DB_FILE: &str = "nostr.db"; + +// schema +const INIT_SQL: &str = r##" +PRAGMA encoding = "UTF-8"; +PRAGMA journal_mode=WAL; +PRAGMA main.synchronous=NORMAL; +PRAGMA foreign_keys = ON; +PRAGMA application_id = 1654008667; +PRAGMA user_version = 1; +pragma mmap_size = 536870912; -- 512MB of mmap +CREATE TABLE IF NOT EXISTS event ( +id INTEGER PRIMARY KEY, +event_hash BLOB NOT NULL, -- 4-byte hash +first_seen INTEGER NOT NULL, -- when the event was first seen (not authored!) (seconds since 1970) +created_at INTEGER NOT NULL, -- when the event was authored +author BLOB NOT NULL, -- author pubkey +kind INTEGER NOT NULL, -- event kind +content TEXT NOT NULL -- serialized json of event object +); +CREATE UNIQUE INDEX IF NOT EXISTS event_hash_index ON event(event_hash); +CREATE INDEX IF NOT EXISTS created_at_index ON event(created_at); +CREATE INDEX IF NOT EXISTS author_index ON event(author); +CREATE INDEX IF NOT EXISTS kind_index ON event(kind); +CREATE TABLE IF NOT EXISTS event_ref ( +id INTEGER PRIMARY KEY, +event_id INTEGER NOT NULL, -- an event ID that contains an #e tag. +referenced_event BLOB NOT NULL, -- the event that is referenced. +FOREIGN KEY(event_id) REFERENCES event(id) ON UPDATE CASCADE ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS event_ref_index ON event_ref(referenced_event); +CREATE TABLE IF NOT EXISTS pubkey_ref ( +id INTEGER PRIMARY KEY, +event_id INTEGER NOT NULL, -- an event ID that contains an #p tag. +referenced_pubkey BLOB NOT NULL, -- the pubkey that is referenced. +FOREIGN KEY(event_id) REFERENCES event(id) ON UPDATE RESTRICT ON DELETE CASCADE +); +CREATE INDEX IF NOT EXISTS pubkey_ref_index ON pubkey_ref(referenced_pubkey); +"##; + +/// Spawn a database writer that persists events to the SQLite store. +pub async fn db_writer( + mut event_rx: tokio::sync::mpsc::Receiver, +) -> tokio::task::JoinHandle> { + task::spawn_blocking(move || { + let mut conn = Connection::open_with_flags( + Path::new(DB_FILE), + OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE, + )?; + info!("Opened database for writing"); + // TODO: determine if we need to execute the init script. + // TODO: check database app id / version before proceeding. + match conn.execute_batch(INIT_SQL) { + Ok(()) => info!("init completed"), + Err(err) => info!("update failed: {}", err), + } + loop { + // call blocking read on channel + let next_event = event_rx.blocking_recv(); + // if the channel has closed, we will never get work + if next_event.is_none() { + info!("No more event senders for DB, shutting down."); + break; + } + let event = next_event.unwrap(); + info!("Got event to write: {}", event.get_event_id_prefix()); + match write_event(&mut conn, &event) { + Ok(updated) => { + if updated == 0 { + info!("nothing inserted (dupe?)"); + } else { + info!("persisted new event"); + } + } + Err(err) => { + info!("event insert failed: {}", err); + } + } + } + conn.close().ok(); + info!("database connection closed"); + Ok(()) + }) +} + +pub fn write_event(conn: &mut Connection, e: &Event) -> Result { + // start transaction + let tx = conn.transaction()?; + // get relevant fields from event and convert to blobs. + let id_blob = hex::decode(&e.id).ok(); + let pubkey_blob = hex::decode(&e.pubkey).ok(); + let event_str = serde_json::to_string(&e).ok(); + // ignore if the event hash is a duplicate.x + let ins_count = tx.execute( + "INSERT OR IGNORE INTO event (event_hash, created_at, kind, author, content, first_seen) VALUES (?1, ?2, ?3, ?4, ?5, strftime('%s','now'));", + params![id_blob, e.created_at, e.kind, pubkey_blob, event_str] + )?; + let ev_id = tx.last_insert_rowid(); + let etags = e.get_event_tags(); + if etags.len() > 0 { + // this will need to + for etag in etags.iter() { + tx.execute( + "INSERT OR IGNORE INTO event_ref (event_id, referenced_event) VALUES (?1, ?2)", + params![ev_id, hex::decode(&etag).ok()], + )?; + } + } + let ptags = e.get_pubkey_tags(); + if ptags.len() > 0 { + for ptag in ptags.iter() { + tx.execute( + "INSERT OR IGNORE INTO event_ref (event_id, referenced_pubkey) VALUES (?1, ?2)", + params![ev_id, hex::decode(&ptag).ok()], + )?; + } + } + tx.commit()?; + Ok(ins_count) +} + +// Queries return a subscription identifier and the serialized event. +#[derive(PartialEq, Debug, Clone)] +pub struct QueryResult { + pub sub_id: String, + pub event: String, +} + +// TODO: make this hex +fn is_alphanum(s: &str) -> bool { + s.chars().all(|x| char::is_ascii_hexdigit(&x)) +} + +fn query_from_sub(sub: &Subscription) -> String { + // build a dynamic SQL query. all user-input is either an integer + // (sqli-safe), or a string that is filtered to only contain + // hexadecimal characters. + let mut query = + "SELECT DISTINCT(e.content) FROM event e LEFT JOIN event_ref er ON e.id=er.event_id LEFT JOIN pubkey_ref pr ON e.id=pr.event_id " + .to_owned(); + // for every filter in the subscription, generate a where clause + // all individual filter clause strings for this subscription + let mut filter_clauses: Vec = Vec::new(); + for f in sub.filters.iter() { + // individual filter components + let mut filter_components: Vec = Vec::new(); + // Query for "author" + // https://github.com/fiatjaf/nostr/issues/34 + // I believe the author & authors fields are redundant. + if f.author.is_some() { + let author_str = f.author.as_ref().unwrap(); + if is_alphanum(author_str) { + let author_clause = format!("author = x'{}'", author_str); + filter_components.push(author_clause); + } + } + // Query for "authors" + if f.authors.is_some() { + let authors_escaped: Vec = f + .authors + .as_ref() + .unwrap() + .iter() + .filter(|&x| is_alphanum(x)) + .map(|x| format!("x'{}'", x)) + .collect(); + let authors_clause = format!("author IN ({})", authors_escaped.join(", ")); + filter_components.push(authors_clause); + } + // Query for Kind + if f.kind.is_some() { + // kind is number, no escaping needed + let kind_clause = format!("kind = {}", f.kind.unwrap()); + filter_components.push(kind_clause); + } + // Query for event + if f.id.is_some() { + // whitelist characters + let id_str = f.id.as_ref().unwrap(); + if is_alphanum(id_str) { + let id_clause = format!("event_hash = x'{}'", id_str); + filter_components.push(id_clause); + } + } + // Query for referenced event + if f.event.is_some() { + // whitelist characters + let ev_str = f.event.as_ref().unwrap(); + if is_alphanum(ev_str) { + let ev_clause = format!("referenced_event = x'{}'", ev_str); + filter_components.push(ev_clause); + } + } + // Query for referenced pet name pubkey + if f.pubkey.is_some() { + // whitelist characters + let pet_str = f.pubkey.as_ref().unwrap(); + if is_alphanum(pet_str) { + let pet_clause = format!("referenced_pubkey = x'{}'", pet_str); + filter_components.push(pet_clause); + } + } + // Query for timestamp + if f.since.is_some() { + // timestamp is number, no escaping needed + let created_clause = format!("created_at > {}", f.since.unwrap()); + filter_components.push(created_clause); + } + // combine all clauses, and add to filter_clauses + if filter_components.len() > 0 { + let mut fc = "( ".to_owned(); + fc.push_str(&filter_components.join(" AND ")); + fc.push_str(" )"); + filter_clauses.push(fc); + } + } + + // combine all filters with OR clauses, if any exist + if filter_clauses.len() > 0 { + query.push_str(" WHERE "); + query.push_str(&filter_clauses.join(" OR ")); + } + info!("Query: {}", query); + return query; +} + +pub async fn db_query( + sub: Subscription, + query_tx: tokio::sync::mpsc::Sender, + mut abandon_query_rx: tokio::sync::oneshot::Receiver<()>, +) { + task::spawn_blocking(move || { + let conn = + Connection::open_with_flags(Path::new(DB_FILE), OpenFlags::SQLITE_OPEN_READ_ONLY) + .unwrap(); + info!("Opened database for reading"); + info!("Going to query for: {:?}", sub); + // generate query + let q = query_from_sub(&sub); + + let mut stmt = conn.prepare(&q).unwrap(); + let mut event_rows = stmt.query([]).unwrap(); + let mut i: usize = 0; + while let Some(row) = event_rows.next().unwrap() { + // check if this is still active (we could do this every N rows) + if abandon_query_rx.try_recv().is_ok() { + info!("Abandoning query..."); + // we have received a request to abandon the query + return; + } + let event_json = row.get(0).unwrap(); + i += 1; + info!("Sending event #{}", i); + query_tx + .blocking_send(QueryResult { + sub_id: sub.get_id(), + event: event_json, + }) + .ok(); + } + info!("Finished reading"); + }); +} diff --git a/src/error.rs b/src/error.rs index 8476303..5abe1f1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -27,6 +27,14 @@ pub enum Error { WebsocketError(WsError), #[error("Command unknown")] CommandUnknownError, + #[error("SQL error")] + SqlError(rusqlite::Error), +} + +impl From for Error { + fn from(r: rusqlite::Error) -> Self { + Error::SqlError(r) + } } impl From for Error { diff --git a/src/event.rs b/src/event.rs index a7d029c..1ad239c 100644 --- a/src/event.rs +++ b/src/event.rs @@ -124,18 +124,35 @@ impl Event { serde_json::Value::Array(tags) } - // check if given event is referenced in a tag - pub fn event_tag_match(&self, eventid: &str) -> bool { + // get set of event tags + pub fn get_event_tags(&self) -> Vec<&str> { + let mut etags = vec![]; for t in self.tags.iter() { - if t.len() == 2 { - if t.get(0).unwrap() == "#e" { - if t.get(1).unwrap() == eventid { - return true; - } + if t.len() >= 2 { + if t.get(0).unwrap() == "e" { + etags.push(&t.get(1).unwrap()[..]); } } } - return false; + etags + } + + // get set of pubkey tags + pub fn get_pubkey_tags(&self) -> Vec<&str> { + let mut ptags = vec![]; + for t in self.tags.iter() { + if t.len() >= 2 { + if t.get(0).unwrap() == "p" { + ptags.push(&t.get(1).unwrap()[..]); + } + } + } + ptags + } + + // check if given event is referenced in a tag + pub fn event_tag_match(&self, eventid: &str) -> bool { + self.get_event_tags().contains(&eventid) } } @@ -170,6 +187,21 @@ mod tests { Ok(()) } + #[test] + fn empty_event_tag_match() -> Result<()> { + let event = simple_event(); + assert!(!event.event_tag_match("foo")); + Ok(()) + } + + #[test] + fn single_event_tag_match() -> Result<()> { + let mut event = simple_event(); + event.tags = vec![vec!["e".to_owned(), "foo".to_owned()]]; + assert!(event.event_tag_match("foo")); + Ok(()) + } + #[test] fn event_tags_serialize() -> Result<()> { // serialize an event with tags to JSON string diff --git a/src/lib.rs b/src/lib.rs index b247749..e464664 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,6 +1,7 @@ -pub mod conn; -pub mod protostream; -pub mod event; -pub mod subscription; pub mod close; +pub mod conn; +pub mod db; pub mod error; +pub mod event; +pub mod protostream; +pub mod subscription; diff --git a/src/main.rs b/src/main.rs index e8c8d44..67954fc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,18 +3,20 @@ use futures::StreamExt; use log::*; use nostr_rs_relay::close::Close; use nostr_rs_relay::conn; +use nostr_rs_relay::db; use nostr_rs_relay::error::{Error, Result}; use nostr_rs_relay::event::Event; use nostr_rs_relay::protostream; use nostr_rs_relay::protostream::NostrMessage::*; use nostr_rs_relay::protostream::NostrResponse::*; -use rusqlite::Result as SQLResult; +use std::collections::HashMap; use std::env; use tokio::net::{TcpListener, TcpStream}; use tokio::runtime::Builder; use tokio::sync::broadcast; -use tokio::sync::broadcast::Sender; +use tokio::sync::broadcast::{Receiver, Sender}; use tokio::sync::mpsc; +use tokio::sync::oneshot; /// Start running a Nostr relay server. fn main() -> Result<(), Error> { @@ -35,60 +37,103 @@ fn main() -> Result<(), Error> { info!("Listening on: {}", addr); // Establish global broadcast channel. This is where all // accepted events will be distributed for other connected clients. - let (bcast_tx, _) = broadcast::channel::(64); + + // this needs to be large enough to accomodate any slow + // readers - otherwise messages will be dropped before they + // can be processed. Since this is global to all connections, + // we can tolerate this being rather large (for 4096, the + // buffer itself is about 1MB). + let (bcast_tx, _) = broadcast::channel::(4096); // Establish database writer channel. This needs to be // accessible from sync code, which is why the broadcast // cannot be reused. - let (event_tx, _) = mpsc::channel::(64); + let (event_tx, event_rx) = mpsc::channel::(16); // start the database writer. - // TODO: manage program termination, to close the DB. - //let _db_handle = db_writer(event_rx).await; - while let Ok((stream, _)) = listener.accept().await { - tokio::spawn(nostr_server(stream, bcast_tx.clone(), event_tx.clone())); + db::db_writer(event_rx).await; + // setup a broadcast channel for invoking a process shutdown + let (invoke_shutdown, _) = broadcast::channel::<()>(1); + let shutdown_handler = invoke_shutdown.clone(); + // listen for ctrl-c interruupts + tokio::spawn(async move { + tokio::signal::ctrl_c().await.unwrap(); + // Your handler here + info!("got ctrl-c"); + shutdown_handler.send(()).ok(); + }); + let mut stop_listening = invoke_shutdown.subscribe(); + // shutdown on Ctrl-C, or accept a new connection + loop { + tokio::select! { + _ = stop_listening.recv() => { + break; + } + Ok((stream, _)) = listener.accept() => { + tokio::spawn(nostr_server( + stream, + bcast_tx.clone(), + event_tx.clone(), + invoke_shutdown.subscribe(), + )); + } + } } }); Ok(()) } -async fn _db_writer(_event_rx: tokio::sync::mpsc::Receiver) -> SQLResult<()> { - unimplemented!(); -} - async fn nostr_server( stream: TcpStream, broadcast: Sender, - _event_tx: tokio::sync::mpsc::Sender, + event_tx: tokio::sync::mpsc::Sender, + mut shutdown: Receiver<()>, ) { // get a broadcast channel for clients to communicate on // wrap the TCP stream in a websocket. let mut bcast_rx = broadcast.subscribe(); + // upgrade the TCP connection to WebSocket let conn = tokio_tungstenite::accept_async(stream).await; let ws_stream = conn.expect("websocket handshake error"); - // a stream & sink of Nostr protocol messages + // wrap websocket into a stream & sink of Nostr protocol messages let mut nostr_stream = protostream::wrap_ws_in_nostr(ws_stream); - //let task_queue = mpsc::channel::(16); - // track connection state so we can break when it fails // Track internal client state let mut conn = conn::ClientConn::new(); - let mut conn_good = true; + let cid = conn.get_client_prefix(); + // Create a channel for receiving query results from the database. + // we will send out the tx handle to any query we generate. + let (query_tx, mut query_rx) = mpsc::channel::(256); + // maintain a hashmap of a oneshot channel for active subscriptions. + // when these subscriptions are cancelled, make a message + // available to the executing query so it knows to stop. + //let (abandon_query_tx, _) = oneshot::channel::<()>(); + let mut running_queries: HashMap> = HashMap::new(); + loop { tokio::select! { + _ = shutdown.recv() => { + // server shutting down, exit loop + break; + }, + Some(query_result) = query_rx.recv() => { + info!("Got query result"); + let res = EventRes(query_result.sub_id,query_result.event); + nostr_stream.send(res).await.ok(); + }, Ok(global_event) = bcast_rx.recv() => { // ignoring closed broadcast errors, there will always be one sender available. // Is there a subscription for this event? let sub_name_opt = conn.get_matching_subscription(&global_event); - if sub_name_opt.is_none() { - return; - } else { + if sub_name_opt.is_some() { let sub_name = sub_name_opt.unwrap(); let event_str = serde_json::to_string(&global_event); if event_str.is_ok() { info!("sub match: client: {}, sub: {}, event: {}", - conn.get_client_prefix(), sub_name, + cid, sub_name, global_event.get_event_id_prefix()); // create an event response and send it let res = EventRes(sub_name.to_owned(),event_str.unwrap()); nostr_stream.send(res).await.ok(); + } else { + warn!("could not convert event to string"); } } }, @@ -102,49 +147,68 @@ async fn nostr_server( match parsed { Ok(e) => { let id_prefix:String = e.id.chars().take(8).collect(); - info!("Successfully parsed/validated event: {}", id_prefix); + info!("Successfully parsed/validated event: {} from client: {}", id_prefix, cid); + // Write this to the database + event_tx.send(e.clone()).await.ok(); // send this event to everyone listening. let bcast_res = broadcast.send(e); if bcast_res.is_err() { warn!("Could not send broadcast message: {:?}", bcast_res); } }, - Err(_) => {info!("Invalid event ignored")} + Err(_) => {info!("Client {} sent an invalid event", cid)} } }, Some(Ok(SubMsg(s))) => { + info!("Client {} requesting a subscription", cid); + // subscription handling consists of: - // adding new subscriptions to the client conn: - conn.subscribe(s).ok(); - // TODO: sending a request for a SQL query + // * registering the subscription so future events can be matched + // * making a channel to cancel to request later + // * sending a request for a SQL query + let (abandon_query_tx, abandon_query_rx) = oneshot::channel::<()>(); + running_queries.insert(s.id.to_owned(), abandon_query_tx); + // register this connection + conn.subscribe(s.clone()).ok(); + // start a database query + db::db_query(s, query_tx.clone(), abandon_query_rx).await; }, Some(Ok(CloseMsg(cc))) => { // closing a request simply removes the subscription. let parsed : Result = Result::::from(cc); match parsed { - Ok(c) => {conn.unsubscribe(c);}, + Ok(c) => { + let stop_tx = running_queries.remove(&c.id); + match stop_tx { + Some(tx) => { + info!("Removing query, telling DB to abandon query"); + tx.send(()).ok(); + }, + None => {} + } + conn.unsubscribe(c); + }, Err(_) => {info!("Invalid command ignored");} } }, None => { - info!("stream ended"); + info!("normal websocket close from client: {}",cid); + break; }, Some(Err(Error::ConnError)) => { - debug!("got connection error, disconnecting"); - conn_good = false; - if conn_good { - info!("Lint bug?, https://github.com/rust-lang/rust/pull/57302"); - } - return + info!("got connection close/error, disconnecting client: {}",cid); + break; } Some(Err(e)) => { - info!("got error, continuing: {:?}", e); + info!("got non-fatal error from client: {}, error: {:?}", cid, e); }, } - } - } - if !conn_good { - break; + }, } } + // connection cleanup - ensure any still running queries are terminated. + for (_, stop_tx) in running_queries.into_iter() { + stop_tx.send(()).ok(); + } + info!("stopping client connection: {}", cid); } diff --git a/src/protostream.rs b/src/protostream.rs index 722eee2..3b76191 100644 --- a/src/protostream.rs +++ b/src/protostream.rs @@ -47,7 +47,6 @@ impl Stream for NostrStream { fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { // convert Message to NostrMessage fn convert(msg: String) -> Result { - debug!("Input message: {}", &msg); let parsed_res: Result = serde_json::from_str(&msg).map_err(|e| e.into()); match parsed_res { Ok(m) => Ok(m), diff --git a/src/subscription.rs b/src/subscription.rs index f086ffa..22710b1 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -16,9 +16,9 @@ pub struct ReqFilter { pub id: Option, pub author: Option, pub kind: Option, - #[serde(rename = "e#")] + #[serde(rename = "#e")] pub event: Option, - #[serde(rename = "p#")] + #[serde(rename = "#p")] pub pubkey: Option, pub since: Option, pub authors: Option>,