diff --git a/Cargo.toml b/Cargo.toml index 401165f..a2a9868 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ repository = "https://git.sr.ht/~gheartsfield/nostr-rs-relay" license = "MIT" keywords = ["nostr", "server"] categories = ["network-programming", "web-programming"] +default-run = "nostr-rs-relay" [dependencies] clap = { version = "4.0.32", features = ["env", "default", "derive"]} diff --git a/src/event.rs b/src/event.rs index 5fb3542..076d8e6 100644 --- a/src/event.rs +++ b/src/event.rs @@ -19,6 +19,7 @@ use std::collections::HashMap; use std::collections::HashSet; use std::str::FromStr; use tracing::{debug, info}; +use crate::subscription::TagOperand; lazy_static! { /// Secp256k1 verification instance. @@ -28,7 +29,8 @@ lazy_static! { /// Event command in network format. #[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)] pub struct EventCmd { - cmd: String, // expecting static "EVENT" + // expecting static "EVENT" + cmd: String, event: Event, } @@ -63,8 +65,8 @@ type Tag = Vec>; /// Deserializer that ensures we always have a [`Tag`]. fn tag_from_string<'de, D>(deserializer: D) -> Result -where - D: Deserializer<'de>, + where + D: Deserializer<'de>, { let opt = Option::deserialize(deserializer)?; Ok(opt.unwrap_or_default()) @@ -409,13 +411,15 @@ impl Event { /// Determine if the given tag and value set intersect with tags in this event. #[must_use] - pub fn generic_tag_val_intersect(&self, tagname: char, check: &HashSet) -> bool { + pub fn generic_tag_val_intersect(&self, tagname: char, check: &TagOperand) -> bool { match &self.tagidx { // check if this is indexable tagname Some(idx) => match idx.get(&tagname) { Some(valset) => { - let common = valset.intersection(check); - common.count() > 0 + match &check { + TagOperand::And(v) => valset.intersection(v).count() == v.len(), + TagOperand::Or(v) => valset.intersection(v).count() > 0 + } } None => false, }, @@ -464,7 +468,7 @@ mod tests { fn empty_event_tag_match() { let event = Event::simple_event(); assert!(!event - .generic_tag_val_intersect('e', &HashSet::from(["foo".to_owned(), "bar".to_owned()]))); + .generic_tag_val_intersect('e', &TagOperand::Or(HashSet::from(["foo".to_owned(), "bar".to_owned()])))); } #[test] @@ -475,7 +479,7 @@ mod tests { assert!( event.generic_tag_val_intersect( 'e', - &HashSet::from(["foo".to_owned(), "bar".to_owned()]) + &TagOperand::Or(HashSet::from(["foo".to_owned(), "bar".to_owned()])), ) ); } diff --git a/src/repo/postgres.rs b/src/repo/postgres.rs index 75ba1a8..dd90ce1 100644 --- a/src/repo/postgres.rs +++ b/src/repo/postgres.rs @@ -4,7 +4,7 @@ use crate::event::{single_char_tagname, Event}; use crate::nip05::{Nip05Name, VerificationRecord}; use crate::payment::{InvoiceInfo, InvoiceStatus}; use crate::repo::{now_jitter, NostrRepo}; -use crate::subscription::{ReqFilter, Subscription}; +use crate::subscription::{ReqFilter, Subscription, TagOperand}; use async_std::stream::StreamExt; use async_trait::async_trait; use chrono::{DateTime, TimeZone, Utc}; @@ -725,10 +725,68 @@ fn query_from_filter(f: &ReqFilter) -> Option> { return None; } - let mut query = QueryBuilder::new("SELECT e.\"content\", e.created_at FROM \"event\" e WHERE "); + let mut query = QueryBuilder::new("SELECT e.\"content\", e.created_at FROM \"event\" e"); // This tracks whether we need to push a prefix AND before adding another clause let mut push_and = false; + + // Query for tags + if let Some(map) = &f.tags { + if !map.is_empty() { + let mut tag_ctr = 1; + for (key, val) in map.iter() { + let has_plain_values = val.into_iter().any(|v| !is_lower_hex(v)); + let has_hex_values = val.into_iter().any(|v| v.len() % 2 == 0 && is_lower_hex(v)); + + if let TagOperand::Or(v_or) = val { + query.push(format!(" JOIN tag t{0} on e.id = t{0}.event_id AND t{0}.\"name\" = ", tag_ctr)) + .push_bind(key.to_string()) + .push(" AND ("); + + if has_plain_values { + query.push(format!("t{0}.\"value\" in (", tag_ctr)); + let mut tag_query = query.separated(", "); + for v in v_or.iter() + .filter(|v| !is_lower_hex(v)) { + tag_query.push_bind(v.as_bytes()); + } + } + if has_plain_values && has_hex_values { + query.push(") OR "); + } + if has_hex_values { + query.push(format!("t{0}.\"value_hex\" in (", tag_ctr)); + let mut tag_query = query.separated(", "); + for v in v_or.iter() + .filter(|v| v.len() % 2 == 0 && is_lower_hex(v)) { + tag_query.push_bind(hex::decode(v).ok()); + } + } + + tag_ctr += 1; + query.push("))"); + } else if let TagOperand::And(v_and) = val { + for vx in v_and.iter() { + query.push(format!(" JOIN \"tag\" t{0} on e.id = t{0}.event_id AND t{0}.\"name\" = ", tag_ctr)) + .push_bind(key.to_string()) + .push(" AND "); + + if !is_lower_hex(vx) { + query.push(format!("t{0}.\"value\" = ", tag_ctr)) + .push_bind(vx.as_bytes()); + } else { + query.push(format!("t{0}.\"value_hex\" = ", tag_ctr)) + .push_bind(hex::decode(vx).ok()); + } + + tag_ctr += 1; + } + } + } + } + } + + query.push(" WHERE "); // Query for "authors", allowing prefix matches if let Some(auth_vec) = &f.authors { // filter out non-hex values @@ -870,48 +928,6 @@ fn query_from_filter(f: &ReqFilter) -> Option> { } } - // Query for tags - if let Some(map) = &f.tags { - if !map.is_empty() { - if push_and { - query.push(" AND "); - } - push_and = true; - - for (key, val) in map.iter() { - query.push("e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = ") - .push_bind(key.to_string()) - .push(" AND ("); - - let has_plain_values = val.iter().any(|v| !is_lower_hex(v)); - let has_hex_values = val.iter().any(|v| v.len() % 2 == 0 && is_lower_hex(v)); - if has_plain_values { - query.push("value in ("); - // plain value match first - let mut tag_query = query.separated(", "); - for v in val.iter() - .filter(|v| !is_lower_hex(v)) { - tag_query.push_bind(v.as_bytes()); - } - } - if has_plain_values && has_hex_values { - query.push(") OR "); - } - if has_hex_values { - query.push("value_hex in ("); - // plain value match first - let mut tag_query = query.separated(", "); - for v in val.iter() - .filter(|v| v.len() % 2 == 0 && is_lower_hex(v)) { - tag_query.push_bind(hex::decode(v).ok()); - } - } - - query.push("))))"); - } - } - } - // Query for timestamp if f.since.is_some() { if push_and { @@ -981,6 +997,7 @@ impl FromRow<'_, PgRow> for VerificationRecord { #[cfg(test)] mod tests { use std::collections::{HashMap, HashSet}; + use crate::subscription::TagOperand; use super::*; #[test] @@ -993,13 +1010,13 @@ mod tests { authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]), limit: None, tags: Some(HashMap::from([ - ('p', HashSet::from(["63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()])) + ('p', TagOperand::Or(HashSet::from(["63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()]))) ])), force_no_match: false, }; let q = query_from_filter(&filter).unwrap(); - assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e WHERE (e.pub_key in ($1) OR e.delegated_by in ($2)) AND e.kind in ($3) AND e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = $4 AND (value_hex in ($5)))) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000") + assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN tag t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND (t1.\"value_hex\" in ($2)) WHERE (e.pub_key in ($3) OR e.delegated_by in ($4)) AND e.kind in ($5) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000") } #[test] @@ -1012,13 +1029,13 @@ mod tests { authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]), limit: None, tags: Some(HashMap::from([ - ('d', HashSet::from(["test".to_owned()])) + ('d', TagOperand::Or(HashSet::from(["test".to_owned()]))) ])), force_no_match: false, }; let q = query_from_filter(&filter).unwrap(); - assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e WHERE (e.pub_key in ($1) OR e.delegated_by in ($2)) AND e.kind in ($3) AND e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = $4 AND (value in ($5)))) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000") + assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN tag t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND (t1.\"value\" in ($2)) WHERE (e.pub_key in ($3) OR e.delegated_by in ($4)) AND e.kind in ($5) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000") } #[test] @@ -1031,12 +1048,31 @@ mod tests { authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]), limit: None, tags: Some(HashMap::from([ - ('d', HashSet::from(["test".to_owned(), "63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()])) + ('d', TagOperand::Or(HashSet::from(["test".to_owned(), "63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()]))) ])), force_no_match: false, }; let q = query_from_filter(&filter).unwrap(); - assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e WHERE (e.pub_key in ($1) OR e.delegated_by in ($2)) AND e.kind in ($3) AND e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = $4 AND (value in ($5) OR value_hex in ($6)))) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000") + assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN tag t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND (t1.\"value\" in ($2) OR t1.\"value_hex\" in ($3)) WHERE (e.pub_key in ($4) OR e.delegated_by in ($5)) AND e.kind in ($6) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000") + } + + #[test] + fn test_query_gen_tag_value_hex_and() { + let filter = ReqFilter { + ids: None, + kinds: Some(vec![1000]), + since: None, + until: None, + authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]), + limit: None, + tags: Some(HashMap::from([ + ('p', TagOperand::And(HashSet::from(["63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned(), "84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]))) + ])), + force_no_match: false, + }; + + let q = query_from_filter(&filter).unwrap(); + assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN \"tag\" t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND t1.\"value_hex\" = $2 JOIN \"tag\" t2 on e.id = t2.event_id AND t2.\"name\" = $3 AND t2.\"value_hex\" = $4 WHERE (e.pub_key in ($5) OR e.delegated_by in ($6)) AND e.kind in ($7) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000") } } \ No newline at end of file diff --git a/src/subscription.rs b/src/subscription.rs index 17aaceb..ed66a0d 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -1,4 +1,5 @@ //! Subscription and filter parsing +use std::collections::hash_set::Iter; use crate::error::Result; use crate::event::Event; use serde::de::Unexpected; @@ -15,6 +16,25 @@ pub struct Subscription { pub filters: Vec, } +/// Tag query is AND or OR operation +#[derive(Serialize, PartialEq, Eq, Debug, Clone)] +pub enum TagOperand { + And(HashSet), + Or(HashSet), +} + +impl<'a> IntoIterator for &'a TagOperand { + type Item = &'a String; + type IntoIter = Iter<'a, String>; + + fn into_iter(self) -> Self::IntoIter { + match self { + TagOperand::Or(vv) => vv.iter(), + TagOperand::And(vv) => vv.iter() + } + } +} + /// Filter for requests /// /// Corresponds to client-provided subscription request elements. Any @@ -35,7 +55,7 @@ pub struct ReqFilter { /// Limit number of results pub limit: Option, /// Set of tags - pub tags: Option>>, + pub tags: Option>, /// Force no matches due to malformed data // we can't represent it in the req filter, so we don't want to // erroneously match. This basically indicates the req tried to @@ -45,8 +65,8 @@ pub struct ReqFilter { impl Serialize for ReqFilter { fn serialize(&self, serializer: S) -> Result - where - S: Serializer, + where + S: Serializer, { let mut map = serializer.serialize_map(None)?; if let Some(ids) = &self.ids { @@ -70,7 +90,7 @@ impl Serialize for ReqFilter { // serialize tags if let Some(tags) = &self.tags { for (k, v) in tags { - let vals: Vec<&String> = v.iter().collect(); + let vals: Vec<&String> = v.into_iter().collect(); map.serialize_entry(&format!("#{k}"), &vals)?; } } @@ -80,8 +100,8 @@ impl Serialize for ReqFilter { impl<'de> Deserialize<'de> for ReqFilter { fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, + where + D: Deserializer<'de>, { let received: Value = Deserialize::deserialize(deserializer)?; let filter = received.as_object().ok_or_else(|| { @@ -101,7 +121,7 @@ impl<'de> Deserialize<'de> for ReqFilter { force_no_match: false, }; let empty_string = "".into(); - let mut ts = None; + let mut ts: Option> = None; // iterate through each key, and assign values that exist for (key, val) in filter { // ids @@ -135,24 +155,31 @@ impl<'de> Deserialize<'de> for ReqFilter { } } rf.authors = raw_authors; - } else if key.starts_with('#') && key.len() > 1 && val.is_array() { - if let Some(tag_search) = tag_search_char_from_filter(key) { - if ts.is_none() { - // Initialize the tag if necessary - ts = Some(HashMap::new()); - } - if let Some(m) = ts.as_mut() { - let tag_vals: Option> = Deserialize::deserialize(val).ok(); - if let Some(v) = tag_vals { - let hs = v.into_iter().collect::>(); - m.insert(tag_search.to_owned(), hs); - } - }; - } else { - // tag search that is multi-character, don't add to subscription - rf.force_no_match = true; - continue; + } else if key.starts_with('#') && key.len() > 1 && key.len() < 4 && val.is_array() { + if ts.is_none() { + // Initialize the tag if necessary + ts = Some(HashMap::new()); } + if let Some(m) = ts.as_mut() { + let tag_vals: Option> = Deserialize::deserialize(val).ok(); + if let Some(v) = tag_vals { + let hs = v.into_iter().collect::>(); + let hs_op = match key.len() { + 2 => Some(TagOperand::Or(hs)), + 3 => { + if key.chars().nth(2).unwrap() == '&' { + Some(TagOperand::And(hs)) + } else { + None + } + } + _ => None + }; + if let Some(hs_some) = hs_op { + m.insert(key.chars().nth(1).unwrap(), hs_some); + } + } + }; } } rf.tags = ts; @@ -160,32 +187,12 @@ impl<'de> Deserialize<'de> for ReqFilter { } } -/// Attempt to form a single-char identifier from a tag search filter -fn tag_search_char_from_filter(tagname: &str) -> Option { - let tagname_nohash = &tagname[1..]; - // We return the tag character if and only if the tagname consists - // of a single char. - let mut tagnamechars = tagname_nohash.chars(); - let firstchar = tagnamechars.next(); - match firstchar { - Some(_) => { - // check second char - if tagnamechars.next().is_none() { - firstchar - } else { - None - } - } - None => None, - } -} - impl<'de> Deserialize<'de> for Subscription { /// Custom deserializer for subscriptions, which have a more /// complex structure than the other message types. fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, + where + D: Deserializer<'de>, { let mut v: Value = Deserialize::deserialize(deserializer)?; // this should be a 3-or-more element array.