improvement: add NostrRepo trait, with sqlite implementation

This is inspired by the work of
v0l (https://github.com/v0l/nostr-rs-relay/).

A new trait abstracts the storage layer with an async API.  Rusqlite
is still used with worker threads, but this allows for Postgresql or
other backends to be used.

There may be bugs, this has not been rigorously tested.
This commit is contained in:
Greg Heartsfield 2023-01-22 09:49:49 -06:00
parent e996d4c009
commit 6800c2e39d
20 changed files with 1396 additions and 1345 deletions

1
Cargo.lock generated
View File

@ -1191,6 +1191,7 @@ name = "nostr-rs-relay"
version = "0.7.17"
dependencies = [
"anyhow",
"async-trait",
"bitcoin_hashes",
"clap",
"config",

View File

@ -42,6 +42,7 @@ parse_duration = "2"
rand = "0.8"
const_format = "0.2.28"
regex = "1"
async-trait = "0.1.60"
[dev-dependencies]
anyhow = "1"

View File

@ -1,14 +1,13 @@
use std::io;
use std::path::Path;
use nostr_rs_relay::utils::is_lower_hex;
use tracing::*;
use tracing::info;
use nostr_rs_relay::config;
use nostr_rs_relay::event::{Event,single_char_tagname};
use nostr_rs_relay::error::{Error, Result};
use nostr_rs_relay::db::build_pool;
use nostr_rs_relay::schema::{curr_db_version, DB_VERSION};
use nostr_rs_relay::repo::sqlite::{PooledConnection, build_pool};
use nostr_rs_relay::repo::sqlite_migration::{curr_db_version, DB_VERSION};
use rusqlite::{OpenFlags, Transaction};
use nostr_rs_relay::db::PooledConnection;
use std::sync::mpsc;
use std::thread;
use rusqlite::params;
@ -67,7 +66,7 @@ pub fn main() -> Result<()> {
info!("finished parsing events");
event_tx.send(None).ok();
let ok: Result<()> = Ok(());
return ok;
ok
});
let mut conn: PooledConnection = pool.get()?;
let mut events_read = 0;

View File

@ -18,6 +18,7 @@ pub struct Info {
#[allow(unused)]
pub struct Database {
pub data_directory: String,
pub engine: String,
pub in_memory: bool,
pub min_conn: u32,
pub max_conn: u32,
@ -206,6 +207,7 @@ impl Default for Settings {
diagnostics: Diagnostics { tracing: false },
database: Database {
data_directory: ".".to_owned(),
engine: "sqlite".to_owned(),
in_memory: false,
min_conn: 4,
max_conn: 8,

View File

@ -14,7 +14,7 @@ const MAX_SUBSCRIPTION_ID_LEN: usize = 256;
/// State for a client connection
pub struct ClientConn {
/// Client IP (either from socket, or configured proxy header
client_ip: String,
client_ip_addr: String,
/// Unique client identifier generated at connection time
client_id: Uuid,
/// The current set of active client subscriptions
@ -32,22 +32,22 @@ impl Default for ClientConn {
impl ClientConn {
/// Create a new, empty connection state.
#[must_use]
pub fn new(client_ip: String) -> Self {
pub fn new(client_ip_addr: String) -> Self {
let client_id = Uuid::new_v4();
ClientConn {
client_ip,
client_ip_addr,
client_id,
subscriptions: HashMap::new(),
max_subs: 32,
}
}
pub fn subscriptions(&self) -> &HashMap<String, Subscription> {
#[must_use] pub fn subscriptions(&self) -> &HashMap<String, Subscription> {
&self.subscriptions
}
/// Check if the given subscription already exists
pub fn has_subscription(&self, sub: &Subscription) -> bool {
#[must_use] pub fn has_subscription(&self, sub: &Subscription) -> bool {
self.subscriptions.values().any(|x| x == sub)
}
@ -60,7 +60,7 @@ impl ClientConn {
#[must_use]
pub fn ip(&self) -> &str {
&self.client_ip
&self.client_ip_addr
}
/// Add a new subscription for this connection.

997
src/db.rs

File diff suppressed because it is too large Load Diff

View File

@ -84,7 +84,7 @@ pub struct ConditionQuery {
}
impl ConditionQuery {
pub fn allows_event(&self, event: &Event) -> bool {
#[must_use] pub fn allows_event(&self, event: &Event) -> bool {
// check each condition, to ensure that the event complies
// with the restriction.
for c in &self.conditions {
@ -101,7 +101,7 @@ impl ConditionQuery {
}
// Verify that the delegator approved the delegation; return a ConditionQuery if so.
pub fn validate_delegation(
#[must_use] pub fn validate_delegation(
delegator: &str,
delegatee: &str,
cond_query: &str,
@ -133,8 +133,8 @@ pub fn validate_delegation(
}
/// Parsed delegation condition
/// see https://github.com/nostr-protocol/nips/pull/28#pullrequestreview-1084903800
/// An example complex condition would be: kind=1,2,3&created_at<1665265999
/// see <https://github.com/nostr-protocol/nips/pull/28#pullrequestreview-1084903800>
/// An example complex condition would be: `kind=1,2,3&created_at<1665265999`
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
pub struct Condition {
pub field: Field,
@ -144,7 +144,7 @@ pub struct Condition {
impl Condition {
/// Check if this condition allows the given event to be delegated
pub fn allows_event(&self, event: &Event) -> bool {
#[must_use] pub fn allows_event(&self, event: &Event) -> bool {
// determine what the right-hand side of the operator is
let resolved_field = match &self.field {
Field::Kind => event.kind,
@ -323,7 +323,7 @@ mod tests {
Condition {
field: Field::CreatedAt,
operator: Operator::LessThan,
values: vec![1665867123],
values: vec![1_665_867_123],
},
],
};

View File

@ -1,6 +1,6 @@
//! Event parsing and validation
use crate::delegation::validate_delegation;
use crate::error::Error::*;
use crate::error::Error::{CommandUnknownError, EventCouldNotCanonicalize, EventInvalidId, EventInvalidSignature, EventMalformedPubkey};
use crate::error::Result;
use crate::nip05;
use crate::utils::unix_time;
@ -28,7 +28,7 @@ pub struct EventCmd {
}
impl EventCmd {
pub fn event_id(&self) -> &str {
#[must_use] pub fn event_id(&self) -> &str {
&self.event.id
}
}
@ -65,7 +65,7 @@ where
}
/// Attempt to form a single-char tag name.
pub fn single_char_tagname(tagname: &str) -> Option<char> {
#[must_use] pub fn single_char_tagname(tagname: &str) -> Option<char> {
// We return the tag character if and only if the tagname consists
// of a single char.
let mut tagnamechars = tagname.chars();
@ -87,22 +87,22 @@ pub fn single_char_tagname(tagname: &str) -> Option<char> {
impl From<EventCmd> for Result<Event> {
fn from(ec: EventCmd) -> Result<Event> {
// ensure command is correct
if ec.cmd != "EVENT" {
Err(CommandUnknownError)
} else {
ec.event.validate().map(|_| {
if ec.cmd == "EVENT" {
ec.event.validate().map(|_| {
let mut e = ec.event;
e.build_index();
e.update_delegation();
e
})
} else {
Err(CommandUnknownError)
}
}
}
impl Event {
#[cfg(test)]
pub fn simple_event() -> Event {
#[must_use] pub fn simple_event() -> Event {
Event {
id: "0".to_owned(),
pubkey: "0".to_owned(),
@ -116,17 +116,17 @@ impl Event {
}
}
pub fn is_kind_metadata(&self) -> bool {
#[must_use] pub fn is_kind_metadata(&self) -> bool {
self.kind == 0
}
/// Should this event be replaced with newer timestamps from same author?
pub fn is_replaceable(&self) -> bool {
#[must_use] pub fn is_replaceable(&self) -> bool {
self.kind == 0 || self.kind == 3 || self.kind == 41 || (self.kind >= 10000 && self.kind < 20000)
}
/// Pull a NIP-05 Name out of the event, if one exists
pub fn get_nip05_addr(&self) -> Option<nip05::Nip05Name> {
#[must_use] pub fn get_nip05_addr(&self) -> Option<nip05::Nip05Name> {
if self.is_kind_metadata() {
// very quick check if we should attempt to parse this json
if self.content.contains("\"nip05\"") {
@ -143,7 +143,7 @@ impl Event {
// is this event delegated (properly)?
// does the signature match, and are conditions valid?
// if so, return an alternate author for the event
pub fn delegated_author(&self) -> Option<String> {
#[must_use] pub fn delegated_author(&self) -> Option<String> {
// is there a delegation tag?
let delegation_tag: Vec<String> = self
.tags
@ -151,8 +151,7 @@ impl Event {
.filter(|x| x.len() == 4)
.filter(|x| x.get(0).unwrap() == "delegation")
.take(1)
.next()?
.to_vec(); // get first tag
.next()?.clone(); // get first tag
//let delegation_tag = self.tag_values_by_name("delegation");
// delegation tags should have exactly 3 elements after the name (pubkey, condition, sig)
@ -212,24 +211,24 @@ impl Event {
}
/// Create a short event identifier, suitable for logging.
pub fn get_event_id_prefix(&self) -> String {
#[must_use] pub fn get_event_id_prefix(&self) -> String {
self.id.chars().take(8).collect()
}
pub fn get_author_prefix(&self) -> String {
#[must_use] pub fn get_author_prefix(&self) -> String {
self.pubkey.chars().take(8).collect()
}
/// Retrieve tag initial values across all tags matching the name
pub fn tag_values_by_name(&self, tag_name: &str) -> Vec<String> {
#[must_use] pub fn tag_values_by_name(&self, tag_name: &str) -> Vec<String> {
self.tags
.iter()
.filter(|x| x.len() > 1)
.filter(|x| x.get(0).unwrap() == tag_name)
.map(|x| x.get(1).unwrap().to_owned())
.map(|x| x.get(1).unwrap().clone())
.collect()
}
pub fn is_valid_timestamp(&self, reject_future_seconds: Option<usize>) -> bool {
#[must_use] pub fn is_valid_timestamp(&self, reject_future_seconds: Option<usize>) -> bool {
if let Some(allowable_future) = reject_future_seconds {
let curr_time = unix_time();
// calculate difference, plus how far future we allow
@ -291,7 +290,7 @@ impl Event {
let id = Number::from(0_u64);
c.push(serde_json::Value::Number(id));
// public key
c.push(Value::String(self.pubkey.to_owned()));
c.push(Value::String(self.pubkey.clone()));
// creation time
let created_at = Number::from(self.created_at);
c.push(serde_json::Value::Number(created_at));
@ -301,7 +300,7 @@ impl Event {
// tags
c.push(self.tags_to_canonical());
// content
c.push(Value::String(self.content.to_owned()));
c.push(Value::String(self.content.clone()));
serde_json::to_string(&Value::Array(c)).ok()
}
@ -309,11 +308,11 @@ impl Event {
fn tags_to_canonical(&self) -> Value {
let mut tags = Vec::<Value>::new();
// iterate over self tags,
for t in self.tags.iter() {
for t in &self.tags {
// each tag is a vec of strings
let mut a = Vec::<Value>::new();
for v in t.iter() {
a.push(serde_json::Value::String(v.to_owned()));
a.push(serde_json::Value::String(v.clone()));
}
tags.push(serde_json::Value::Array(a));
}
@ -321,7 +320,7 @@ impl Event {
}
/// Determine if the given tag and value set intersect with tags in this event.
pub fn generic_tag_val_intersect(&self, tagname: char, check: &HashSet<String>) -> bool {
#[must_use] pub fn generic_tag_val_intersect(&self, tagname: char, check: &HashSet<String>) -> bool {
match &self.tagidx {
// check if this is indexable tagname
Some(idx) => match idx.get(&tagname) {
@ -413,7 +412,7 @@ mod tests {
id: "999".to_owned(),
pubkey: "012345".to_owned(),
delegated_by: None,
created_at: 501234,
created_at: 501_234,
kind: 1,
tags: vec![],
content: "this is a test".to_owned(),
@ -431,7 +430,7 @@ mod tests {
id: "999".to_owned(),
pubkey: "012345".to_owned(),
delegated_by: None,
created_at: 501234,
created_at: 501_234,
kind: 1,
tags: vec![
vec!["j".to_owned(), "abc".to_owned()],
@ -458,7 +457,7 @@ mod tests {
id: "999".to_owned(),
pubkey: "012345".to_owned(),
delegated_by: None,
created_at: 501234,
created_at: 501_234,
kind: 1,
tags: vec![
vec!["j".to_owned(), "abc".to_owned()],
@ -485,7 +484,7 @@ mod tests {
id: "999".to_owned(),
pubkey: "012345".to_owned(),
delegated_by: None,
created_at: 501234,
created_at: 501_234,
kind: 1,
tags: vec![
vec!["#e".to_owned(), "aoeu".to_owned()],

View File

@ -19,7 +19,7 @@ fn is_all_fs(s: &str) -> bool {
}
/// Find the next hex sequence greater than the argument.
pub fn hex_range(s: &str) -> Option<HexSearch> {
#[must_use] pub fn hex_range(s: &str) -> Option<HexSearch> {
// handle special cases
if !is_hex(s) || s.len() > 64 {
return None;

View File

@ -37,7 +37,7 @@ impl From<config::Info> for RelayInfo {
contact: i.contact,
supported_nips: Some(vec![1, 2, 9, 11, 12, 15, 16, 20, 22]),
software: Some("https://git.sr.ht/~gheartsfield/nostr-rs-relay".to_owned()),
version: CARGO_PKG_VERSION.map(|x| x.to_owned()),
version: CARGO_PKG_VERSION.map(std::borrow::ToOwned::to_owned),
}
}
}

View File

@ -10,7 +10,7 @@ pub mod hexrange;
pub mod info;
pub mod nip05;
pub mod notice;
pub mod schema;
pub mod repo;
pub mod subscription;
pub mod utils;
// Public API for creating relays programatically

View File

@ -1,6 +1,6 @@
//! Server process
use clap::Parser;
use nostr_rs_relay::cli::*;
use nostr_rs_relay::cli::CLIArgs;
use nostr_rs_relay::config;
use nostr_rs_relay::server::start_server;
use std::sync::mpsc as syncmpsc;
@ -37,12 +37,15 @@ fn main() {
if let Some(db_dir) = db_dir_arg {
settings.database.data_directory = db_dir;
}
// we should have a 'control plane' channel to monitor and bump
// the server. this will let us do stuff like clear the database,
// shutdown, etc.; for now all this does is initiate shutdown if
// `()` is sent. This will change in the future, this is just a
// stopgap to shutdown the relay when it is used as a library.
let (_, ctrl_rx): (MpscSender<()>, MpscReceiver<()>) = syncmpsc::channel();
// run this in a new thread
let handle = thread::spawn(|| {
// we should have a 'control plane' channel to monitor and bump the server.
// this will let us do stuff like clear the database, shutdown, etc.
let _svr = start_server(settings, ctrl_rx);
let handle = thread::spawn(move || {
let _svr = start_server(&settings, ctrl_rx);
});
// block on nostr thread to finish.
handle.join().unwrap();

View File

@ -5,16 +5,14 @@
//! consumes a stream of metadata events, and keeps a database table
//! updated with the current NIP-05 verification status.
use crate::config::VerifiedUsers;
use crate::db;
use crate::error::{Error, Result};
use crate::event::Event;
use crate::utils::unix_time;
use crate::repo::NostrRepo;
use std::sync::Arc;
use hyper::body::HttpBody;
use hyper::client::connect::HttpConnector;
use hyper::Client;
use hyper_tls::HttpsConnector;
use rand::Rng;
use rusqlite::params;
use std::time::Duration;
use std::time::Instant;
use std::time::SystemTime;
@ -23,14 +21,12 @@ use tracing::{debug, info, warn};
/// NIP-05 verifier state
pub struct Verifier {
/// Repository for saving/retrieving events and records
repo: Arc<dyn NostrRepo>,
/// Metadata events for us to inspect
metadata_rx: tokio::sync::broadcast::Receiver<Event>,
/// Newly validated events get written and then broadcast on this channel to subscribers
event_tx: tokio::sync::broadcast::Sender<Event>,
/// SQLite read query pool
read_pool: db::SqlitePool,
/// SQLite write query pool
write_pool: db::SqlitePool,
/// Settings
settings: crate::config::Settings,
/// HTTP client
@ -52,7 +48,7 @@ pub struct Nip05Name {
impl Nip05Name {
/// Does this name represent the entire domain?
pub fn is_domain_only(&self) -> bool {
#[must_use] pub fn is_domain_only(&self) -> bool {
self.local == "_"
}
@ -73,16 +69,11 @@ impl std::convert::TryFrom<&str> for Nip05Name {
fn try_from(inet: &str) -> Result<Self, Self::Error> {
// break full name at the @ boundary.
let components: Vec<&str> = inet.split('@').collect();
if components.len() != 2 {
Err(Error::CustomError("too many/few components".to_owned()))
} else {
// check if local name is valid
if components.len() == 2 {
// check if local name is valid
let local = components[0];
let domain = components[1];
if local
.chars()
.all(|x| x.is_alphanumeric() || x == '_' || x == '-' || x == '.')
{
if local.chars().all(|x| x.is_alphanumeric() || x == '_' || x == '-' || x == '.') {
if domain
.chars()
.all(|x| x.is_alphanumeric() || x == '-' || x == '.')
@ -101,6 +92,8 @@ impl std::convert::TryFrom<&str> for Nip05Name {
"invalid character in local part".to_owned(),
))
}
} else {
Err(Error::CustomError("too many/few components".to_owned()))
}
}
}
@ -111,55 +104,30 @@ impl std::fmt::Display for Nip05Name {
}
}
// Current time, with a slight foward jitter in seconds
fn now_jitter(sec: u64) -> u64 {
// random time between now, and 10min in future.
let mut rng = rand::thread_rng();
let jitter_amount = rng.gen_range(0..sec);
let now = unix_time();
now.saturating_add(jitter_amount)
}
/// Check if the specified username and address are present and match in this response body
fn body_contains_user(username: &str, address: &str, bytes: hyper::body::Bytes) -> Result<bool> {
fn body_contains_user(username: &str, address: &str, bytes: &hyper::body::Bytes) -> Result<bool> {
// convert the body into json
let body: serde_json::Value = serde_json::from_slice(&bytes)?;
// ensure we have a names object.
let names_map = body
.as_object()
.and_then(|x| x.get("names"))
.and_then(|x| x.as_object())
.and_then(serde_json::Value::as_object)
.ok_or_else(|| Error::CustomError("not a map".to_owned()))?;
// get the pubkey for the requested user
let check_name = names_map.get(username).and_then(|x| x.as_str());
let check_name = names_map.get(username).and_then(serde_json::Value::as_str);
// ensure the address is a match
Ok(check_name.map(|x| x == address).unwrap_or(false))
Ok(check_name.map_or(false, |x| x == address))
}
impl Verifier {
pub fn new(
repo: Arc<dyn NostrRepo>,
metadata_rx: tokio::sync::broadcast::Receiver<Event>,
event_tx: tokio::sync::broadcast::Sender<Event>,
settings: crate::config::Settings,
) -> Result<Self> {
info!("creating NIP-05 verifier");
// build a database connection for reading and writing.
let write_pool = db::build_pool(
"nip05 writer",
&settings,
rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
1, // min conns
4, // max conns
true, // wait for DB
);
let read_pool = db::build_pool(
"nip05 reader",
&settings,
rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
1, // min conns
8, // max conns
true, // wait for DB
);
// setup hyper client
let https = HttpsConnector::new();
let client = Client::builder().build::<_, hyper::Body>(https);
@ -175,10 +143,9 @@ impl Verifier {
// duration.
let reverify_interval = tokio::time::interval(http_wait_duration);
Ok(Verifier {
repo,
metadata_rx,
event_tx,
read_pool,
write_pool,
settings,
client,
wait_after_finish,
@ -246,44 +213,40 @@ impl Verifier {
let response_fut = self.client.request(req);
// HTTP request with timeout
match tokio::time::timeout(Duration::from_secs(5), response_fut).await {
Ok(response_res) => {
// limit size of verification document to 1MB.
const MAX_ALLOWED_RESPONSE_SIZE: u64 = 1024 * 1024;
let response = response_res?;
// determine content length from response
let response_content_length = match response.body().size_hint().upper() {
Some(v) => v,
None => MAX_ALLOWED_RESPONSE_SIZE + 1, // reject missing content length
};
// TODO: test how hyper handles the client providing an inaccurate content-length.
if response_content_length <= MAX_ALLOWED_RESPONSE_SIZE {
let (parts, body) = response.into_parts();
// TODO: consider redirects
if parts.status == http::StatusCode::OK {
// parse body, determine if the username / key / address is present
let body_bytes = hyper::body::to_bytes(body).await?;
let body_matches = body_contains_user(&nip.local, pubkey, body_bytes)?;
if body_matches {
return Ok(UserWebVerificationStatus::Verified);
}
// successful response, parsed as a nip-05
// document, but this name/pubkey was not
// present.
return Ok(UserWebVerificationStatus::Unverified);
if let Ok(response_res) = tokio::time::timeout(Duration::from_secs(5), response_fut).await {
// limit size of verification document to 1MB.
const MAX_ALLOWED_RESPONSE_SIZE: u64 = 1024 * 1024;
let response = response_res?;
// determine content length from response
let response_content_length = match response.body().size_hint().upper() {
Some(v) => v,
None => MAX_ALLOWED_RESPONSE_SIZE + 1, // reject missing content length
};
// TODO: test how hyper handles the client providing an inaccurate content-length.
if response_content_length <= MAX_ALLOWED_RESPONSE_SIZE {
let (parts, body) = response.into_parts();
// TODO: consider redirects
if parts.status == http::StatusCode::OK {
// parse body, determine if the username / key / address is present
let body_bytes = hyper::body::to_bytes(body).await?;
let body_matches = body_contains_user(&nip.local, pubkey, &body_bytes)?;
if body_matches {
return Ok(UserWebVerificationStatus::Verified);
}
} else {
info!(
"content length missing or exceeded limits for account: {:?}",
nip.to_string()
);
// successful response, parsed as a nip-05
// document, but this name/pubkey was not
// present.
return Ok(UserWebVerificationStatus::Unverified);
}
} else {
info!(
"content length missing or exceeded limits for account: {:?}",
nip.to_string()
);
}
Err(_) => {
info!("timeout verifying account {:?}", nip);
return Ok(UserWebVerificationStatus::Unknown);
}
} else {
info!("timeout verifying account {:?}", nip);
return Ok(UserWebVerificationStatus::Unknown);
}
Ok(UserWebVerificationStatus::Unknown)
}
@ -309,7 +272,7 @@ impl Verifier {
if let Some(naddr) = e.get_nip05_addr() {
info!("got metadata event for ({:?},{:?})", naddr.to_string() ,e.get_author_prefix());
// Process a new author, checking if they are verified:
let check_verified = get_latest_user_verification(self.read_pool.get().expect("could not get connection"), &e.pubkey).await;
let check_verified = self.repo.get_latest_user_verification(&e.pubkey).await;
// ensure the event we got is more recent than the one we have, otherwise we can ignore it.
if let Ok(last_check) = check_verified {
if e.created_at <= last_check.event_created {
@ -370,7 +333,7 @@ impl Verifier {
.duration_since(SystemTime::UNIX_EPOCH)
.map(|x| x.as_secs())
.unwrap_or(0);
let vr = get_oldest_user_verification(self.read_pool.get()?, earliest_epoch).await;
let vr = self.repo.get_oldest_user_verification(earliest_epoch).await;
match vr {
Ok(ref v) => {
let new_status = self.get_web_verification(&v.name, &v.address).await;
@ -378,8 +341,10 @@ impl Verifier {
UserWebVerificationStatus::Verified => {
// freshly verified account, update the
// timestamp.
self.update_verification_record(self.write_pool.get()?, v)
self.repo.update_verification_timestamp(v.rowid)
.await?;
info!("verification updated for {}", v.to_string());
}
UserWebVerificationStatus::DomainNotAllowed
| UserWebVerificationStatus::Unknown => {
@ -394,18 +359,19 @@ impl Verifier {
"giving up on verifying {:?} after {} failures",
v.name, v.failure_count
);
self.delete_verification_record(self.write_pool.get()?, v)
self.repo.delete_verification(v.rowid)
.await?;
} else {
// record normal failure, incrementing failure count
self.fail_verification_record(self.write_pool.get()?, v)
.await?;
info!("verification failed for {}", v.to_string());
self.repo.fail_verification(v.rowid).await?;
}
}
UserWebVerificationStatus::Unverified => {
// domain has removed the verification, drop
// the record on our side.
self.delete_verification_record(self.write_pool.get()?, v)
info!("verification rescinded for {}", v.to_string());
self.repo.delete_verification(v.rowid)
.await?;
}
}
@ -426,80 +392,6 @@ impl Verifier {
Ok(())
}
/// Reset the verification timestamp on a VerificationRecord
pub async fn update_verification_record(
&mut self,
mut conn: db::PooledConnection,
vr: &VerificationRecord,
) -> Result<()> {
let vr_id = vr.rowid;
let vr_str = vr.to_string();
tokio::task::spawn_blocking(move || {
// add some jitter to the verification to prevent everything from stacking up together.
let verif_time = now_jitter(600);
let tx = conn.transaction()?;
{
// update verification time and reset any failure count
let query =
"UPDATE user_verification SET verified_at=?, failure_count=0 WHERE id=?";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![verif_time, vr_id])?;
}
tx.commit()?;
info!("verification updated for {}", vr_str);
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Reset the failure timestamp on a VerificationRecord
pub async fn fail_verification_record(
&mut self,
mut conn: db::PooledConnection,
vr: &VerificationRecord,
) -> Result<()> {
let vr_id = vr.rowid;
let vr_str = vr.to_string();
let fail_count = vr.failure_count.saturating_add(1);
tokio::task::spawn_blocking(move || {
// add some jitter to the verification to prevent everything from stacking up together.
let fail_time = now_jitter(600);
let tx = conn.transaction()?;
{
let query = "UPDATE user_verification SET failed_at=?, failure_count=? WHERE id=?";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![fail_time, fail_count, vr_id])?;
}
tx.commit()?;
info!("verification failed for {}", vr_str);
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Delete a VerificationRecord that is no longer valid
pub async fn delete_verification_record(
&mut self,
mut conn: db::PooledConnection,
vr: &VerificationRecord,
) -> Result<()> {
let vr_id = vr.rowid;
let vr_str = vr.to_string();
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
{
let query = "DELETE FROM user_verification WHERE id=?;";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![vr_id])?;
}
tx.commit()?;
info!("verification rescinded for {}", vr_str);
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Persist an event, create a verification record, and broadcast.
// TODO: have more event-writing logic handled in the db module.
// Right now, these events avoid the rate limit. That is
@ -513,27 +405,27 @@ impl Verifier {
// disabled/passive, the event has already been persisted.
let should_write_event = self.settings.verified_users.is_enabled();
if should_write_event {
match db::write_event(&mut self.write_pool.get()?, event) {
Ok(updated) => {
if updated != 0 {
info!(
"persisted event (new verified pubkey): {:?} in {:?}",
event.get_event_id_prefix(),
start.elapsed()
);
self.event_tx.send(event.clone()).ok();
}
}
Err(err) => {
warn!("event insert failed: {:?}", err);
if let Error::SqlError(r) = err {
warn!("because: : {:?}", r);
}
}
}
match self.repo.write_event(event).await {
Ok(updated) => {
if updated != 0 {
info!(
"persisted event (new verified pubkey): {:?} in {:?}",
event.get_event_id_prefix(),
start.elapsed()
);
self.event_tx.send(event.clone()).ok();
}
}
Err(err) => {
warn!("event insert failed: {:?}", err);
if let Error::SqlError(r) = err {
warn!("because: : {:?}", r);
}
}
}
}
// write the verification record
save_verification_record(self.write_pool.get()?, event, name).await?;
self.repo.create_verification_record(&event.id, name).await?;
Ok(())
}
}
@ -563,7 +455,7 @@ pub struct VerificationRecord {
/// Check with settings to determine if a given domain is allowed to
/// publish.
pub fn is_domain_allowed(
#[must_use] pub fn is_domain_allowed(
domain: &str,
whitelist: &Option<Vec<String>>,
blacklist: &Option<Vec<String>>,
@ -583,7 +475,7 @@ pub fn is_domain_allowed(
impl VerificationRecord {
/// Check if the record is recent enough to be considered valid,
/// and the domain is allowed.
pub fn is_valid(&self, verified_users_settings: &VerifiedUsers) -> bool {
#[must_use] pub fn is_valid(&self, verified_users_settings: &VerifiedUsers) -> bool {
//let settings = SETTINGS.read().unwrap();
// how long a verification record is good for
let nip05_expiration = &verified_users_settings.verify_expiration_duration;
@ -630,130 +522,6 @@ impl std::fmt::Display for VerificationRecord {
}
}
/// Create a new verification record based on an event
pub async fn save_verification_record(
mut conn: db::PooledConnection,
event: &Event,
name: &str,
) -> Result<()> {
let e = hex::decode(&event.id).ok();
let n = name.to_owned();
let a_prefix = event.get_author_prefix();
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
{
// if we create a /new/ one, we should get rid of any old ones. or group the new ones by name and only consider the latest.
let query = "INSERT INTO user_verification (metadata_event, name, verified_at) VALUES ((SELECT id from event WHERE event_hash=?), ?, strftime('%s','now'));";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![e, n])?;
// get the row ID
let v_id = tx.last_insert_rowid();
// delete everything else by this name
let del_query = "DELETE FROM user_verification WHERE name = ? AND id != ?;";
let mut del_stmt = tx.prepare(del_query)?;
let count = del_stmt.execute(params![n,v_id])?;
if count > 0 {
info!("removed {} old verification records for ({:?},{:?})", count, n, a_prefix);
}
}
tx.commit()?;
info!("saved new verification record for ({:?},{:?})", n, a_prefix);
let ok: Result<()> = Ok(());
ok
}).await?
}
/// Retrieve the most recent verification record for a given pubkey (async).
pub async fn get_latest_user_verification(
conn: db::PooledConnection,
pubkey: &str,
) -> Result<VerificationRecord> {
let p = pubkey.to_owned();
tokio::task::spawn_blocking(move || query_latest_user_verification(conn, p)).await?
}
/// Query database for the latest verification record for a given pubkey.
pub fn query_latest_user_verification(
mut conn: db::PooledConnection,
pubkey: String,
) -> Result<VerificationRecord> {
let tx = conn.transaction()?;
let query = "SELECT v.id, v.name, e.event_hash, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v LEFT JOIN event e ON e.id=v.metadata_event WHERE e.author=? ORDER BY e.created_at DESC, v.verified_at DESC, v.failed_at DESC LIMIT 1;";
let mut stmt = tx.prepare_cached(query)?;
let fields = stmt.query_row(params![hex::decode(&pubkey).ok()], |r| {
let rowid: u64 = r.get(0)?;
let rowname: String = r.get(1)?;
let eventid: Vec<u8> = r.get(2)?;
let created_at: u64 = r.get(3)?;
// create a tuple since we can't throw non-rusqlite errors in this closure
Ok((
rowid,
rowname,
eventid,
created_at,
r.get(4).ok(),
r.get(5).ok(),
r.get(6)?,
))
})?;
Ok(VerificationRecord {
rowid: fields.0,
name: Nip05Name::try_from(&fields.1[..])?,
address: pubkey,
event: hex::encode(fields.2),
event_created: fields.3,
last_success: fields.4,
last_failure: fields.5,
failure_count: fields.6,
})
}
/// Retrieve the oldest user verification (async)
pub async fn get_oldest_user_verification(
conn: db::PooledConnection,
earliest: u64,
) -> Result<VerificationRecord> {
tokio::task::spawn_blocking(move || query_oldest_user_verification(conn, earliest)).await?
}
pub fn query_oldest_user_verification(
mut conn: db::PooledConnection,
earliest: u64,
) -> Result<VerificationRecord> {
let tx = conn.transaction()?;
let query = "SELECT v.id, v.name, e.event_hash, e.author, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v INNER JOIN event e ON e.id=v.metadata_event WHERE (v.verified_at < ? OR v.verified_at IS NULL) AND (v.failed_at < ? OR v.failed_at IS NULL) ORDER BY v.verified_at ASC, v.failed_at ASC LIMIT 1;";
let mut stmt = tx.prepare_cached(query)?;
let fields = stmt.query_row(params![earliest, earliest], |r| {
let rowid: u64 = r.get(0)?;
let rowname: String = r.get(1)?;
let eventid: Vec<u8> = r.get(2)?;
let pubkey: Vec<u8> = r.get(3)?;
let created_at: u64 = r.get(4)?;
// create a tuple since we can't throw non-rusqlite errors in this closure
Ok((
rowid,
rowname,
eventid,
pubkey,
created_at,
r.get(5).ok(),
r.get(6).ok(),
r.get(7)?,
))
})?;
let vr = VerificationRecord {
rowid: fields.0,
name: Nip05Name::try_from(&fields.1[..])?,
address: hex::encode(fields.3),
event: hex::encode(fields.2),
event_created: fields.4,
last_success: fields.5,
last_failure: fields.6,
failure_count: fields.7,
};
Ok(vr)
}
#[cfg(test)]
mod tests {
use super::*;
@ -762,7 +530,7 @@ mod tests {
fn local_from_inet() {
let addr = "bob@example.com";
let parsed = Nip05Name::try_from(addr);
assert!(!parsed.is_err());
assert!(parsed.is_ok());
let v = parsed.unwrap();
assert_eq!(v.local, "bob");
assert_eq!(v.domain, "example.com");

View File

@ -19,18 +19,14 @@ pub enum Notice {
}
impl EventResultStatus {
pub fn to_bool(&self) -> bool {
#[must_use] pub fn to_bool(&self) -> bool {
match self {
Self::Saved => true,
Self::Duplicate => true,
Self::Invalid => false,
Self::Blocked => false,
Self::RateLimited => false,
Self::Error => false,
Self::Duplicate | Self::Saved => true,
Self::Invalid |Self::Blocked | Self::RateLimited | Self::Error => false,
}
}
pub fn prefix(&self) -> &'static str {
#[must_use] pub fn prefix(&self) -> &'static str {
match self {
Self::Saved => "saved",
Self::Duplicate => "duplicate",
@ -47,7 +43,7 @@ impl Notice {
// Notice::err_msg(format!("{}", err), id)
//}
pub fn message(msg: String) -> Notice {
#[must_use] pub fn message(msg: String) -> Notice {
Notice::Message(msg)
}
@ -56,27 +52,27 @@ impl Notice {
Notice::EventResult(EventResult { id, msg, status })
}
pub fn invalid(id: String, msg: &str) -> Notice {
#[must_use] pub fn invalid(id: String, msg: &str) -> Notice {
Notice::prefixed(id, msg, EventResultStatus::Invalid)
}
pub fn blocked(id: String, msg: &str) -> Notice {
#[must_use] pub fn blocked(id: String, msg: &str) -> Notice {
Notice::prefixed(id, msg, EventResultStatus::Blocked)
}
pub fn rate_limited(id: String, msg: &str) -> Notice {
#[must_use] pub fn rate_limited(id: String, msg: &str) -> Notice {
Notice::prefixed(id, msg, EventResultStatus::RateLimited)
}
pub fn duplicate(id: String) -> Notice {
#[must_use] pub fn duplicate(id: String) -> Notice {
Notice::prefixed(id, "", EventResultStatus::Duplicate)
}
pub fn error(id: String, msg: &str) -> Notice {
#[must_use] pub fn error(id: String, msg: &str) -> Notice {
Notice::prefixed(id, msg, EventResultStatus::Error)
}
pub fn saved(id: String) -> Notice {
#[must_use] pub fn saved(id: String) -> Notice {
Notice::EventResult(EventResult {
id,
msg: "".into(),

67
src/repo/mod.rs Normal file
View File

@ -0,0 +1,67 @@
use crate::db::QueryResult;
use crate::error::Result;
use crate::event::Event;
use crate::nip05::VerificationRecord;
use crate::subscription::Subscription;
use crate::utils::unix_time;
use async_trait::async_trait;
use rand::Rng;
pub mod sqlite;
pub mod sqlite_migration;
#[async_trait]
pub trait NostrRepo: Send + Sync {
/// Start the repository (any initialization or maintenance tasks can be kicked off here)
async fn start(&self) -> Result<()>;
/// Run migrations and return current version
async fn migrate_up(&self) -> Result<usize>;
/// Persist event to database
async fn write_event(&self, e: &Event) -> Result<u64>;
/// Perform a database query using a subscription.
///
/// The [`Subscription`] is converted into a SQL query. Each result
/// is published on the `query_tx` channel as it is returned. If a
/// message becomes available on the `abandon_query_rx` channel, the
/// query is immediately aborted.
async fn query_subscription(
&self,
sub: Subscription,
client_id: String,
query_tx: tokio::sync::mpsc::Sender<QueryResult>,
mut abandon_query_rx: tokio::sync::oneshot::Receiver<()>,
) -> Result<()>;
/// Perform normal maintenance
async fn optimize_db(&self) -> Result<()>;
/// Create a new verification record connected to a specific event
async fn create_verification_record(&self, event_id: &str, name: &str) -> Result<()>;
/// Update verification timestamp
async fn update_verification_timestamp(&self, id: u64) -> Result<()>;
/// Update verification record as failed
async fn fail_verification(&self, id: u64) -> Result<()>;
/// Delete verification record
async fn delete_verification(&self, id: u64) -> Result<()>;
/// Get the latest verification record for a given pubkey.
async fn get_latest_user_verification(&self, pub_key: &str) -> Result<VerificationRecord>;
/// Get oldest verification before timestamp
async fn get_oldest_user_verification(&self, before: u64) -> Result<VerificationRecord>;
}
// Current time, with a slight forward jitter in seconds
pub(crate) fn now_jitter(sec: u64) -> u64 {
// random time between now, and 10min in future.
let mut rng = rand::thread_rng();
let jitter_amount = rng.gen_range(0..sec);
let now = unix_time();
now.saturating_add(jitter_amount)
}

948
src/repo/sqlite.rs Normal file
View File

@ -0,0 +1,948 @@
//! Event persistence and querying
//use crate::config::SETTINGS;
use crate::config::Settings;
use crate::error::Result;
use crate::event::{single_char_tagname, Event};
use crate::hexrange::hex_range;
use crate::hexrange::HexSearch;
use crate::repo::sqlite_migration::{STARTUP_SQL,upgrade_db};
use crate::utils::{is_hex, is_lower_hex};
use crate::nip05::{Nip05Name, VerificationRecord};
use crate::subscription::{ReqFilter, Subscription};
use hex;
use r2d2;
use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::params;
use rusqlite::types::ToSql;
use rusqlite::OpenFlags;
use tokio::sync::{Mutex, MutexGuard};
use std::fmt::Write as _;
use std::path::Path;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use std::time::Instant;
use tokio::task;
use tracing::{debug, info, trace, warn};
use async_trait::async_trait;
use crate::db::QueryResult;
use crate::repo::{now_jitter, NostrRepo};
pub type SqlitePool = r2d2::Pool<r2d2_sqlite::SqliteConnectionManager>;
pub type PooledConnection = r2d2::PooledConnection<r2d2_sqlite::SqliteConnectionManager>;
pub const DB_FILE: &str = "nostr.db";
#[derive(Clone)]
pub struct SqliteRepo {
/// Pool for reading events and NIP-05 status
read_pool: SqlitePool,
/// Pool for writing events and NIP-05 verification
write_pool: SqlitePool,
/// Pool for performing checkpoints/optimization
maint_pool: SqlitePool,
/// Flag to indicate a checkpoint is underway
checkpoint_in_progress: Arc<Mutex<u64>>,
/// Flag to limit writer concurrency
write_in_progress: Arc<Mutex<u64>>,
}
impl SqliteRepo {
// build all the pools needed
#[must_use] pub fn new(settings: &Settings) -> SqliteRepo {
let maint_pool = build_pool(
"maintenance",
settings,
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
1,
2,
true,
);
let read_pool = build_pool(
"reader",
settings,
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
settings.database.min_conn,
settings.database.max_conn,
true,
);
let write_pool = build_pool(
"writer",
settings,
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
1,
2,
false,
);
// this is used to block new reads during critical checkpoints
let checkpoint_in_progress = Arc::new(Mutex::new(0));
// SQLite can only effectively write single threaded, so don't
// block multiple worker threads unnecessarily.
let write_in_progress = Arc::new(Mutex::new(0));
SqliteRepo {
read_pool,
write_pool,
maint_pool,
checkpoint_in_progress,
write_in_progress,
}
}
/// Persist an event to the database, returning rows added.
pub fn persist_event(conn: &mut PooledConnection, e: &Event) -> Result<u64> {
// enable auto vacuum
conn.execute_batch("pragma auto_vacuum = FULL")?;
// start transaction
let tx = conn.transaction()?;
// get relevant fields from event and convert to blobs.
let id_blob = hex::decode(&e.id).ok();
let pubkey_blob: Option<Vec<u8>> = hex::decode(&e.pubkey).ok();
let delegator_blob: Option<Vec<u8>> = e.delegated_by.as_ref().and_then(|d| hex::decode(d).ok());
let event_str = serde_json::to_string(&e).ok();
// check for replaceable events that would hide this one; we won't even attempt to insert these.
if e.is_replaceable() {
let repl_count = tx.query_row(
"SELECT e.id FROM event e INDEXED BY author_index WHERE e.author=? AND e.kind=? AND e.created_at > ? LIMIT 1;",
params![pubkey_blob, e.kind, e.created_at], |row| row.get::<usize, usize>(0));
if repl_count.ok().is_some() {
return Ok(0);
}
}
// ignore if the event hash is a duplicate.
let mut ins_count = tx.execute(
"INSERT OR IGNORE INTO event (event_hash, created_at, kind, author, delegated_by, content, first_seen, hidden) VALUES (?1, ?2, ?3, ?4, ?5, ?6, strftime('%s','now'), FALSE);",
params![id_blob, e.created_at, e.kind, pubkey_blob, delegator_blob, event_str]
)? as u64;
if ins_count == 0 {
// if the event was a duplicate, no need to insert event or
// pubkey references.
tx.rollback().ok();
return Ok(ins_count);
}
// remember primary key of the event most recently inserted.
let ev_id = tx.last_insert_rowid();
// add all tags to the tag table
for tag in &e.tags {
// ensure we have 2 values.
if tag.len() >= 2 {
let tagname = &tag[0];
let tagval = &tag[1];
// only single-char tags are searchable
let tagchar_opt = single_char_tagname(tagname);
match &tagchar_opt {
Some(_) => {
// if tagvalue is lowercase hex;
if is_lower_hex(tagval) && (tagval.len() % 2 == 0) {
tx.execute(
"INSERT OR IGNORE INTO tag (event_id, name, value_hex) VALUES (?1, ?2, ?3)",
params![ev_id, &tagname, hex::decode(tagval).ok()],
)?;
} else {
tx.execute(
"INSERT OR IGNORE INTO tag (event_id, name, value) VALUES (?1, ?2, ?3)",
params![ev_id, &tagname, &tagval],
)?;
}
}
None => {}
}
}
}
// if this event is replaceable update, remove other replaceable
// event with the same kind from the same author that was issued
// earlier than this.
if e.is_replaceable() {
let author = hex::decode(&e.pubkey).ok();
// this is a backwards check - hide any events that were older.
let update_count = tx.execute(
"DELETE FROM event WHERE kind=? and author=? and id NOT IN (SELECT id FROM event INDEXED BY author_kind_index WHERE kind=? AND author=? ORDER BY created_at DESC LIMIT 1)",
params![e.kind, author, e.kind, author],
)?;
if update_count > 0 {
info!(
"removed {} older replaceable kind {} events for author: {:?}",
update_count,
e.kind,
e.get_author_prefix()
);
}
}
// if this event is a deletion, hide the referenced events from the same author.
if e.kind == 5 {
let event_candidates = e.tag_values_by_name("e");
// first parameter will be author
let mut params: Vec<Box<dyn ToSql>> = vec![Box::new(hex::decode(&e.pubkey)?)];
event_candidates
.iter()
.filter(|x| is_hex(x) && x.len() == 64)
.filter_map(|x| hex::decode(x).ok())
.for_each(|x| params.push(Box::new(x)));
let query = format!(
"UPDATE event SET hidden=TRUE WHERE kind!=5 AND author=? AND event_hash IN ({})",
repeat_vars(params.len() - 1)
);
let mut stmt = tx.prepare(&query)?;
let update_count = stmt.execute(rusqlite::params_from_iter(params))?;
info!(
"hid {} deleted events for author {:?}",
update_count,
e.get_author_prefix()
);
} else {
// check if a deletion has already been recorded for this event.
// Only relevant for non-deletion events
let del_count = tx.query_row(
"SELECT e.id FROM event e LEFT JOIN tag t ON e.id=t.event_id WHERE e.author=? AND t.name='e' AND e.kind=5 AND t.value_hex=? LIMIT 1;",
params![pubkey_blob, id_blob], |row| row.get::<usize, usize>(0));
// check if a the query returned a result, meaning we should
// hid the current event
if del_count.ok().is_some() {
// a deletion already existed, mark original event as hidden.
info!(
"hid event: {:?} due to existing deletion by author: {:?}",
e.get_event_id_prefix(),
e.get_author_prefix()
);
let _update_count =
tx.execute("UPDATE event SET hidden=TRUE WHERE id=?", params![ev_id])?;
// event was deleted, so let caller know nothing new
// arrived, preventing this from being sent to active
// subscriptions
ins_count = 0;
}
}
tx.commit()?;
Ok(ins_count)
}
}
#[async_trait]
impl NostrRepo for SqliteRepo {
async fn start(&self) -> Result<()> {
db_checkpoint_task(self.maint_pool.clone(), Duration::from_secs(60), self.checkpoint_in_progress.clone()).await
}
async fn migrate_up(&self) -> Result<usize> {
let _write_guard = self.write_in_progress.lock().await;
let mut conn = self.write_pool.get()?;
task::spawn_blocking(move || {
upgrade_db(&mut conn)
}).await?
}
/// Persist event to database
async fn write_event(&self, e: &Event) -> Result<u64> {
let _write_guard = self.write_in_progress.lock().await;
// spawn a blocking thread
let mut conn = self.write_pool.get()?;
let e = e.clone();
task::spawn_blocking(move || {
SqliteRepo::persist_event(&mut conn, &e)
}).await?
}
/// Perform a database query using a subscription.
///
/// The [`Subscription`] is converted into a SQL query. Each result
/// is published on the `query_tx` channel as it is returned. If a
/// message becomes available on the `abandon_query_rx` channel, the
/// query is immediately aborted.
async fn query_subscription(
&self,
sub: Subscription,
client_id: String,
query_tx: tokio::sync::mpsc::Sender<QueryResult>,
mut abandon_query_rx: tokio::sync::oneshot::Receiver<()>,
) -> Result<()> {
let pre_spawn_start = Instant::now();
let self=self.clone();
task::spawn_blocking(move || {
{
// if we are waiting on a checkpoint, stop until it is complete
let _x = self.checkpoint_in_progress.blocking_lock();
}
let db_queue_time = pre_spawn_start.elapsed();
// if the queue time was very long (>5 seconds), spare the DB and abort.
if db_queue_time > Duration::from_secs(5) {
info!(
"shedding DB query load queued for {:?} (cid: {}, sub: {:?})",
db_queue_time, client_id, sub.id
);
return Ok(());
}
// otherwise, report queuing time if it is slow
else if db_queue_time > Duration::from_secs(1) {
debug!(
"(slow) DB query queued for {:?} (cid: {}, sub: {:?})",
db_queue_time, client_id, sub.id
);
}
let start = Instant::now();
let mut row_count: usize = 0;
// generate SQL query
let (q, p, idxs) = query_from_sub(&sub);
let sql_gen_elapsed = start.elapsed();
if sql_gen_elapsed > Duration::from_millis(10) {
debug!("SQL (slow) generated in {:?}", start.elapsed());
}
// cutoff for displaying slow queries
let slow_cutoff = Duration::from_millis(2000);
// any client that doesn't cause us to generate new rows in 5
// seconds gets dropped.
let abort_cutoff = Duration::from_secs(5);
let start = Instant::now();
let mut slow_first_event;
let mut last_successful_send = Instant::now();
if let Ok(mut conn) = self.read_pool.get() {
// execute the query.
// make the actual SQL query (with parameters inserted) available
conn.trace(Some(|x| {trace!("SQL trace: {:?}", x)}));
let mut stmt = conn.prepare_cached(&q)?;
let mut event_rows = stmt.query(rusqlite::params_from_iter(p))?;
let mut first_result = true;
while let Some(row) = event_rows.next()? {
let first_event_elapsed = start.elapsed();
slow_first_event = first_event_elapsed >= slow_cutoff;
if first_result {
debug!(
"first result in {:?} (cid: {}, sub: {:?}) [used indexes: {:?}]",
first_event_elapsed, client_id, sub.id, idxs
);
first_result = false;
}
// logging for slow queries; show sub and SQL.
// to reduce logging; only show 1/16th of clients (leading 0)
if row_count == 0 && slow_first_event && client_id.starts_with('0') {
debug!(
"query req (slow): {:?} (cid: {}, sub: {:?})",
sub, client_id, sub.id
);
}
// check if a checkpoint is trying to run, and abort
if row_count % 100 == 0 {
{
if self.checkpoint_in_progress.try_lock().is_err() {
// lock was held, abort this query
debug!("query aborted due to checkpoint (cid: {}, sub: {:?})", client_id, sub.id);
return Ok(());
}
}
}
// check if this is still active; every 100 rows
if row_count % 100 == 0 && abandon_query_rx.try_recv().is_ok() {
debug!("query aborted (cid: {}, sub: {:?})", client_id, sub.id);
return Ok(());
}
row_count += 1;
let event_json = row.get(0)?;
loop {
if query_tx.capacity() != 0 {
// we have capacity to add another item
break;
}
// the queue is full
trace!("db reader thread is stalled");
if last_successful_send + abort_cutoff < Instant::now() {
// the queue has been full for too long, abort
info!("aborting database query due to slow client (cid: {}, sub: {:?})",
client_id, sub.id);
let ok: Result<()> = Ok(());
return ok;
}
// check if a checkpoint is trying to run, and abort
if self.checkpoint_in_progress.try_lock().is_err() {
// lock was held, abort this query
debug!("query aborted due to checkpoint (cid: {}, sub: {:?})", client_id, sub.id);
return Ok(());
}
// give the queue a chance to clear before trying again
thread::sleep(Duration::from_millis(100));
}
// TODO: we could use try_send, but we'd have to juggle
// getting the query result back as part of the error
// result.
query_tx
.blocking_send(QueryResult {
sub_id: sub.get_id(),
event: event_json,
})
.ok();
last_successful_send = Instant::now();
}
query_tx
.blocking_send(QueryResult {
sub_id: sub.get_id(),
event: "EOSE".to_string(),
})
.ok();
debug!(
"query completed in {:?} (cid: {}, sub: {:?}, db_time: {:?}, rows: {})",
pre_spawn_start.elapsed(),
client_id,
sub.id,
start.elapsed(),
row_count
);
} else {
warn!("Could not get a database connection for querying");
}
let ok: Result<()> = Ok(());
ok
});
Ok(())
}
/// Perform normal maintenance
async fn optimize_db(&self) -> Result<()> {
let conn = self.write_pool.get()?;
task::spawn_blocking(move || {
let start = Instant::now();
conn.execute_batch("PRAGMA optimize;").ok();
info!("optimize ran in {:?}", start.elapsed());
}).await?;
Ok(())
}
/// Create a new verification record connected to a specific event
async fn create_verification_record(&self, event_id: &str, name: &str) -> Result<()> {
let e = hex::decode(event_id).ok();
let n = name.to_owned();
let mut conn = self.write_pool.get()?;
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
{
// if we create a /new/ one, we should get rid of any old ones. or group the new ones by name and only consider the latest.
let query = "INSERT INTO user_verification (metadata_event, name, verified_at) VALUES ((SELECT id from event WHERE event_hash=?), ?, strftime('%s','now'));";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![e, n])?;
// get the row ID
let v_id = tx.last_insert_rowid();
// delete everything else by this name
let del_query = "DELETE FROM user_verification WHERE name = ? AND id != ?;";
let mut del_stmt = tx.prepare(del_query)?;
let count = del_stmt.execute(params![n,v_id])?;
if count > 0 {
info!("removed {} old verification records for ({:?})", count, n);
}
}
tx.commit()?;
info!("saved new verification record for ({:?})", n);
let ok: Result<()> = Ok(());
ok
}).await?
}
/// Update verification timestamp
async fn update_verification_timestamp(&self, id: u64) -> Result<()> {
let mut conn = self.write_pool.get()?;
tokio::task::spawn_blocking(move || {
// add some jitter to the verification to prevent everything from stacking up together.
let verif_time = now_jitter(600);
let tx = conn.transaction()?;
{
// update verification time and reset any failure count
let query =
"UPDATE user_verification SET verified_at=?, failure_count=0 WHERE id=?";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![verif_time, id])?;
}
tx.commit()?;
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Update verification record as failed
async fn fail_verification(&self, id: u64) -> Result<()> {
let mut conn = self.write_pool.get()?;
tokio::task::spawn_blocking(move || {
// add some jitter to the verification to prevent everything from stacking up together.
let fail_time = now_jitter(600);
let tx = conn.transaction()?;
{
let query = "UPDATE user_verification SET failed_at=?, failure_count=failure_count+1 WHERE id=?";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![fail_time, id])?;
}
tx.commit()?;
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Delete verification record
async fn delete_verification(&self, id: u64) -> Result<()> {
let mut conn = self.write_pool.get()?;
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
{
let query = "DELETE FROM user_verification WHERE id=?;";
let mut stmt = tx.prepare(query)?;
stmt.execute(params![id])?;
}
tx.commit()?;
let ok: Result<()> = Ok(());
ok
})
.await?
}
/// Get the latest verification record for a given pubkey.
async fn get_latest_user_verification(&self, pub_key: &str) -> Result<VerificationRecord> {
let mut conn = self.read_pool.get()?;
let pub_key = pub_key.to_owned();
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
let query = "SELECT v.id, v.name, e.event_hash, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v LEFT JOIN event e ON e.id=v.metadata_event WHERE e.author=? ORDER BY e.created_at DESC, v.verified_at DESC, v.failed_at DESC LIMIT 1;";
let mut stmt = tx.prepare_cached(query)?;
let fields = stmt.query_row(params![hex::decode(&pub_key).ok()], |r| {
let rowid: u64 = r.get(0)?;
let rowname: String = r.get(1)?;
let eventid: Vec<u8> = r.get(2)?;
let created_at: u64 = r.get(3)?;
// create a tuple since we can't throw non-rusqlite errors in this closure
Ok((
rowid,
rowname,
eventid,
created_at,
r.get(4).ok(),
r.get(5).ok(),
r.get(6)?,
))
})?;
Ok(VerificationRecord {
rowid: fields.0,
name: Nip05Name::try_from(&fields.1[..])?,
address: pub_key,
event: hex::encode(fields.2),
event_created: fields.3,
last_success: fields.4,
last_failure: fields.5,
failure_count: fields.6,
})
}).await?
}
/// Get oldest verification before timestamp
async fn get_oldest_user_verification(&self, before: u64) -> Result<VerificationRecord> {
let mut conn = self.read_pool.get()?;
tokio::task::spawn_blocking(move || {
let tx = conn.transaction()?;
let query = "SELECT v.id, v.name, e.event_hash, e.author, e.created_at, v.verified_at, v.failed_at, v.failure_count FROM user_verification v INNER JOIN event e ON e.id=v.metadata_event WHERE (v.verified_at < ? OR v.verified_at IS NULL) AND (v.failed_at < ? OR v.failed_at IS NULL) ORDER BY v.verified_at ASC, v.failed_at ASC LIMIT 1;";
let mut stmt = tx.prepare_cached(query)?;
let fields = stmt.query_row(params![before, before], |r| {
let rowid: u64 = r.get(0)?;
let rowname: String = r.get(1)?;
let eventid: Vec<u8> = r.get(2)?;
let pubkey: Vec<u8> = r.get(3)?;
let created_at: u64 = r.get(4)?;
// create a tuple since we can't throw non-rusqlite errors in this closure
Ok((
rowid,
rowname,
eventid,
pubkey,
created_at,
r.get(5).ok(),
r.get(6).ok(),
r.get(7)?,
))
})?;
let vr = VerificationRecord {
rowid: fields.0,
name: Nip05Name::try_from(&fields.1[..])?,
address: hex::encode(fields.3),
event: hex::encode(fields.2),
event_created: fields.4,
last_success: fields.5,
last_failure: fields.6,
failure_count: fields.7,
};
Ok(vr)
}).await?
}
}
/// Decide if there is an index that should be used explicitly
fn override_index(f: &ReqFilter) -> Option<String> {
// queries for multiple kinds default to kind_index, which is
// significantly slower than kind_created_at_index.
if let Some(ks) = &f.kinds {
if f.ids.is_none() &&
ks.len() > 1 &&
f.since.is_none() &&
f.until.is_none() &&
f.tags.is_none() &&
f.authors.is_none() {
return Some("kind_created_at_index".into());
}
}
// if there is an author, it is much better to force the authors index.
if f.authors.is_some() {
if f.since.is_none() && f.until.is_none() {
if f.kinds.is_none() {
// with no use of kinds/created_at, just author
return Some("author_index".into());
}
// prefer author_kind if there are kinds
return Some("author_kind_index".into());
}
// finally, prefer author_created_at if time is provided
return Some("author_created_at_index".into());
}
None
}
/// Create a dynamic SQL subquery and params from a subscription filter (and optional explicit index used)
fn query_from_filter(f: &ReqFilter) -> (String, Vec<Box<dyn ToSql>>, Option<String>) {
// build a dynamic SQL query. all user-input is either an integer
// (sqli-safe), or a string that is filtered to only contain
// hexadecimal characters. Strings that require escaping (tag
// names/values) use parameters.
// if the filter is malformed, don't return anything.
if f.force_no_match {
let empty_query = "SELECT e.content, e.created_at FROM event e WHERE 1=0".to_owned();
// query parameters for SQLite
let empty_params: Vec<Box<dyn ToSql>> = vec![];
return (empty_query, empty_params, None);
}
// check if the index needs to be overriden
let idx_name = override_index(f);
let idx_stmt = idx_name.as_ref().map_or_else(|| "".to_owned(), |i| format!("INDEXED BY {}",i));
let mut query = format!("SELECT e.content, e.created_at FROM event e {}", idx_stmt);
// query parameters for SQLite
let mut params: Vec<Box<dyn ToSql>> = vec![];
// individual filter components (single conditions such as an author or event ID)
let mut filter_components: Vec<String> = Vec::new();
// Query for "authors", allowing prefix matches
if let Some(authvec) = &f.authors {
// take each author and convert to a hexsearch
let mut auth_searches: Vec<String> = vec![];
for auth in authvec {
match hex_range(auth) {
Some(HexSearch::Exact(ex)) => {
auth_searches.push("author=?".to_owned());
params.push(Box::new(ex));
}
Some(HexSearch::Range(lower, upper)) => {
auth_searches.push(
"(author>? AND author<?)".to_owned(),
);
params.push(Box::new(lower));
params.push(Box::new(upper));
}
Some(HexSearch::LowerOnly(lower)) => {
auth_searches.push("author>?".to_owned());
params.push(Box::new(lower));
}
None => {
info!("Could not parse hex range from author {:?}", auth);
}
}
}
if !authvec.is_empty() {
let auth_clause = format!("({})", auth_searches.join(" OR "));
filter_components.push(auth_clause);
} else {
filter_components.push("false".to_owned());
}
}
// Query for Kind
if let Some(ks) = &f.kinds {
// kind is number, no escaping needed
let str_kinds: Vec<String> = ks.iter().map(std::string::ToString::to_string).collect();
let kind_clause = format!("kind IN ({})", str_kinds.join(", "));
filter_components.push(kind_clause);
}
// Query for event, allowing prefix matches
if let Some(idvec) = &f.ids {
// take each author and convert to a hexsearch
let mut id_searches: Vec<String> = vec![];
for id in idvec {
match hex_range(id) {
Some(HexSearch::Exact(ex)) => {
id_searches.push("event_hash=?".to_owned());
params.push(Box::new(ex));
}
Some(HexSearch::Range(lower, upper)) => {
id_searches.push("(event_hash>? AND event_hash<?)".to_owned());
params.push(Box::new(lower));
params.push(Box::new(upper));
}
Some(HexSearch::LowerOnly(lower)) => {
id_searches.push("event_hash>?".to_owned());
params.push(Box::new(lower));
}
None => {
info!("Could not parse hex range from id {:?}", id);
}
}
}
if idvec.is_empty() {
// if the ids list was empty, we should never return
// any results.
filter_components.push("false".to_owned());
} else {
let id_clause = format!("({})", id_searches.join(" OR "));
filter_components.push(id_clause);
}
}
// Query for tags
if let Some(map) = &f.tags {
for (key, val) in map.iter() {
let mut str_vals: Vec<Box<dyn ToSql>> = vec![];
let mut blob_vals: Vec<Box<dyn ToSql>> = vec![];
for v in val {
if (v.len() % 2 == 0) && is_lower_hex(v) {
if let Ok(h) = hex::decode(v) {
blob_vals.push(Box::new(h));
}
} else {
str_vals.push(Box::new(v.clone()));
}
}
// create clauses with "?" params for each tag value being searched
let str_clause = format!("value IN ({})", repeat_vars(str_vals.len()));
let blob_clause = format!("value_hex IN ({})", repeat_vars(blob_vals.len()));
// find evidence of the target tag name/value existing for this event.
let tag_clause = format!(
"e.id IN (SELECT e.id FROM event e LEFT JOIN tag t on e.id=t.event_id WHERE hidden!=TRUE and (name=? AND ({} OR {})))",
str_clause, blob_clause
);
// add the tag name as the first parameter
params.push(Box::new(key.to_string()));
// add all tag values that are plain strings as params
params.append(&mut str_vals);
// add all tag values that are blobs as params
params.append(&mut blob_vals);
filter_components.push(tag_clause);
}
}
// Query for timestamp
if f.since.is_some() {
let created_clause = format!("created_at > {}", f.since.unwrap());
filter_components.push(created_clause);
}
// Query for timestamp
if f.until.is_some() {
let until_clause = format!("created_at < {}", f.until.unwrap());
filter_components.push(until_clause);
}
// never display hidden events
query.push_str(" WHERE hidden!=TRUE");
// build filter component conditions
if !filter_components.is_empty() {
query.push_str(" AND ");
query.push_str(&filter_components.join(" AND "));
}
// Apply per-filter limit to this subquery.
// The use of a LIMIT implies a DESC order, to capture only the most recent events.
if let Some(lim) = f.limit {
let _ = write!(query, " ORDER BY e.created_at DESC LIMIT {}", lim);
} else {
query.push_str(" ORDER BY e.created_at ASC");
}
(query, params, idx_name)
}
/// Create a dynamic SQL query string and params from a subscription.
fn query_from_sub(sub: &Subscription) -> (String, Vec<Box<dyn ToSql>>, Vec<String>) {
// build a dynamic SQL query for an entire subscription, based on
// SQL subqueries for filters.
let mut subqueries: Vec<String> = Vec::new();
let mut indexes = vec![];
// subquery params
let mut params: Vec<Box<dyn ToSql>> = vec![];
// for every filter in the subscription, generate a subquery
for f in &sub.filters {
let (f_subquery, mut f_params, index) = query_from_filter(f);
if let Some(i) = index {
indexes.push(i);
}
subqueries.push(f_subquery);
params.append(&mut f_params);
}
// encapsulate subqueries into select statements
let subqueries_selects: Vec<String> = subqueries
.iter()
.map(|s| format!("SELECT distinct content, created_at FROM ({})", s))
.collect();
let query: String = subqueries_selects.join(" UNION ");
(query, params,indexes)
}
/// Build a database connection pool.
/// # Panics
///
/// Will panic if the pool could not be created.
#[must_use]
pub fn build_pool(
name: &str,
settings: &Settings,
flags: OpenFlags,
min_size: u32,
max_size: u32,
wait_for_db: bool,
) -> SqlitePool {
let db_dir = &settings.database.data_directory;
let full_path = Path::new(db_dir).join(DB_FILE);
// small hack; if the database doesn't exist yet, that means the
// writer thread hasn't finished. Give it a chance to work. This
// is only an issue with the first time we run.
if !settings.database.in_memory {
while !full_path.exists() && wait_for_db {
debug!("Database reader pool is waiting on the database to be created...");
thread::sleep(Duration::from_millis(500));
}
}
let manager = if settings.database.in_memory {
SqliteConnectionManager::memory()
.with_flags(flags)
.with_init(|c| c.execute_batch(STARTUP_SQL))
} else {
SqliteConnectionManager::file(&full_path)
.with_flags(flags)
.with_init(|c| c.execute_batch(STARTUP_SQL))
};
let pool: SqlitePool = r2d2::Pool::builder()
.test_on_check_out(true) // no noticeable performance hit
.min_idle(Some(min_size))
.max_size(max_size)
.max_lifetime(Some(Duration::from_secs(30)))
.build(manager)
.unwrap();
info!(
"Built a connection pool {:?} (min={}, max={})",
name, min_size, max_size
);
pool
}
/// Perform database WAL checkpoint on a regular basis
pub async fn db_checkpoint_task(pool: SqlitePool, frequency: Duration, checkpoint_in_progress: Arc<Mutex<u64>>) -> Result<()> {
tokio::task::spawn(async move {
// WAL size in pages.
let mut current_wal_size = 0;
// WAL threshold for more aggressive checkpointing (10,000 pages, or about 40MB)
let wal_threshold = 1000*10;
// default threshold for the busy timer
let busy_wait_default = Duration::from_secs(1);
// if the WAL file is getting too big, switch to this
let busy_wait_default_long = Duration::from_secs(10);
loop {
tokio::select! {
_ = tokio::time::sleep(frequency) => {
if let Ok(mut conn) = pool.get() {
let mut _guard:Option<MutexGuard<u64>> = None;
// the busy timer will block writers, so don't set
// this any higher than you want max latency for event
// writes.
if current_wal_size <= wal_threshold {
conn.busy_timeout(busy_wait_default).ok();
} else {
// if the wal size has exceeded a threshold, increase the busy timeout.
conn.busy_timeout(busy_wait_default_long).ok();
// take a lock that will prevent new readers.
info!("blocking new readers to perform wal_checkpoint");
_guard = Some(checkpoint_in_progress.lock().await);
}
debug!("running wal_checkpoint(TRUNCATE)");
if let Ok(new_size) = checkpoint_db(&mut conn) {
current_wal_size = new_size;
}
}
}
};
}
});
Ok(())
}
#[derive(Debug)]
enum SqliteStatus {
Ok,
Busy,
Error,
Other(u64),
}
/// Checkpoint/Truncate WAL. Returns the number of WAL pages remaining.
pub fn checkpoint_db(conn: &mut PooledConnection) -> Result<usize> {
let query = "PRAGMA wal_checkpoint(TRUNCATE);";
let start = Instant::now();
let (cp_result, wal_size, _frames_checkpointed) = conn.query_row(query, [], |row| {
let checkpoint_result: u64 = row.get(0)?;
let wal_size: u64 = row.get(1)?;
let frames_checkpointed: u64 = row.get(2)?;
Ok((checkpoint_result, wal_size, frames_checkpointed))
})?;
let result = match cp_result {
0 => SqliteStatus::Ok,
1 => SqliteStatus::Busy,
2 => SqliteStatus::Error,
x => SqliteStatus::Other(x),
};
info!(
"checkpoint ran in {:?} (result: {:?}, WAL size: {})",
start.elapsed(),
result,
wal_size
);
Ok(wal_size as usize)
}
/// Produce a arbitrary list of '?' parameters.
fn repeat_vars(count: usize) -> String {
if count == 0 {
return "".to_owned();
}
let mut s = "?,".repeat(count);
// Remove trailing comma
s.pop();
s
}
/// Display database pool stats every 1 minute
pub async fn monitor_pool(name: &str, pool: SqlitePool) {
let sleep_dur = Duration::from_secs(60);
loop {
log_pool_stats(name, &pool);
tokio::time::sleep(sleep_dur).await;
}
}
/// Log pool stats
fn log_pool_stats(name: &str, pool: &SqlitePool) {
let state: r2d2::State = pool.state();
let in_use_cxns = state.connections - state.idle_connections;
debug!(
"DB pool {:?} usage (in_use: {}, available: {}, max: {})",
name,
in_use_cxns,
state.connections,
pool.max_size()
);
}
/// Check if the pool is fully utilized
fn _pool_at_capacity(pool: &SqlitePool) -> bool {
let state: r2d2::State = pool.state();
state.idle_connections == 0
}

View File

@ -113,7 +113,7 @@ pub fn db_tag_count(conn: &mut Connection) -> Result<usize> {
Ok(count)
}
fn mig_init(conn: &mut PooledConnection) -> Result<usize> {
fn mig_init(conn: &mut PooledConnection) -> usize {
match conn.execute_batch(INIT_SQL) {
Ok(()) => {
info!(
@ -126,11 +126,11 @@ fn mig_init(conn: &mut PooledConnection) -> Result<usize> {
panic!("database could not be initialized");
}
}
Ok(DB_VERSION)
DB_VERSION
}
/// Upgrade DB to latest version, and execute pragma settings
pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> {
pub fn upgrade_db(conn: &mut PooledConnection) -> Result<usize> {
// check the version.
let mut curr_version = curr_db_version(conn)?;
info!("DB version = {:?}", curr_version);
@ -141,11 +141,11 @@ pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> {
);
debug!(
"SQLite max table/blob/text length: {} MB",
(conn.limit(Limit::SQLITE_LIMIT_LENGTH) as f64 / (1024 * 1024) as f64).floor()
(f64::from(conn.limit(Limit::SQLITE_LIMIT_LENGTH)) / f64::from(1024 * 1024)).floor()
);
debug!(
"SQLite max SQL length: {} MB",
(conn.limit(Limit::SQLITE_LIMIT_SQL_LENGTH) as f64 / (1024 * 1024) as f64).floor()
(f64::from(conn.limit(Limit::SQLITE_LIMIT_SQL_LENGTH)) / f64::from(1024 * 1024)).floor()
);
match curr_version.cmp(&DB_VERSION) {
@ -153,7 +153,7 @@ pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> {
Ordering::Less => {
// initialize from scratch
if curr_version == 0 {
curr_version = mig_init(conn)?;
curr_version = mig_init(conn);
}
// for initialized but out-of-date schemas, proceed to
// upgrade sequentially until we are current.
@ -223,7 +223,7 @@ pub fn upgrade_db(conn: &mut PooledConnection) -> Result<()> {
// Setup PRAGMA
conn.execute_batch(STARTUP_SQL)?;
debug!("SQLite PRAGMA startup completed");
Ok(())
Ok(DB_VERSION)
}
pub fn rebuild_tags(conn: &mut PooledConnection) -> Result<()> {

View File

@ -3,6 +3,7 @@ use crate::close::Close;
use crate::close::CloseCmd;
use crate::config::{Settings, VerifiedUsersMode};
use crate::conn;
use crate::repo::NostrRepo;
use crate::db;
use crate::db::SubmittedEvent;
use crate::error::{Error, Result};
@ -22,10 +23,8 @@ use hyper::upgrade::Upgraded;
use hyper::{
header, server::conn::AddrStream, upgrade, Body, Request, Response, Server, StatusCode,
};
use rusqlite::OpenFlags;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::sync::Mutex;
use std::collections::HashMap;
use std::convert::Infallible;
use std::net::SocketAddr;
@ -40,23 +39,22 @@ use tokio::sync::broadcast::{self, Receiver, Sender};
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio_tungstenite::WebSocketStream;
use tracing::*;
use tracing::{debug, error, info, trace, warn};
use tungstenite::error::CapacityError::MessageTooLong;
use tungstenite::error::Error as WsError;
use tungstenite::handshake;
use tungstenite::protocol::Message;
use tungstenite::protocol::WebSocketConfig;
/// Handle arbitrary HTTP requests, including for WebSocket upgrades.
/// Handle arbitrary HTTP requests, including for `WebSocket` upgrades.
async fn handle_web_request(
mut request: Request<Body>,
pool: db::SqlitePool,
repo: Arc<dyn NostrRepo>,
settings: Settings,
remote_addr: SocketAddr,
broadcast: Sender<Event>,
event_tx: tokio::sync::mpsc::Sender<SubmittedEvent>,
shutdown: Receiver<()>,
safe_to_read: Arc<Mutex<u64>>,
) -> Result<Response<Body>, Infallible> {
match (
request.uri().path(),
@ -111,14 +109,13 @@ async fn handle_web_request(
};
// spawn a nostr server with our websocket
tokio::spawn(nostr_server(
pool,
repo,
client_info,
settings,
ws_stream,
broadcast,
event_tx,
shutdown,
safe_to_read,
));
}
// todo: trace, don't print...
@ -184,7 +181,7 @@ async fn handle_web_request(
fn get_header_string(header: &str, headers: &HeaderMap) -> Option<String> {
headers
.get(header)
.and_then(|x| x.to_str().ok().map(|x| x.to_string()))
.and_then(|x| x.to_str().ok().map(std::string::ToString::to_string))
}
// return on a control-c or internally requested shutdown signal
@ -211,7 +208,7 @@ async fn ctrl_c_or_signal(mut shutdown_signal: Receiver<()>) {
}
/// Start running a Nostr relay server.
pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result<(), Error> {
pub fn start_server(settings: &Settings, shutdown_rx: MpscReceiver<()>) -> Result<(), Error> {
trace!("Config: {:?}", settings);
// do some config validation.
if !Path::new(&settings.database.data_directory).is_dir() {
@ -274,8 +271,6 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result
let broadcast_buffer_limit = settings.limits.broadcast_buffer;
let persist_buffer_limit = settings.limits.event_persist_buffer;
let verified_users_active = settings.verified_users.is_active();
let db_min_conn = settings.database.min_conn;
let db_max_conn = settings.database.max_conn;
let settings = settings.clone();
info!("listening on: {}", socket_addr);
// all client-submitted valid events are broadcast to every
@ -298,23 +293,26 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result
// overwhelming this will drop events and won't register
// metadata events.
let (metadata_tx, metadata_rx) = broadcast::channel::<Event>(4096);
// start the database writer thread. Give it a channel for
// build a repository for events
let repo = db::build_repo(&settings).await;
// start the database writer task. Give it a channel for
// writing events, and for publishing events that have been
// written (to all connected clients).
db::db_writer(
settings.clone(),
event_rx,
bcast_tx.clone(),
metadata_tx.clone(),
shutdown_listen,
)
.await;
tokio::task::spawn(
db::db_writer(
repo.clone(),
settings.clone(),
event_rx,
bcast_tx.clone(),
metadata_tx.clone(),
shutdown_listen,
));
info!("db writer created");
// create a nip-05 verifier thread; if enabled.
if settings.verified_users.mode != VerifiedUsersMode::Disabled {
let verifier_opt =
nip05::Verifier::new(metadata_rx, bcast_tx.clone(), settings.clone());
nip05::Verifier::new(repo.clone(), metadata_rx, bcast_tx.clone(), settings.clone());
if let Ok(mut v) = verifier_opt {
if verified_users_active {
tokio::task::spawn(async move {
@ -324,35 +322,19 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result
}
}
}
// build a connection pool for DB maintenance
let maintenance_pool = db::build_pool(
"maintenance writer",
&settings,
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
1,
2,
false,
);
// Create a mutex that will block readers, so that a
// checkpoint can be performed quickly.
let safe_to_read = Arc::new(Mutex::new(0));
db::db_optimize_task(maintenance_pool.clone()).await;
db::db_checkpoint_task(maintenance_pool, safe_to_read.clone()).await;
// listen for (external to tokio) shutdown request
let controlled_shutdown = invoke_shutdown.clone();
tokio::spawn(async move {
info!("control message listener started");
// we only have good "shutdown" messages propagation from this-> controlled shutdown. Not from controlled_shutdown-> this. Which means we have a task that is stuck waiting on a sync receive. recv is blocking, and this is async.
match shutdown_rx.recv() {
Ok(()) => {
info!("control message requesting shutdown");
controlled_shutdown.send(()).ok();
}
},
Err(std::sync::mpsc::RecvError) => {
// FIXME: spurious error on startup?
debug!("shutdown requestor is disconnected");
trace!("shutdown requestor is disconnected (this is normal)");
}
};
});
@ -366,41 +348,30 @@ pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result
info!("shutting down due to SIGINT (main)");
ctrl_c_shutdown.send(()).ok();
});
// build a connection pool for sqlite connections
let pool = db::build_pool(
"client query",
&settings,
rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
db_min_conn,
db_max_conn,
true,
);
// spawn a task to check the pool size.
let pool_monitor = pool.clone();
tokio::spawn(async move {db::monitor_pool("reader", pool_monitor).await;});
//let pool_monitor = pool.clone();
//tokio::spawn(async move {db::monitor_pool("reader", pool_monitor).await;});
// A `Service` is needed for every connection, so this
// creates one from our `handle_request` function.
let make_svc = make_service_fn(|conn: &AddrStream| {
let svc_pool = pool.clone();
let repo = repo.clone();
let remote_addr = conn.remote_addr();
let bcast = bcast_tx.clone();
let event = event_tx.clone();
let stop = invoke_shutdown.clone();
let settings = settings.clone();
let safe_to_read = safe_to_read.clone();
async move {
// service_fn converts our function into a `Service`
Ok::<_, Infallible>(service_fn(move |request: Request<Body>| {
handle_web_request(
request,
svc_pool.clone(),
repo.clone(),
settings.clone(),
remote_addr,
bcast.clone(),
event.clone(),
stop.subscribe(),
safe_to_read.clone(),
)
}))
}
@ -428,9 +399,9 @@ pub enum NostrMessage {
CloseMsg(CloseCmd),
}
/// Convert Message to NostrMessage
fn convert_to_msg(msg: String, max_bytes: Option<usize>) -> Result<NostrMessage> {
let parsed_res: Result<NostrMessage> = serde_json::from_str(&msg).map_err(|e| e.into());
/// Convert Message to `NostrMessage`
fn convert_to_msg(msg: &str, max_bytes: Option<usize>) -> Result<NostrMessage> {
let parsed_res: Result<NostrMessage> = serde_json::from_str(msg).map_err(std::convert::Into::into);
match parsed_res {
Ok(m) => {
if let NostrMessage::SubMsg(_) = m {
@ -455,8 +426,8 @@ fn convert_to_msg(msg: String, max_bytes: Option<usize>) -> Result<NostrMessage>
}
}
/// Turn a string into a NOTICE message ready to send over a WebSocket
fn make_notice_message(notice: Notice) -> Message {
/// Turn a string into a NOTICE message ready to send over a `WebSocket`
fn make_notice_message(notice: &Notice) -> Message {
let json = match notice {
Notice::Message(ref msg) => json!(["NOTICE", msg]),
Notice::EventResult(ref res) => json!(["OK", res.id, res.status.to_bool(), res.msg]),
@ -474,14 +445,13 @@ struct ClientInfo {
/// Handle new client connections. This runs through an event loop
/// for all client communication.
async fn nostr_server(
pool: db::SqlitePool,
repo: Arc<dyn NostrRepo>,
client_info: ClientInfo,
settings: Settings,
mut ws_stream: WebSocketStream<Upgraded>,
broadcast: Sender<Event>,
event_tx: mpsc::Sender<SubmittedEvent>,
mut shutdown: Receiver<()>,
safe_to_read: Arc<Mutex<u64>>,
) {
// the time this websocket nostr server started
let orig_start = Instant::now();
@ -559,7 +529,7 @@ async fn nostr_server(
ws_stream.send(Message::Ping(Vec::new())).await.ok();
},
Some(notice_msg) = notice_rx.recv() => {
ws_stream.send(make_notice_message(notice_msg)).await.ok();
ws_stream.send(make_notice_message(&notice_msg)).await.ok();
},
Some(query_result) = query_rx.recv() => {
// database informed us of a query result we asked for
@ -603,11 +573,11 @@ async fn nostr_server(
// Consume text messages from the client, parse into Nostr messages.
let nostr_msg = match ws_next {
Some(Ok(Message::Text(m))) => {
convert_to_msg(m,settings.limits.max_event_bytes)
convert_to_msg(&m,settings.limits.max_event_bytes)
},
Some(Ok(Message::Binary(_))) => {
ws_stream.send(
make_notice_message(Notice::message("binary messages are not accepted".into()))).await.ok();
make_notice_message(&Notice::message("binary messages are not accepted".into()))).await.ok();
continue;
},
Some(Ok(Message::Ping(_) | Message::Pong(_))) => {
@ -617,7 +587,7 @@ async fn nostr_server(
},
Some(Err(WsError::Capacity(MessageTooLong{size, max_size}))) => {
ws_stream.send(
make_notice_message(Notice::message(format!("message too large ({} > {})",size, max_size)))).await.ok();
make_notice_message(&Notice::message(format!("message too large ({} > {})",size, max_size)))).await.ok();
continue;
},
None |
@ -662,13 +632,13 @@ async fn nostr_server(
if let Some(fut_sec) = settings.options.reject_future_seconds {
let msg = format!("The event created_at field is out of the acceptable range (+{}sec) for this relay.",fut_sec);
let notice = Notice::invalid(e.id, &msg);
ws_stream.send(make_notice_message(notice)).await.ok();
ws_stream.send(make_notice_message(&notice)).await.ok();
}
}
},
Err(e) => {
info!("client sent an invalid event (cid: {})", cid);
ws_stream.send(make_notice_message(Notice::invalid(evid, &format!("{}", e)))).await.ok();
ws_stream.send(make_notice_message(&Notice::invalid(evid, &format!("{}", e)))).await.ok();
}
}
},
@ -679,31 +649,31 @@ async fn nostr_server(
// * registering the subscription so future events can be matched
// * making a channel to cancel to request later
// * sending a request for a SQL query
// Do nothing if the sub already exists.
if !conn.has_subscription(&s) {
if let Some(ref lim) = sub_lim_opt {
lim.until_ready_with_jitter(jitter).await;
}
// Do nothing if the sub already exists.
if conn.has_subscription(&s) {
info!("client sent duplicate subscription, ignoring (cid: {}, sub: {:?})", cid, s.id);
} else {
if let Some(ref lim) = sub_lim_opt {
lim.until_ready_with_jitter(jitter).await;
}
let (abandon_query_tx, abandon_query_rx) = oneshot::channel::<()>();
match conn.subscribe(s.clone()) {
Ok(()) => {
Ok(()) => {
// when we insert, if there was a previous query running with the same name, cancel it.
if let Some(previous_query) = running_queries.insert(s.id.to_owned(), abandon_query_tx) {
previous_query.send(()).ok();
if let Some(previous_query) = running_queries.insert(s.id.clone(), abandon_query_tx) {
previous_query.send(()).ok();
}
if s.needs_historical_events() {
// start a database query. this spawns a blocking database query on a worker thread.
db::db_query(s, cid.to_owned(), pool.clone(), query_tx.clone(), abandon_query_rx,safe_to_read.clone()).await;
if s.needs_historical_events() {
// start a database query. this spawns a blocking database query on a worker thread.
repo.query_subscription(s, cid.clone(), query_tx.clone(), abandon_query_rx).await.ok();
}
},
Err(e) => {
info!("Subscription error: {} (cid: {}, sub: {:?})", e, cid, s.id);
ws_stream.send(make_notice_message(Notice::message(format!("Subscription error: {}", e)))).await.ok();
}
},
Err(e) => {
info!("Subscription error: {} (cid: {}, sub: {:?})", e, cid, s.id);
ws_stream.send(make_notice_message(&Notice::message(format!("Subscription error: {}", e)))).await.ok();
}
}
} else {
info!("client sent duplicate subscription, ignoring (cid: {}, sub: {:?})", cid, s.id);
}
}
},
Ok(NostrMessage::CloseMsg(cc)) => {
// closing a request simply removes the subscription.
@ -720,7 +690,7 @@ async fn nostr_server(
conn.unsubscribe(&c);
} else {
info!("invalid command ignored");
ws_stream.send(make_notice_message(Notice::message("could not parse command".into()))).await.ok();
ws_stream.send(make_notice_message(&Notice::message("could not parse command".into()))).await.ok();
}
},
Err(Error::ConnError) => {
@ -729,11 +699,11 @@ async fn nostr_server(
}
Err(Error::EventMaxLengthError(s)) => {
info!("client sent event larger ({} bytes) than max size (cid: {})", s, cid);
ws_stream.send(make_notice_message(Notice::message("event exceeded max size".into()))).await.ok();
ws_stream.send(make_notice_message(&Notice::message("event exceeded max size".into()))).await.ok();
},
Err(Error::ProtoParseError) => {
info!("client sent event that could not be parsed (cid: {})", cid);
ws_stream.send(make_notice_message(Notice::message("could not parse command".into()))).await.ok();
ws_stream.send(make_notice_message(&Notice::message("could not parse command".into()))).await.ok();
},
Err(e) => {
info!("got non-fatal error from client (cid: {}, error: {:?}", cid, e);

View File

@ -68,7 +68,7 @@ impl<'de> Deserialize<'de> for ReqFilter {
let empty_string = "".into();
let mut ts = None;
// iterate through each key, and assign values that exist
for (key, val) in filter.into_iter() {
for (key, val) in filter {
// ids
if key == "ids" {
let raw_ids: Option<Vec<String>>= Deserialize::deserialize(val).ok();
@ -107,7 +107,7 @@ impl<'de> Deserialize<'de> for ReqFilter {
if let Some(m) = ts.as_mut() {
let tag_vals: Option<Vec<String>> = Deserialize::deserialize(val).ok();
if let Some(v) = tag_vals {
let hs = HashSet::from_iter(v.into_iter());
let hs = v.into_iter().collect::<HashSet<_>>();
m.insert(tag_search.to_owned(), hs);
}
};
@ -197,20 +197,20 @@ impl<'de> Deserialize<'de> for Subscription {
impl Subscription {
/// Get a copy of the subscription identifier.
pub fn get_id(&self) -> String {
#[must_use] pub fn get_id(&self) -> String {
self.id.clone()
}
/// Determine if any filter is requesting historical (database)
/// queries. If every filter has limit:0, we do not need to query the DB.
pub fn needs_historical_events(&self) -> bool {
#[must_use] pub fn needs_historical_events(&self) -> bool {
self.filters.iter().any(|f| f.limit!=Some(0))
}
/// Determine if this subscription matches a given [`Event`]. Any
/// individual filter match is sufficient.
pub fn interested_in_event(&self, event: &Event) -> bool {
for f in self.filters.iter() {
#[must_use] pub fn interested_in_event(&self, event: &Event) -> bool {
for f in &self.filters {
if f.interested_in_event(event) {
return true;
}
@ -233,23 +233,20 @@ impl ReqFilter {
fn ids_match(&self, event: &Event) -> bool {
self.ids
.as_ref()
.map(|vs| prefix_match(vs, &event.id))
.unwrap_or(true)
.map_or(true, |vs| prefix_match(vs, &event.id))
}
fn authors_match(&self, event: &Event) -> bool {
self.authors
.as_ref()
.map(|vs| prefix_match(vs, &event.pubkey))
.unwrap_or(true)
.map_or(true, |vs| prefix_match(vs, &event.pubkey))
}
fn delegated_authors_match(&self, event: &Event) -> bool {
if let Some(delegated_pubkey) = &event.delegated_by {
self.authors
.as_ref()
.map(|vs| prefix_match(vs, delegated_pubkey))
.unwrap_or(true)
.map_or(true, |vs| prefix_match(vs, delegated_pubkey))
} else {
false
}
@ -275,16 +272,15 @@ impl ReqFilter {
fn kind_match(&self, kind: u64) -> bool {
self.kinds
.as_ref()
.map(|ks| ks.contains(&kind))
.unwrap_or(true)
.map_or(true, |ks| ks.contains(&kind))
}
/// Determine if all populated fields in this filter match the provided event.
pub fn interested_in_event(&self, event: &Event) -> bool {
#[must_use] pub fn interested_in_event(&self, event: &Event) -> bool {
// self.id.as_ref().map(|v| v == &event.id).unwrap_or(true)
self.ids_match(event)
&& self.since.map(|t| event.created_at > t).unwrap_or(true)
&& self.until.map(|t| event.created_at < t).unwrap_or(true)
&& self.since.map_or(true, |t| event.created_at > t)
&& self.until.map_or(true, |t| event.created_at < t)
&& self.kind_match(event.kind)
&& (self.authors_match(event) || self.delegated_authors_match(event))
&& self.tag_match(event)

View File

@ -2,7 +2,7 @@
use std::time::SystemTime;
/// Seconds since 1970.
pub fn unix_time() -> u64 {
#[must_use] pub fn unix_time() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map(|x| x.as_secs())
@ -10,12 +10,12 @@ pub fn unix_time() -> u64 {
}
/// Check if a string contains only hex characters.
pub fn is_hex(s: &str) -> bool {
#[must_use] pub fn is_hex(s: &str) -> bool {
s.chars().all(|x| char::is_ascii_hexdigit(&x))
}
/// Check if a string contains only lower-case hex chars.
pub fn is_lower_hex(s: &str) -> bool {
#[must_use] pub fn is_lower_hex(s: &str) -> bool {
s.chars().all(|x| {
(char::is_ascii_lowercase(&x) || char::is_ascii_digit(&x)) && char::is_ascii_hexdigit(&x)
})