refactor: remove global/singleton settings object

This commit is contained in:
Greg Heartsfield 2022-09-06 06:12:07 -05:00
parent e48bae10e6
commit 2b03f11e5e
7 changed files with 148 additions and 127 deletions

View File

@ -1,16 +1,9 @@
//! Configuration file and settings management //! Configuration file and settings management
use config::{Config, ConfigError, File}; use config::{Config, ConfigError, File};
use lazy_static::lazy_static;
use log::*; use log::*;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::sync::RwLock;
use std::time::Duration; use std::time::Duration;
// initialize a singleton default configuration
lazy_static! {
pub static ref SETTINGS: RwLock<Settings> = RwLock::new(Settings::default());
}
#[derive(Debug, Serialize, Deserialize, Clone)] #[derive(Debug, Serialize, Deserialize, Clone)]
#[allow(unused)] #[allow(unused)]
pub struct Info { pub struct Info {
@ -21,7 +14,7 @@ pub struct Info {
pub contact: Option<String>, pub contact: Option<String>,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(unused)] #[allow(unused)]
pub struct Database { pub struct Database {
pub data_directory: String, pub data_directory: String,
@ -30,7 +23,7 @@ pub struct Database {
pub max_conn: u32, pub max_conn: u32,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(unused)] #[allow(unused)]
pub struct Network { pub struct Network {
pub port: u16, pub port: u16,
@ -38,13 +31,13 @@ pub struct Network {
} }
// //
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(unused)] #[allow(unused)]
pub struct Options { pub struct Options {
pub reject_future_seconds: Option<usize>, // if defined, reject any events with a timestamp more than X seconds in the future pub reject_future_seconds: Option<usize>, // if defined, reject any events with a timestamp more than X seconds in the future
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(unused)] #[allow(unused)]
pub struct Retention { pub struct Retention {
// TODO: implement // TODO: implement
@ -54,7 +47,7 @@ pub struct Retention {
pub whitelist_addresses: Option<Vec<String>>, // whitelisted addresses (never delete) pub whitelist_addresses: Option<Vec<String>>, // whitelisted addresses (never delete)
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(unused)] #[allow(unused)]
pub struct Limits { pub struct Limits {
pub messages_per_sec: Option<u32>, // Artificially slow down event writing to limit disk consumption (averaged over 1 minute) pub messages_per_sec: Option<u32>, // Artificially slow down event writing to limit disk consumption (averaged over 1 minute)
@ -65,7 +58,7 @@ pub struct Limits {
pub event_persist_buffer: usize, // events to buffer for database commits (block senders if database writes are too slow) pub event_persist_buffer: usize, // events to buffer for database commits (block senders if database writes are too slow)
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(unused)] #[allow(unused)]
pub struct Authorization { pub struct Authorization {
pub pubkey_whitelist: Option<Vec<String>>, // If present, only allow these pubkeys to publish events pub pubkey_whitelist: Option<Vec<String>>, // If present, only allow these pubkeys to publish events
@ -79,7 +72,7 @@ pub enum VerifiedUsersMode {
Disabled, Disabled,
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(unused)] #[allow(unused)]
pub struct VerifiedUsers { pub struct VerifiedUsers {
pub mode: VerifiedUsersMode, // Mode of operation: "enabled" (enforce) or "passive" (check only). If none, this is simply disabled. pub mode: VerifiedUsersMode, // Mode of operation: "enabled" (enforce) or "passive" (check only). If none, this is simply disabled.
@ -125,7 +118,7 @@ impl VerifiedUsers {
} }
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(unused)] #[allow(unused)]
pub struct Settings { pub struct Settings {
pub info: Info, pub info: Info,
@ -158,7 +151,7 @@ impl Settings {
// use defaults // use defaults
.add_source(Config::try_from(default)?) .add_source(Config::try_from(default)?)
// override with file contents // override with file contents
.add_source(File::with_name("config")) .add_source(File::with_name("config.toml"))
.build()?; .build()?;
let mut settings: Settings = config.try_deserialize()?; let mut settings: Settings = config.try_deserialize()?;
// ensure connection pool size is logical // ensure connection pool size is logical

View File

@ -1,7 +1,7 @@
//! Event persistence and querying //! Event persistence and querying
use crate::config::SETTINGS; //use crate::config::SETTINGS;
use crate::error::Error; use crate::config::Settings;
use crate::error::Result; use crate::error::{Error, Result};
use crate::event::{single_char_tagname, Event}; use crate::event::{single_char_tagname, Event};
use crate::hexrange::hex_range; use crate::hexrange::hex_range;
use crate::hexrange::HexSearch; use crate::hexrange::HexSearch;
@ -18,7 +18,6 @@ use r2d2;
use r2d2_sqlite::SqliteConnectionManager; use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::params; use rusqlite::params;
use rusqlite::types::ToSql; use rusqlite::types::ToSql;
use rusqlite::Connection;
use rusqlite::OpenFlags; use rusqlite::OpenFlags;
use std::fmt::Write as _; use std::fmt::Write as _;
use std::path::Path; use std::path::Path;
@ -42,13 +41,12 @@ pub const DB_FILE: &str = "nostr.db";
/// Build a database connection pool. /// Build a database connection pool.
pub fn build_pool( pub fn build_pool(
name: &str, name: &str,
settings: Settings,
flags: OpenFlags, flags: OpenFlags,
min_size: u32, min_size: u32,
max_size: u32, max_size: u32,
wait_for_db: bool, wait_for_db: bool,
) -> SqlitePool { ) -> SqlitePool {
let settings = SETTINGS.read().unwrap();
let db_dir = &settings.database.data_directory; let db_dir = &settings.database.data_directory;
let full_path = Path::new(db_dir).join(DB_FILE); let full_path = Path::new(db_dir).join(DB_FILE);
// small hack; if the database doesn't exist yet, that means the // small hack; if the database doesn't exist yet, that means the
@ -81,43 +79,36 @@ pub fn build_pool(
pool pool
} }
/// Build a single database connection, with provided flags
pub fn build_conn(flags: OpenFlags) -> Result<Connection> {
let settings = SETTINGS.read().unwrap();
let db_dir = &settings.database.data_directory;
let full_path = Path::new(db_dir).join(DB_FILE);
// create a connection
Ok(Connection::open_with_flags(&full_path, flags)?)
}
/// Spawn a database writer that persists events to the SQLite store. /// Spawn a database writer that persists events to the SQLite store.
pub async fn db_writer( pub async fn db_writer(
settings: Settings,
mut event_rx: tokio::sync::mpsc::Receiver<SubmittedEvent>, mut event_rx: tokio::sync::mpsc::Receiver<SubmittedEvent>,
bcast_tx: tokio::sync::broadcast::Sender<Event>, bcast_tx: tokio::sync::broadcast::Sender<Event>,
metadata_tx: tokio::sync::broadcast::Sender<Event>, metadata_tx: tokio::sync::broadcast::Sender<Event>,
mut shutdown: tokio::sync::broadcast::Receiver<()>, mut shutdown: tokio::sync::broadcast::Receiver<()>,
) -> tokio::task::JoinHandle<Result<()>> { ) -> tokio::task::JoinHandle<Result<()>> {
let settings = SETTINGS.read().unwrap();
// are we performing NIP-05 checking? // are we performing NIP-05 checking?
let nip05_active = settings.verified_users.is_active(); let nip05_active = settings.verified_users.is_active();
// are we requriing NIP-05 user verification? // are we requriing NIP-05 user verification?
let nip05_enabled = settings.verified_users.is_enabled(); let nip05_enabled = settings.verified_users.is_enabled();
task::spawn_blocking(move || { task::spawn_blocking(move || {
// get database configuration settings
let settings = SETTINGS.read().unwrap();
let db_dir = &settings.database.data_directory; let db_dir = &settings.database.data_directory;
let full_path = Path::new(db_dir).join(DB_FILE); let full_path = Path::new(db_dir).join(DB_FILE);
// create a connection pool // create a connection pool
let pool = build_pool( let pool = build_pool(
"event writer", "event writer",
settings.clone(),
OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE, OpenFlags::SQLITE_OPEN_READ_WRITE | OpenFlags::SQLITE_OPEN_CREATE,
1, 1,
4, 4,
false, false,
); );
if settings.database.in_memory {
info!("using in-memory database, this will not persist a restart!");
} else {
info!("opened database {:?} for writing", full_path); info!("opened database {:?} for writing", full_path);
}
upgrade_db(&mut pool.get()?)?; upgrade_db(&mut pool.get()?)?;
// Make a copy of the whitelist // Make a copy of the whitelist
@ -178,7 +169,7 @@ pub async fn db_writer(
if nip05_enabled { if nip05_enabled {
match nip05::query_latest_user_verification(pool.get()?, event.pubkey.to_owned()) { match nip05::query_latest_user_verification(pool.get()?, event.pubkey.to_owned()) {
Ok(uv) => { Ok(uv) => {
if uv.is_valid() { if uv.is_valid(&settings.verified_users) {
info!( info!(
"new event from verified author ({:?},{:?})", "new event from verified author ({:?},{:?})",
uv.name.to_string(), uv.name.to_string(),

View File

@ -1,5 +1,4 @@
//! Event parsing and validation //! Event parsing and validation
use crate::config;
use crate::error::Error::*; use crate::error::Error::*;
use crate::error::Result; use crate::error::Result;
use crate::nip05; use crate::nip05;
@ -156,13 +155,8 @@ impl Event {
.collect() .collect()
} }
/// Check if this event has a valid signature. pub fn is_valid_timestamp(&self, reject_future_seconds: Option<usize>) -> bool {
fn is_valid(&self) -> bool { if let Some(allowable_future) = reject_future_seconds {
// TODO: return a Result with a reason for invalid events
// don't bother to validate an event with a timestamp in the distant future.
let config = config::SETTINGS.read().unwrap();
let max_future_sec = config.options.reject_future_seconds;
if let Some(allowable_future) = max_future_sec {
let curr_time = unix_time(); let curr_time = unix_time();
// calculate difference, plus how far future we allow // calculate difference, plus how far future we allow
if curr_time + (allowable_future as u64) < self.created_at { if curr_time + (allowable_future as u64) < self.created_at {
@ -174,6 +168,12 @@ impl Event {
return false; return false;
} }
} }
true
}
/// Check if this event has a valid signature.
fn is_valid(&self) -> bool {
// TODO: return a Result with a reason for invalid events
// validation is performed by: // validation is performed by:
// * parsing JSON string into event fields // * parsing JSON string into event fields
// * create an array: // * create an array:
@ -194,7 +194,6 @@ impl Event {
return false; return false;
} }
// * validate the message digest (sig) using the pubkey & computed sha256 message hash. // * validate the message digest (sig) using the pubkey & computed sha256 message hash.
let sig = schnorr::Signature::from_str(&self.sig).unwrap(); let sig = schnorr::Signature::from_str(&self.sig).unwrap();
if let Ok(msg) = secp256k1::Message::from_slice(digest.as_ref()) { if let Ok(msg) = secp256k1::Message::from_slice(digest.as_ref()) {
if let Ok(pubkey) = XOnlyPublicKey::from_str(&self.pubkey) { if let Ok(pubkey) = XOnlyPublicKey::from_str(&self.pubkey) {

View File

@ -8,6 +8,7 @@ pub mod hexrange;
pub mod info; pub mod info;
pub mod nip05; pub mod nip05;
pub mod schema; pub mod schema;
pub mod server;
pub mod subscription; pub mod subscription;
pub mod utils; pub mod utils;
// Public API for creating relays programatically
pub mod server;

View File

@ -4,6 +4,8 @@ use nostr_rs_relay::config;
use nostr_rs_relay::error::{Error, Result}; use nostr_rs_relay::error::{Error, Result};
use nostr_rs_relay::server::start_server; use nostr_rs_relay::server::start_server;
use std::env; use std::env;
use std::sync::mpsc as syncmpsc;
use std::sync::mpsc::{Receiver as MpscReceiver, Sender as MpscSender};
use std::thread; use std::thread;
/// Return a requested DB name from command line arguments. /// Return a requested DB name from command line arguments.
@ -19,22 +21,23 @@ fn main() -> Result<(), Error> {
// setup logger // setup logger
let _ = env_logger::try_init(); let _ = env_logger::try_init();
info!("Starting up from main"); info!("Starting up from main");
// get database directory from args // get database directory from args
let args: Vec<String> = env::args().collect(); let args: Vec<String> = env::args().collect();
let db_dir: Option<String> = db_from_args(args); let db_dir: Option<String> = db_from_args(args);
{ // configure settings from config.toml
let mut settings = config::SETTINGS.write().unwrap();
// replace default settings with those read from config.toml // replace default settings with those read from config.toml
let mut c = config::Settings::new(); let mut settings = config::Settings::new();
// update with database location // update with database location
if let Some(db) = db_dir { if let Some(db) = db_dir {
c.database.data_directory = db; settings.database.data_directory = db;
}
*settings = c;
} }
let (_, ctrl_rx): (MpscSender<()>, MpscReceiver<()>) = syncmpsc::channel();
// run this in a new thread // run this in a new thread
let handle = thread::spawn(|| { let handle = thread::spawn(|| {
let _ = start_server(); // we should have a 'control plane' channel to monitor and bump the server.
// this will let us do stuff like clear the database, shutdown, etc.
let _ = start_server(settings, ctrl_rx);
}); });
// block on nostr thread to finish. // block on nostr thread to finish.
handle.join().unwrap(); handle.join().unwrap();

View File

@ -4,7 +4,7 @@
//! address with their public key, in metadata events. This module //! address with their public key, in metadata events. This module
//! consumes a stream of metadata events, and keeps a database table //! consumes a stream of metadata events, and keeps a database table
//! updated with the current NIP-05 verification status. //! updated with the current NIP-05 verification status.
use crate::config::SETTINGS; use crate::config::VerifiedUsers;
use crate::db; use crate::db;
use crate::error::{Error, Result}; use crate::error::{Error, Result};
use crate::event::Event; use crate::event::Event;
@ -31,6 +31,8 @@ pub struct Verifier {
read_pool: db::SqlitePool, read_pool: db::SqlitePool,
/// SQLite write query pool /// SQLite write query pool
write_pool: db::SqlitePool, write_pool: db::SqlitePool,
/// Settings
settings: crate::config::Settings,
/// HTTP client /// HTTP client
client: hyper::Client<HttpsConnector<HttpConnector>, hyper::Body>, client: hyper::Client<HttpsConnector<HttpConnector>, hyper::Body>,
/// After all accounts are updated, wait this long before checking again. /// After all accounts are updated, wait this long before checking again.
@ -138,11 +140,13 @@ impl Verifier {
pub fn new( pub fn new(
metadata_rx: tokio::sync::broadcast::Receiver<Event>, metadata_rx: tokio::sync::broadcast::Receiver<Event>,
event_tx: tokio::sync::broadcast::Sender<Event>, event_tx: tokio::sync::broadcast::Sender<Event>,
settings: crate::config::Settings,
) -> Result<Self> { ) -> Result<Self> {
info!("creating NIP-05 verifier"); info!("creating NIP-05 verifier");
// build a database connection for reading and writing. // build a database connection for reading and writing.
let write_pool = db::build_pool( let write_pool = db::build_pool(
"nip05 writer", "nip05 writer",
settings.clone(),
rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE, rusqlite::OpenFlags::SQLITE_OPEN_READ_WRITE,
1, // min conns 1, // min conns
4, // max conns 4, // max conns
@ -150,6 +154,7 @@ impl Verifier {
); );
let read_pool = db::build_pool( let read_pool = db::build_pool(
"nip05 reader", "nip05 reader",
settings.clone(),
rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY, rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY,
1, // min conns 1, // min conns
8, // max conns 8, // max conns
@ -174,6 +179,7 @@ impl Verifier {
event_tx, event_tx,
read_pool, read_pool,
write_pool, write_pool,
settings,
client, client,
wait_after_finish, wait_after_finish,
http_wait_duration, http_wait_duration,
@ -214,7 +220,11 @@ impl Verifier {
pubkey: &str, pubkey: &str,
) -> Result<UserWebVerificationStatus> { ) -> Result<UserWebVerificationStatus> {
// determine if this domain should be checked // determine if this domain should be checked
if !is_domain_allowed(&nip.domain) { if !is_domain_allowed(
&nip.domain,
&self.settings.verified_users.domain_whitelist,
&self.settings.verified_users.domain_blacklist,
) {
return Ok(UserWebVerificationStatus::DomainNotAllowed); return Ok(UserWebVerificationStatus::DomainNotAllowed);
} }
let url = nip let url = nip
@ -347,15 +357,11 @@ impl Verifier {
/// Reverify the oldest user verification record. /// Reverify the oldest user verification record.
async fn do_reverify(&mut self) -> Result<()> { async fn do_reverify(&mut self) -> Result<()> {
let reverify_setting; let reverify_setting = self
let max_failures; .settings
{ .verified_users
// this block prevents a read handle to settings being .verify_update_frequency_duration;
// captured by the async DB call (guard is not Send) let max_failures = self.settings.verified_users.max_consecutive_failures;
let settings = SETTINGS.read().unwrap();
reverify_setting = settings.verified_users.verify_update_frequency_duration;
max_failures = settings.verified_users.max_consecutive_failures;
}
// get from settings, but default to 6hrs between re-checking an account // get from settings, but default to 6hrs between re-checking an account
let reverify_dur = reverify_setting.unwrap_or_else(|| Duration::from_secs(60 * 60 * 6)); let reverify_dur = reverify_setting.unwrap_or_else(|| Duration::from_secs(60 * 60 * 6));
// find all verification records that have success or failure OLDER than the reverify_dur. // find all verification records that have success or failure OLDER than the reverify_dur.
@ -506,11 +512,7 @@ impl Verifier {
let start = Instant::now(); let start = Instant::now();
// we should only do this if we are enabled. if we are // we should only do this if we are enabled. if we are
// disabled/passive, the event has already been persisted. // disabled/passive, the event has already been persisted.
let should_write_event; let should_write_event = self.settings.verified_users.is_enabled();
{
let settings = SETTINGS.read().unwrap();
should_write_event = settings.verified_users.is_enabled()
}
if should_write_event { if should_write_event {
match db::write_event(&mut self.write_pool.get()?, event) { match db::write_event(&mut self.write_pool.get()?, event) {
Ok(updated) => { Ok(updated) => {
@ -562,15 +564,18 @@ pub struct VerificationRecord {
/// Check with settings to determine if a given domain is allowed to /// Check with settings to determine if a given domain is allowed to
/// publish. /// publish.
pub fn is_domain_allowed(domain: &str) -> bool { pub fn is_domain_allowed(
let settings = SETTINGS.read().unwrap(); domain: &str,
whitelist: &Option<Vec<String>>,
blacklist: &Option<Vec<String>>,
) -> bool {
// if there is a whitelist, domain must be present in it. // if there is a whitelist, domain must be present in it.
if let Some(wl) = &settings.verified_users.domain_whitelist { if let Some(wl) = whitelist {
// workaround for Vec contains not accepting &str // workaround for Vec contains not accepting &str
return wl.iter().any(|x| x == domain); return wl.iter().any(|x| x == domain);
} }
// otherwise, check that user is not in the blacklist // otherwise, check that user is not in the blacklist
if let Some(bl) = &settings.verified_users.domain_blacklist { if let Some(bl) = blacklist {
return !bl.iter().any(|x| x == domain); return !bl.iter().any(|x| x == domain);
} }
true true
@ -579,17 +584,21 @@ pub fn is_domain_allowed(domain: &str) -> bool {
impl VerificationRecord { impl VerificationRecord {
/// Check if the record is recent enough to be considered valid, /// Check if the record is recent enough to be considered valid,
/// and the domain is allowed. /// and the domain is allowed.
pub fn is_valid(&self) -> bool { pub fn is_valid(&self, verified_users_settings: &VerifiedUsers) -> bool {
let settings = SETTINGS.read().unwrap(); //let settings = SETTINGS.read().unwrap();
// how long a verification record is good for // how long a verification record is good for
let nip05_expiration = &settings.verified_users.verify_expiration_duration; let nip05_expiration = &verified_users_settings.verify_expiration_duration;
if let Some(e) = nip05_expiration { if let Some(e) = nip05_expiration {
if !self.is_current(e) { if !self.is_current(e) {
return false; return false;
} }
} }
// check domains // check domains
is_domain_allowed(&self.name.domain) is_domain_allowed(
&self.name.domain,
&verified_users_settings.domain_whitelist,
&verified_users_settings.domain_blacklist,
)
} }
/// Check if this record has been validated since the given /// Check if this record has been validated since the given

View File

@ -1,7 +1,7 @@
//! Server process //! Server process
use crate::close::Close; use crate::close::Close;
use crate::close::CloseCmd; use crate::close::CloseCmd;
use crate::config; use crate::config::Settings;
use crate::conn; use crate::conn;
use crate::db; use crate::db;
use crate::db::SubmittedEvent; use crate::db::SubmittedEvent;
@ -26,6 +26,7 @@ use std::collections::HashMap;
use std::convert::Infallible; use std::convert::Infallible;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::Path; use std::path::Path;
use std::sync::mpsc::Receiver as MpscReceiver;
use std::time::Duration; use std::time::Duration;
use std::time::Instant; use std::time::Instant;
use tokio::runtime::Builder; use tokio::runtime::Builder;
@ -43,6 +44,7 @@ use tungstenite::protocol::WebSocketConfig;
async fn handle_web_request( async fn handle_web_request(
mut request: Request<Body>, mut request: Request<Body>,
pool: db::SqlitePool, pool: db::SqlitePool,
settings: Settings,
remote_addr: SocketAddr, remote_addr: SocketAddr,
broadcast: Sender<Event>, broadcast: Sender<Event>,
event_tx: tokio::sync::mpsc::Sender<SubmittedEvent>, event_tx: tokio::sync::mpsc::Sender<SubmittedEvent>,
@ -68,12 +70,11 @@ async fn handle_web_request(
//if successfully upgraded //if successfully upgraded
Ok(upgraded) => { Ok(upgraded) => {
// set WebSocket configuration options // set WebSocket configuration options
let mut config = WebSocketConfig::default(); let config = WebSocketConfig {
{ max_message_size: settings.limits.max_ws_message_bytes,
let settings = config::SETTINGS.read().unwrap(); max_frame_size: settings.limits.max_ws_frame_bytes,
config.max_message_size = settings.limits.max_ws_message_bytes; ..Default::default()
config.max_frame_size = settings.limits.max_ws_frame_bytes; };
}
//create a websocket stream from the upgraded object //create a websocket stream from the upgraded object
let ws_stream = WebSocketStream::from_raw_socket( let ws_stream = WebSocketStream::from_raw_socket(
//pass the upgraded object //pass the upgraded object
@ -85,7 +86,7 @@ async fn handle_web_request(
.await; .await;
tokio::spawn(nostr_server( tokio::spawn(nostr_server(
pool, ws_stream, broadcast, event_tx, shutdown, pool, settings, ws_stream, broadcast, event_tx, shutdown,
)); ));
} }
Err(e) => println!( Err(e) => println!(
@ -118,10 +119,9 @@ async fn handle_web_request(
if let Some(media_types) = accept_header { if let Some(media_types) = accept_header {
if let Ok(mt_str) = media_types.to_str() { if let Ok(mt_str) = media_types.to_str() {
if mt_str.contains("application/nostr+json") { if mt_str.contains("application/nostr+json") {
let config = config::SETTINGS.read().unwrap();
// build a relay info response // build a relay info response
debug!("Responding to server info request"); debug!("Responding to server info request");
let rinfo = RelayInfo::from(config.info.clone()); let rinfo = RelayInfo::from(settings.info);
let b = Body::from(serde_json::to_string_pretty(&rinfo).unwrap()); let b = Body::from(serde_json::to_string_pretty(&rinfo).unwrap());
return Ok(Response::builder() return Ok(Response::builder()
.status(200) .status(200)
@ -148,16 +148,25 @@ async fn handle_web_request(
} }
} }
async fn shutdown_signal() { // return on a control-c or internally requested shutdown signal
// Wait for the CTRL+C signal async fn ctrl_c_or_signal(mut shutdown_signal: Receiver<()>) {
tokio::signal::ctrl_c() loop {
.await tokio::select! {
.expect("failed to install CTRL+C signal handler"); _ = shutdown_signal.recv() => {
info!("Shutting down webserver as requested");
// server shutting down, exit loop
break;
},
_ = tokio::signal::ctrl_c() => {
info!("Shutting down webserver due to SIGINT");
break;
}
}
}
} }
/// Start running a Nostr relay server. /// Start running a Nostr relay server.
pub fn start_server() -> Result<(), Error> { pub fn start_server(settings: Settings, shutdown_rx: MpscReceiver<()>) -> Result<(), Error> {
let settings = config::SETTINGS.read().unwrap();
trace!("Config: {:?}", settings); trace!("Config: {:?}", settings);
// do some config validation. // do some config validation.
if !Path::new(&settings.database.data_directory).is_dir() { if !Path::new(&settings.database.data_directory).is_dir() {
@ -204,21 +213,12 @@ pub fn start_server() -> Result<(), Error> {
.unwrap(); .unwrap();
// start tokio // start tokio
rt.block_on(async { rt.block_on(async {
let broadcast_buffer_limit; let broadcast_buffer_limit = settings.limits.broadcast_buffer;
let persist_buffer_limit; let persist_buffer_limit = settings.limits.event_persist_buffer;
let verified_users_active; let verified_users_active = settings.verified_users.is_active();
let db_min_conn; let db_min_conn = settings.database.min_conn;
let db_max_conn; let db_max_conn = settings.database.max_conn;
// hack to prove we drop the mutexguard prior to any await points let settings = settings.clone();
// (https://github.com/rust-lang/rust-clippy/issues/6446)
{
let settings = config::SETTINGS.read().unwrap();
broadcast_buffer_limit = settings.limits.broadcast_buffer;
persist_buffer_limit = settings.limits.event_persist_buffer;
verified_users_active = settings.verified_users.is_active();
db_min_conn = settings.database.min_conn;
db_max_conn = settings.database.max_conn;
}
info!("listening on: {}", socket_addr); info!("listening on: {}", socket_addr);
// all client-submitted valid events are broadcast to every // all client-submitted valid events are broadcast to every
// other client on this channel. This should be large enough // other client on this channel. This should be large enough
@ -244,6 +244,7 @@ pub fn start_server() -> Result<(), Error> {
// writing events, and for publishing events that have been // writing events, and for publishing events that have been
// written (to all connected clients). // written (to all connected clients).
db::db_writer( db::db_writer(
settings.clone(),
event_rx, event_rx,
bcast_tx.clone(), bcast_tx.clone(),
metadata_tx.clone(), metadata_tx.clone(),
@ -253,7 +254,7 @@ pub fn start_server() -> Result<(), Error> {
info!("db writer created"); info!("db writer created");
// create a nip-05 verifier thread // create a nip-05 verifier thread
let verifier_opt = nip05::Verifier::new(metadata_rx, bcast_tx.clone()); let verifier_opt = nip05::Verifier::new(metadata_rx, bcast_tx.clone(), settings.clone());
if let Ok(mut v) = verifier_opt { if let Ok(mut v) = verifier_opt {
if verified_users_active { if verified_users_active {
tokio::task::spawn(async move { tokio::task::spawn(async move {
@ -262,16 +263,31 @@ pub fn start_server() -> Result<(), Error> {
}); });
} }
} }
// // listen for ctrl-c interruupts // listen for (external to tokio) shutdown request
let controlled_shutdown = invoke_shutdown.clone();
tokio::spawn(async move {
info!("control message listener started");
match shutdown_rx.recv() {
Ok(()) => {
info!("control message requesting shutdown");
controlled_shutdown.send(()).ok();
}
Err(std::sync::mpsc::RecvError) => {
debug!("shutdown requestor is disconnected");
}
};
});
// listen for ctrl-c interruupts
let ctrl_c_shutdown = invoke_shutdown.clone(); let ctrl_c_shutdown = invoke_shutdown.clone();
tokio::spawn(async move { tokio::spawn(async move {
tokio::signal::ctrl_c().await.unwrap(); tokio::signal::ctrl_c().await.unwrap();
info!("shutting down due to SIGINT"); info!("shutting down due to SIGINT (main)");
ctrl_c_shutdown.send(()).ok(); ctrl_c_shutdown.send(()).ok();
}); });
// build a connection pool for sqlite connections // build a connection pool for sqlite connections
let pool = db::build_pool( let pool = db::build_pool(
"client query", "client query",
settings.clone(),
rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY rusqlite::OpenFlags::SQLITE_OPEN_READ_ONLY
| rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE, | rusqlite::OpenFlags::SQLITE_OPEN_SHARED_CACHE,
db_min_conn, db_min_conn,
@ -286,12 +302,14 @@ pub fn start_server() -> Result<(), Error> {
let bcast = bcast_tx.clone(); let bcast = bcast_tx.clone();
let event = event_tx.clone(); let event = event_tx.clone();
let stop = invoke_shutdown.clone(); let stop = invoke_shutdown.clone();
let settings = settings.clone();
async move { async move {
// service_fn converts our function into a `Service` // service_fn converts our function into a `Service`
Ok::<_, Infallible>(service_fn(move |request: Request<Body>| { Ok::<_, Infallible>(service_fn(move |request: Request<Body>| {
handle_web_request( handle_web_request(
request, request,
svc_pool.clone(), svc_pool.clone(),
settings.clone(),
remote_addr, remote_addr,
bcast.clone(), bcast.clone(),
event.clone(), event.clone(),
@ -300,14 +318,14 @@ pub fn start_server() -> Result<(), Error> {
})) }))
} }
}); });
let shutdown_listen = invoke_shutdown.subscribe();
let server = Server::bind(&socket_addr) let server = Server::bind(&socket_addr)
.serve(make_svc) .serve(make_svc)
.with_graceful_shutdown(shutdown_signal()); .with_graceful_shutdown(ctrl_c_or_signal(shutdown_listen));
// run hyper // run hyper
if let Err(e) = server.await { if let Err(e) = server.await {
eprintln!("server error: {}", e); eprintln!("server error: {}", e);
} }
// our code
}); });
Ok(()) Ok(())
} }
@ -325,13 +343,12 @@ pub enum NostrMessage {
} }
/// Convert Message to NostrMessage /// Convert Message to NostrMessage
fn convert_to_msg(msg: String) -> Result<NostrMessage> { fn convert_to_msg(msg: String, max_bytes: Option<usize>) -> Result<NostrMessage> {
let config = config::SETTINGS.read().unwrap();
let parsed_res: Result<NostrMessage> = serde_json::from_str(&msg).map_err(|e| e.into()); let parsed_res: Result<NostrMessage> = serde_json::from_str(&msg).map_err(|e| e.into());
match parsed_res { match parsed_res {
Ok(m) => { Ok(m) => {
if let NostrMessage::EventMsg(_) = m { if let NostrMessage::EventMsg(_) = m {
if let Some(max_size) = config.limits.max_event_bytes { if let Some(max_size) = max_bytes {
// check length, ensure that some max size is set. // check length, ensure that some max size is set.
if msg.len() > max_size && max_size > 0 { if msg.len() > max_size && max_size > 0 {
return Err(Error::EventMaxLengthError(msg.len())); return Err(Error::EventMaxLengthError(msg.len()));
@ -357,6 +374,7 @@ fn make_notice_message(msg: &str) -> Message {
/// for all client communication. /// for all client communication.
async fn nostr_server( async fn nostr_server(
pool: db::SqlitePool, pool: db::SqlitePool,
settings: Settings,
mut ws_stream: WebSocketStream<Upgraded>, mut ws_stream: WebSocketStream<Upgraded>,
broadcast: Sender<Event>, broadcast: Sender<Event>,
event_tx: mpsc::Sender<SubmittedEvent>, event_tx: mpsc::Sender<SubmittedEvent>,
@ -398,6 +416,7 @@ async fn nostr_server(
loop { loop {
tokio::select! { tokio::select! {
_ = shutdown.recv() => { _ = shutdown.recv() => {
info!("Shutting client connection down due to shutdown: {:?}", cid);
// server shutting down, exit loop // server shutting down, exit loop
break; break;
}, },
@ -442,7 +461,6 @@ async fn nostr_server(
// create an event response and send it // create an event response and send it
let subesc = s.replace('"', ""); let subesc = s.replace('"', "");
ws_stream.send(Message::Text(format!("[\"EVENT\",\"{}\",{}]", subesc, event_str))).await.ok(); ws_stream.send(Message::Text(format!("[\"EVENT\",\"{}\",{}]", subesc, event_str))).await.ok();
//nostr_stream.send(res).await.ok();
} else { } else {
warn!("could not serialize event {:?}", global_event.get_event_id_prefix()); warn!("could not serialize event {:?}", global_event.get_event_id_prefix());
} }
@ -454,7 +472,7 @@ async fn nostr_server(
// Consume text messages from the client, parse into Nostr messages. // Consume text messages from the client, parse into Nostr messages.
let nostr_msg = match ws_next { let nostr_msg = match ws_next {
Some(Ok(Message::Text(m))) => { Some(Ok(Message::Text(m))) => {
convert_to_msg(m) convert_to_msg(m,settings.limits.max_event_bytes)
}, },
Some(Ok(Message::Binary(_))) => { Some(Ok(Message::Binary(_))) => {
ws_stream.send( ws_stream.send(
@ -503,10 +521,17 @@ async fn nostr_server(
Ok(e) => { Ok(e) => {
let id_prefix:String = e.id.chars().take(8).collect(); let id_prefix:String = e.id.chars().take(8).collect();
debug!("successfully parsed/validated event: {:?} from client: {:?}", id_prefix, cid); debug!("successfully parsed/validated event: {:?} from client: {:?}", id_prefix, cid);
// check if the event is too far in the future.
if e.is_valid_timestamp(settings.options.reject_future_seconds) {
// Write this to the database. // Write this to the database.
let submit_event = SubmittedEvent { event: e.clone(), notice_tx: notice_tx.clone() }; let submit_event = SubmittedEvent { event: e.clone(), notice_tx: notice_tx.clone() };
event_tx.send(submit_event).await.ok(); event_tx.send(submit_event).await.ok();
client_published_event_count += 1; client_published_event_count += 1;
} else {
info!("client {:?} sent a far future-dated event", cid);
ws_stream.send(make_notice_message("event was too far in the future")).await.ok();
}
}, },
Err(_) => { Err(_) => {
info!("client {:?} sent an invalid event", cid); info!("client {:?} sent an invalid event", cid);