mirror of
https://github.com/scsibug/nostr-rs-relay.git
synced 2024-12-22 16:35:51 -05:00
feat: tag and query (postgres)
This commit is contained in:
parent
7120de4ff8
commit
ae5489a97e
|
@ -10,6 +10,7 @@ repository = "https://git.sr.ht/~gheartsfield/nostr-rs-relay"
|
|||
license = "MIT"
|
||||
keywords = ["nostr", "server"]
|
||||
categories = ["network-programming", "web-programming"]
|
||||
default-run = "nostr-rs-relay"
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.0.32", features = ["env", "default", "derive"]}
|
||||
|
|
20
src/event.rs
20
src/event.rs
|
@ -19,6 +19,7 @@ use std::collections::HashMap;
|
|||
use std::collections::HashSet;
|
||||
use std::str::FromStr;
|
||||
use tracing::{debug, info};
|
||||
use crate::subscription::TagOperand;
|
||||
|
||||
lazy_static! {
|
||||
/// Secp256k1 verification instance.
|
||||
|
@ -28,7 +29,8 @@ lazy_static! {
|
|||
/// Event command in network format.
|
||||
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
|
||||
pub struct EventCmd {
|
||||
cmd: String, // expecting static "EVENT"
|
||||
// expecting static "EVENT"
|
||||
cmd: String,
|
||||
event: Event,
|
||||
}
|
||||
|
||||
|
@ -63,8 +65,8 @@ type Tag = Vec<Vec<String>>;
|
|||
|
||||
/// Deserializer that ensures we always have a [`Tag`].
|
||||
fn tag_from_string<'de, D>(deserializer: D) -> Result<Tag, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let opt = Option::deserialize(deserializer)?;
|
||||
Ok(opt.unwrap_or_default())
|
||||
|
@ -409,13 +411,15 @@ impl Event {
|
|||
|
||||
/// Determine if the given tag and value set intersect with tags in this event.
|
||||
#[must_use]
|
||||
pub fn generic_tag_val_intersect(&self, tagname: char, check: &HashSet<String>) -> bool {
|
||||
pub fn generic_tag_val_intersect(&self, tagname: char, check: &TagOperand) -> bool {
|
||||
match &self.tagidx {
|
||||
// check if this is indexable tagname
|
||||
Some(idx) => match idx.get(&tagname) {
|
||||
Some(valset) => {
|
||||
let common = valset.intersection(check);
|
||||
common.count() > 0
|
||||
match &check {
|
||||
TagOperand::And(v) => valset.intersection(v).count() == v.len(),
|
||||
TagOperand::Or(v) => valset.intersection(v).count() > 0
|
||||
}
|
||||
}
|
||||
None => false,
|
||||
},
|
||||
|
@ -464,7 +468,7 @@ mod tests {
|
|||
fn empty_event_tag_match() {
|
||||
let event = Event::simple_event();
|
||||
assert!(!event
|
||||
.generic_tag_val_intersect('e', &HashSet::from(["foo".to_owned(), "bar".to_owned()])));
|
||||
.generic_tag_val_intersect('e', &TagOperand::Or(HashSet::from(["foo".to_owned(), "bar".to_owned()]))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -475,7 +479,7 @@ mod tests {
|
|||
assert!(
|
||||
event.generic_tag_val_intersect(
|
||||
'e',
|
||||
&HashSet::from(["foo".to_owned(), "bar".to_owned()])
|
||||
&TagOperand::Or(HashSet::from(["foo".to_owned(), "bar".to_owned()])),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::event::{single_char_tagname, Event};
|
|||
use crate::nip05::{Nip05Name, VerificationRecord};
|
||||
use crate::payment::{InvoiceInfo, InvoiceStatus};
|
||||
use crate::repo::{now_jitter, NostrRepo};
|
||||
use crate::subscription::{ReqFilter, Subscription};
|
||||
use crate::subscription::{ReqFilter, Subscription, TagOperand};
|
||||
use async_std::stream::StreamExt;
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, TimeZone, Utc};
|
||||
|
@ -725,10 +725,68 @@ fn query_from_filter(f: &ReqFilter) -> Option<QueryBuilder<Postgres>> {
|
|||
return None;
|
||||
}
|
||||
|
||||
let mut query = QueryBuilder::new("SELECT e.\"content\", e.created_at FROM \"event\" e WHERE ");
|
||||
let mut query = QueryBuilder::new("SELECT e.\"content\", e.created_at FROM \"event\" e");
|
||||
|
||||
// This tracks whether we need to push a prefix AND before adding another clause
|
||||
let mut push_and = false;
|
||||
|
||||
// Query for tags
|
||||
if let Some(map) = &f.tags {
|
||||
if !map.is_empty() {
|
||||
let mut tag_ctr = 1;
|
||||
for (key, val) in map.iter() {
|
||||
let has_plain_values = val.into_iter().any(|v| !is_lower_hex(v));
|
||||
let has_hex_values = val.into_iter().any(|v| v.len() % 2 == 0 && is_lower_hex(v));
|
||||
|
||||
if let TagOperand::Or(v_or) = val {
|
||||
query.push(format!(" JOIN tag t{0} on e.id = t{0}.event_id AND t{0}.\"name\" = ", tag_ctr))
|
||||
.push_bind(key.to_string())
|
||||
.push(" AND (");
|
||||
|
||||
if has_plain_values {
|
||||
query.push(format!("t{0}.\"value\" in (", tag_ctr));
|
||||
let mut tag_query = query.separated(", ");
|
||||
for v in v_or.iter()
|
||||
.filter(|v| !is_lower_hex(v)) {
|
||||
tag_query.push_bind(v.as_bytes());
|
||||
}
|
||||
}
|
||||
if has_plain_values && has_hex_values {
|
||||
query.push(") OR ");
|
||||
}
|
||||
if has_hex_values {
|
||||
query.push(format!("t{0}.\"value_hex\" in (", tag_ctr));
|
||||
let mut tag_query = query.separated(", ");
|
||||
for v in v_or.iter()
|
||||
.filter(|v| v.len() % 2 == 0 && is_lower_hex(v)) {
|
||||
tag_query.push_bind(hex::decode(v).ok());
|
||||
}
|
||||
}
|
||||
|
||||
tag_ctr += 1;
|
||||
query.push("))");
|
||||
} else if let TagOperand::And(v_and) = val {
|
||||
for vx in v_and.iter() {
|
||||
query.push(format!(" JOIN \"tag\" t{0} on e.id = t{0}.event_id AND t{0}.\"name\" = ", tag_ctr))
|
||||
.push_bind(key.to_string())
|
||||
.push(" AND ");
|
||||
|
||||
if !is_lower_hex(vx) {
|
||||
query.push(format!("t{0}.\"value\" = ", tag_ctr))
|
||||
.push_bind(vx.as_bytes());
|
||||
} else {
|
||||
query.push(format!("t{0}.\"value_hex\" = ", tag_ctr))
|
||||
.push_bind(hex::decode(vx).ok());
|
||||
}
|
||||
|
||||
tag_ctr += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
query.push(" WHERE ");
|
||||
// Query for "authors", allowing prefix matches
|
||||
if let Some(auth_vec) = &f.authors {
|
||||
// filter out non-hex values
|
||||
|
@ -870,48 +928,6 @@ fn query_from_filter(f: &ReqFilter) -> Option<QueryBuilder<Postgres>> {
|
|||
}
|
||||
}
|
||||
|
||||
// Query for tags
|
||||
if let Some(map) = &f.tags {
|
||||
if !map.is_empty() {
|
||||
if push_and {
|
||||
query.push(" AND ");
|
||||
}
|
||||
push_and = true;
|
||||
|
||||
for (key, val) in map.iter() {
|
||||
query.push("e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = ")
|
||||
.push_bind(key.to_string())
|
||||
.push(" AND (");
|
||||
|
||||
let has_plain_values = val.iter().any(|v| !is_lower_hex(v));
|
||||
let has_hex_values = val.iter().any(|v| v.len() % 2 == 0 && is_lower_hex(v));
|
||||
if has_plain_values {
|
||||
query.push("value in (");
|
||||
// plain value match first
|
||||
let mut tag_query = query.separated(", ");
|
||||
for v in val.iter()
|
||||
.filter(|v| !is_lower_hex(v)) {
|
||||
tag_query.push_bind(v.as_bytes());
|
||||
}
|
||||
}
|
||||
if has_plain_values && has_hex_values {
|
||||
query.push(") OR ");
|
||||
}
|
||||
if has_hex_values {
|
||||
query.push("value_hex in (");
|
||||
// plain value match first
|
||||
let mut tag_query = query.separated(", ");
|
||||
for v in val.iter()
|
||||
.filter(|v| v.len() % 2 == 0 && is_lower_hex(v)) {
|
||||
tag_query.push_bind(hex::decode(v).ok());
|
||||
}
|
||||
}
|
||||
|
||||
query.push("))))");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Query for timestamp
|
||||
if f.since.is_some() {
|
||||
if push_and {
|
||||
|
@ -981,6 +997,7 @@ impl FromRow<'_, PgRow> for VerificationRecord {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use crate::subscription::TagOperand;
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
|
@ -993,13 +1010,13 @@ mod tests {
|
|||
authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]),
|
||||
limit: None,
|
||||
tags: Some(HashMap::from([
|
||||
('p', HashSet::from(["63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()]))
|
||||
('p', TagOperand::Or(HashSet::from(["63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()])))
|
||||
])),
|
||||
force_no_match: false,
|
||||
};
|
||||
|
||||
let q = query_from_filter(&filter).unwrap();
|
||||
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e WHERE (e.pub_key in ($1) OR e.delegated_by in ($2)) AND e.kind in ($3) AND e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = $4 AND (value_hex in ($5)))) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
|
||||
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN tag t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND (t1.\"value_hex\" in ($2)) WHERE (e.pub_key in ($3) OR e.delegated_by in ($4)) AND e.kind in ($5) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -1012,13 +1029,13 @@ mod tests {
|
|||
authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]),
|
||||
limit: None,
|
||||
tags: Some(HashMap::from([
|
||||
('d', HashSet::from(["test".to_owned()]))
|
||||
('d', TagOperand::Or(HashSet::from(["test".to_owned()])))
|
||||
])),
|
||||
force_no_match: false,
|
||||
};
|
||||
|
||||
let q = query_from_filter(&filter).unwrap();
|
||||
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e WHERE (e.pub_key in ($1) OR e.delegated_by in ($2)) AND e.kind in ($3) AND e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = $4 AND (value in ($5)))) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
|
||||
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN tag t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND (t1.\"value\" in ($2)) WHERE (e.pub_key in ($3) OR e.delegated_by in ($4)) AND e.kind in ($5) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -1031,12 +1048,31 @@ mod tests {
|
|||
authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]),
|
||||
limit: None,
|
||||
tags: Some(HashMap::from([
|
||||
('d', HashSet::from(["test".to_owned(), "63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()]))
|
||||
('d', TagOperand::Or(HashSet::from(["test".to_owned(), "63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()])))
|
||||
])),
|
||||
force_no_match: false,
|
||||
};
|
||||
|
||||
let q = query_from_filter(&filter).unwrap();
|
||||
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e WHERE (e.pub_key in ($1) OR e.delegated_by in ($2)) AND e.kind in ($3) AND e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = $4 AND (value in ($5) OR value_hex in ($6)))) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
|
||||
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN tag t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND (t1.\"value\" in ($2) OR t1.\"value_hex\" in ($3)) WHERE (e.pub_key in ($4) OR e.delegated_by in ($5)) AND e.kind in ($6) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_gen_tag_value_hex_and() {
|
||||
let filter = ReqFilter {
|
||||
ids: None,
|
||||
kinds: Some(vec![1000]),
|
||||
since: None,
|
||||
until: None,
|
||||
authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]),
|
||||
limit: None,
|
||||
tags: Some(HashMap::from([
|
||||
('p', TagOperand::And(HashSet::from(["63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned(), "84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()])))
|
||||
])),
|
||||
force_no_match: false,
|
||||
};
|
||||
|
||||
let q = query_from_filter(&filter).unwrap();
|
||||
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN \"tag\" t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND t1.\"value_hex\" = $2 JOIN \"tag\" t2 on e.id = t2.event_id AND t2.\"name\" = $3 AND t2.\"value_hex\" = $4 WHERE (e.pub_key in ($5) OR e.delegated_by in ($6)) AND e.kind in ($7) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
|
||||
}
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
//! Subscription and filter parsing
|
||||
use std::collections::hash_set::Iter;
|
||||
use crate::error::Result;
|
||||
use crate::event::Event;
|
||||
use serde::de::Unexpected;
|
||||
|
@ -15,6 +16,25 @@ pub struct Subscription {
|
|||
pub filters: Vec<ReqFilter>,
|
||||
}
|
||||
|
||||
/// Tag query is AND or OR operation
|
||||
#[derive(Serialize, PartialEq, Eq, Debug, Clone)]
|
||||
pub enum TagOperand {
|
||||
And(HashSet<String>),
|
||||
Or(HashSet<String>),
|
||||
}
|
||||
|
||||
impl<'a> IntoIterator for &'a TagOperand {
|
||||
type Item = &'a String;
|
||||
type IntoIter = Iter<'a, String>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
match self {
|
||||
TagOperand::Or(vv) => vv.iter(),
|
||||
TagOperand::And(vv) => vv.iter()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Filter for requests
|
||||
///
|
||||
/// Corresponds to client-provided subscription request elements. Any
|
||||
|
@ -35,7 +55,7 @@ pub struct ReqFilter {
|
|||
/// Limit number of results
|
||||
pub limit: Option<u64>,
|
||||
/// Set of tags
|
||||
pub tags: Option<HashMap<char, HashSet<String>>>,
|
||||
pub tags: Option<HashMap<char, TagOperand>>,
|
||||
/// Force no matches due to malformed data
|
||||
// we can't represent it in the req filter, so we don't want to
|
||||
// erroneously match. This basically indicates the req tried to
|
||||
|
@ -45,8 +65,8 @@ pub struct ReqFilter {
|
|||
|
||||
impl Serialize for ReqFilter {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_map(None)?;
|
||||
if let Some(ids) = &self.ids {
|
||||
|
@ -70,7 +90,7 @@ impl Serialize for ReqFilter {
|
|||
// serialize tags
|
||||
if let Some(tags) = &self.tags {
|
||||
for (k, v) in tags {
|
||||
let vals: Vec<&String> = v.iter().collect();
|
||||
let vals: Vec<&String> = v.into_iter().collect();
|
||||
map.serialize_entry(&format!("#{k}"), &vals)?;
|
||||
}
|
||||
}
|
||||
|
@ -80,8 +100,8 @@ impl Serialize for ReqFilter {
|
|||
|
||||
impl<'de> Deserialize<'de> for ReqFilter {
|
||||
fn deserialize<D>(deserializer: D) -> Result<ReqFilter, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let received: Value = Deserialize::deserialize(deserializer)?;
|
||||
let filter = received.as_object().ok_or_else(|| {
|
||||
|
@ -101,7 +121,7 @@ impl<'de> Deserialize<'de> for ReqFilter {
|
|||
force_no_match: false,
|
||||
};
|
||||
let empty_string = "".into();
|
||||
let mut ts = None;
|
||||
let mut ts: Option<HashMap<char, TagOperand>> = None;
|
||||
// iterate through each key, and assign values that exist
|
||||
for (key, val) in filter {
|
||||
// ids
|
||||
|
@ -135,24 +155,31 @@ impl<'de> Deserialize<'de> for ReqFilter {
|
|||
}
|
||||
}
|
||||
rf.authors = raw_authors;
|
||||
} else if key.starts_with('#') && key.len() > 1 && val.is_array() {
|
||||
if let Some(tag_search) = tag_search_char_from_filter(key) {
|
||||
if ts.is_none() {
|
||||
// Initialize the tag if necessary
|
||||
ts = Some(HashMap::new());
|
||||
}
|
||||
if let Some(m) = ts.as_mut() {
|
||||
let tag_vals: Option<Vec<String>> = Deserialize::deserialize(val).ok();
|
||||
if let Some(v) = tag_vals {
|
||||
let hs = v.into_iter().collect::<HashSet<_>>();
|
||||
m.insert(tag_search.to_owned(), hs);
|
||||
}
|
||||
};
|
||||
} else {
|
||||
// tag search that is multi-character, don't add to subscription
|
||||
rf.force_no_match = true;
|
||||
continue;
|
||||
} else if key.starts_with('#') && key.len() > 1 && key.len() < 4 && val.is_array() {
|
||||
if ts.is_none() {
|
||||
// Initialize the tag if necessary
|
||||
ts = Some(HashMap::new());
|
||||
}
|
||||
if let Some(m) = ts.as_mut() {
|
||||
let tag_vals: Option<Vec<String>> = Deserialize::deserialize(val).ok();
|
||||
if let Some(v) = tag_vals {
|
||||
let hs = v.into_iter().collect::<HashSet<_>>();
|
||||
let hs_op = match key.len() {
|
||||
2 => Some(TagOperand::Or(hs)),
|
||||
3 => {
|
||||
if key.chars().nth(2).unwrap() == '&' {
|
||||
Some(TagOperand::And(hs))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
_ => None
|
||||
};
|
||||
if let Some(hs_some) = hs_op {
|
||||
m.insert(key.chars().nth(1).unwrap(), hs_some);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
rf.tags = ts;
|
||||
|
@ -160,32 +187,12 @@ impl<'de> Deserialize<'de> for ReqFilter {
|
|||
}
|
||||
}
|
||||
|
||||
/// Attempt to form a single-char identifier from a tag search filter
|
||||
fn tag_search_char_from_filter(tagname: &str) -> Option<char> {
|
||||
let tagname_nohash = &tagname[1..];
|
||||
// We return the tag character if and only if the tagname consists
|
||||
// of a single char.
|
||||
let mut tagnamechars = tagname_nohash.chars();
|
||||
let firstchar = tagnamechars.next();
|
||||
match firstchar {
|
||||
Some(_) => {
|
||||
// check second char
|
||||
if tagnamechars.next().is_none() {
|
||||
firstchar
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Subscription {
|
||||
/// Custom deserializer for subscriptions, which have a more
|
||||
/// complex structure than the other message types.
|
||||
fn deserialize<D>(deserializer: D) -> Result<Subscription, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let mut v: Value = Deserialize::deserialize(deserializer)?;
|
||||
// this should be a 3-or-more element array.
|
||||
|
|
Loading…
Reference in New Issue
Block a user