feat: tag and query (postgres)

This commit is contained in:
Kieran 2023-11-28 13:54:09 +00:00
parent 7120de4ff8
commit ae5489a97e
No known key found for this signature in database
GPG Key ID: DE71CEB3925BE941
4 changed files with 152 additions and 104 deletions

View File

@ -10,6 +10,7 @@ repository = "https://git.sr.ht/~gheartsfield/nostr-rs-relay"
license = "MIT"
keywords = ["nostr", "server"]
categories = ["network-programming", "web-programming"]
default-run = "nostr-rs-relay"
[dependencies]
clap = { version = "4.0.32", features = ["env", "default", "derive"]}

View File

@ -19,6 +19,7 @@ use std::collections::HashMap;
use std::collections::HashSet;
use std::str::FromStr;
use tracing::{debug, info};
use crate::subscription::TagOperand;
lazy_static! {
/// Secp256k1 verification instance.
@ -28,7 +29,8 @@ lazy_static! {
/// Event command in network format.
#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Clone)]
pub struct EventCmd {
cmd: String, // expecting static "EVENT"
// expecting static "EVENT"
cmd: String,
event: Event,
}
@ -63,8 +65,8 @@ type Tag = Vec<Vec<String>>;
/// Deserializer that ensures we always have a [`Tag`].
fn tag_from_string<'de, D>(deserializer: D) -> Result<Tag, D::Error>
where
D: Deserializer<'de>,
where
D: Deserializer<'de>,
{
let opt = Option::deserialize(deserializer)?;
Ok(opt.unwrap_or_default())
@ -409,13 +411,15 @@ impl Event {
/// Determine if the given tag and value set intersect with tags in this event.
#[must_use]
pub fn generic_tag_val_intersect(&self, tagname: char, check: &HashSet<String>) -> bool {
pub fn generic_tag_val_intersect(&self, tagname: char, check: &TagOperand) -> bool {
match &self.tagidx {
// check if this is indexable tagname
Some(idx) => match idx.get(&tagname) {
Some(valset) => {
let common = valset.intersection(check);
common.count() > 0
match &check {
TagOperand::And(v) => valset.intersection(v).count() == v.len(),
TagOperand::Or(v) => valset.intersection(v).count() > 0
}
}
None => false,
},
@ -464,7 +468,7 @@ mod tests {
fn empty_event_tag_match() {
let event = Event::simple_event();
assert!(!event
.generic_tag_val_intersect('e', &HashSet::from(["foo".to_owned(), "bar".to_owned()])));
.generic_tag_val_intersect('e', &TagOperand::Or(HashSet::from(["foo".to_owned(), "bar".to_owned()]))));
}
#[test]
@ -475,7 +479,7 @@ mod tests {
assert!(
event.generic_tag_val_intersect(
'e',
&HashSet::from(["foo".to_owned(), "bar".to_owned()])
&TagOperand::Or(HashSet::from(["foo".to_owned(), "bar".to_owned()])),
)
);
}

View File

@ -4,7 +4,7 @@ use crate::event::{single_char_tagname, Event};
use crate::nip05::{Nip05Name, VerificationRecord};
use crate::payment::{InvoiceInfo, InvoiceStatus};
use crate::repo::{now_jitter, NostrRepo};
use crate::subscription::{ReqFilter, Subscription};
use crate::subscription::{ReqFilter, Subscription, TagOperand};
use async_std::stream::StreamExt;
use async_trait::async_trait;
use chrono::{DateTime, TimeZone, Utc};
@ -725,10 +725,68 @@ fn query_from_filter(f: &ReqFilter) -> Option<QueryBuilder<Postgres>> {
return None;
}
let mut query = QueryBuilder::new("SELECT e.\"content\", e.created_at FROM \"event\" e WHERE ");
let mut query = QueryBuilder::new("SELECT e.\"content\", e.created_at FROM \"event\" e");
// This tracks whether we need to push a prefix AND before adding another clause
let mut push_and = false;
// Query for tags
if let Some(map) = &f.tags {
if !map.is_empty() {
let mut tag_ctr = 1;
for (key, val) in map.iter() {
let has_plain_values = val.into_iter().any(|v| !is_lower_hex(v));
let has_hex_values = val.into_iter().any(|v| v.len() % 2 == 0 && is_lower_hex(v));
if let TagOperand::Or(v_or) = val {
query.push(format!(" JOIN tag t{0} on e.id = t{0}.event_id AND t{0}.\"name\" = ", tag_ctr))
.push_bind(key.to_string())
.push(" AND (");
if has_plain_values {
query.push(format!("t{0}.\"value\" in (", tag_ctr));
let mut tag_query = query.separated(", ");
for v in v_or.iter()
.filter(|v| !is_lower_hex(v)) {
tag_query.push_bind(v.as_bytes());
}
}
if has_plain_values && has_hex_values {
query.push(") OR ");
}
if has_hex_values {
query.push(format!("t{0}.\"value_hex\" in (", tag_ctr));
let mut tag_query = query.separated(", ");
for v in v_or.iter()
.filter(|v| v.len() % 2 == 0 && is_lower_hex(v)) {
tag_query.push_bind(hex::decode(v).ok());
}
}
tag_ctr += 1;
query.push("))");
} else if let TagOperand::And(v_and) = val {
for vx in v_and.iter() {
query.push(format!(" JOIN \"tag\" t{0} on e.id = t{0}.event_id AND t{0}.\"name\" = ", tag_ctr))
.push_bind(key.to_string())
.push(" AND ");
if !is_lower_hex(vx) {
query.push(format!("t{0}.\"value\" = ", tag_ctr))
.push_bind(vx.as_bytes());
} else {
query.push(format!("t{0}.\"value_hex\" = ", tag_ctr))
.push_bind(hex::decode(vx).ok());
}
tag_ctr += 1;
}
}
}
}
}
query.push(" WHERE ");
// Query for "authors", allowing prefix matches
if let Some(auth_vec) = &f.authors {
// filter out non-hex values
@ -870,48 +928,6 @@ fn query_from_filter(f: &ReqFilter) -> Option<QueryBuilder<Postgres>> {
}
}
// Query for tags
if let Some(map) = &f.tags {
if !map.is_empty() {
if push_and {
query.push(" AND ");
}
push_and = true;
for (key, val) in map.iter() {
query.push("e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = ")
.push_bind(key.to_string())
.push(" AND (");
let has_plain_values = val.iter().any(|v| !is_lower_hex(v));
let has_hex_values = val.iter().any(|v| v.len() % 2 == 0 && is_lower_hex(v));
if has_plain_values {
query.push("value in (");
// plain value match first
let mut tag_query = query.separated(", ");
for v in val.iter()
.filter(|v| !is_lower_hex(v)) {
tag_query.push_bind(v.as_bytes());
}
}
if has_plain_values && has_hex_values {
query.push(") OR ");
}
if has_hex_values {
query.push("value_hex in (");
// plain value match first
let mut tag_query = query.separated(", ");
for v in val.iter()
.filter(|v| v.len() % 2 == 0 && is_lower_hex(v)) {
tag_query.push_bind(hex::decode(v).ok());
}
}
query.push("))))");
}
}
}
// Query for timestamp
if f.since.is_some() {
if push_and {
@ -981,6 +997,7 @@ impl FromRow<'_, PgRow> for VerificationRecord {
#[cfg(test)]
mod tests {
use std::collections::{HashMap, HashSet};
use crate::subscription::TagOperand;
use super::*;
#[test]
@ -993,13 +1010,13 @@ mod tests {
authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]),
limit: None,
tags: Some(HashMap::from([
('p', HashSet::from(["63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()]))
('p', TagOperand::Or(HashSet::from(["63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()])))
])),
force_no_match: false,
};
let q = query_from_filter(&filter).unwrap();
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e WHERE (e.pub_key in ($1) OR e.delegated_by in ($2)) AND e.kind in ($3) AND e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = $4 AND (value_hex in ($5)))) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN tag t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND (t1.\"value_hex\" in ($2)) WHERE (e.pub_key in ($3) OR e.delegated_by in ($4)) AND e.kind in ($5) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
}
#[test]
@ -1012,13 +1029,13 @@ mod tests {
authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]),
limit: None,
tags: Some(HashMap::from([
('d', HashSet::from(["test".to_owned()]))
('d', TagOperand::Or(HashSet::from(["test".to_owned()])))
])),
force_no_match: false,
};
let q = query_from_filter(&filter).unwrap();
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e WHERE (e.pub_key in ($1) OR e.delegated_by in ($2)) AND e.kind in ($3) AND e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = $4 AND (value in ($5)))) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN tag t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND (t1.\"value\" in ($2)) WHERE (e.pub_key in ($3) OR e.delegated_by in ($4)) AND e.kind in ($5) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
}
#[test]
@ -1031,12 +1048,31 @@ mod tests {
authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]),
limit: None,
tags: Some(HashMap::from([
('d', HashSet::from(["test".to_owned(), "63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()]))
('d', TagOperand::Or(HashSet::from(["test".to_owned(), "63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned()])))
])),
force_no_match: false,
};
let q = query_from_filter(&filter).unwrap();
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e WHERE (e.pub_key in ($1) OR e.delegated_by in ($2)) AND e.kind in ($3) AND e.id IN (SELECT ee.id FROM \"event\" ee LEFT JOIN tag t on ee.id = t.event_id WHERE ee.hidden != 1::bit(1) and (t.\"name\" = $4 AND (value in ($5) OR value_hex in ($6)))) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN tag t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND (t1.\"value\" in ($2) OR t1.\"value_hex\" in ($3)) WHERE (e.pub_key in ($4) OR e.delegated_by in ($5)) AND e.kind in ($6) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
}
#[test]
fn test_query_gen_tag_value_hex_and() {
let filter = ReqFilter {
ids: None,
kinds: Some(vec![1000]),
since: None,
until: None,
authors: Some(vec!["84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()]),
limit: None,
tags: Some(HashMap::from([
('p', TagOperand::And(HashSet::from(["63fe6318dc58583cfe16810f86dd09e18bfd76aabc24a0081ce2856f330504ed".to_owned(), "84de35e2584d2b144aae823c9ed0b0f3deda09648530b93d1a2a146d1dea9864".to_owned()])))
])),
force_no_match: false,
};
let q = query_from_filter(&filter).unwrap();
assert_eq!(q.sql(), "SELECT e.\"content\", e.created_at FROM \"event\" e JOIN \"tag\" t1 on e.id = t1.event_id AND t1.\"name\" = $1 AND t1.\"value_hex\" = $2 JOIN \"tag\" t2 on e.id = t2.event_id AND t2.\"name\" = $3 AND t2.\"value_hex\" = $4 WHERE (e.pub_key in ($5) OR e.delegated_by in ($6)) AND e.kind in ($7) AND e.hidden != 1::bit(1) AND (e.expires_at IS NULL OR e.expires_at > now()) ORDER BY e.created_at ASC LIMIT 1000")
}
}

View File

@ -1,4 +1,5 @@
//! Subscription and filter parsing
use std::collections::hash_set::Iter;
use crate::error::Result;
use crate::event::Event;
use serde::de::Unexpected;
@ -15,6 +16,25 @@ pub struct Subscription {
pub filters: Vec<ReqFilter>,
}
/// Tag query is AND or OR operation
#[derive(Serialize, PartialEq, Eq, Debug, Clone)]
pub enum TagOperand {
And(HashSet<String>),
Or(HashSet<String>),
}
impl<'a> IntoIterator for &'a TagOperand {
type Item = &'a String;
type IntoIter = Iter<'a, String>;
fn into_iter(self) -> Self::IntoIter {
match self {
TagOperand::Or(vv) => vv.iter(),
TagOperand::And(vv) => vv.iter()
}
}
}
/// Filter for requests
///
/// Corresponds to client-provided subscription request elements. Any
@ -35,7 +55,7 @@ pub struct ReqFilter {
/// Limit number of results
pub limit: Option<u64>,
/// Set of tags
pub tags: Option<HashMap<char, HashSet<String>>>,
pub tags: Option<HashMap<char, TagOperand>>,
/// Force no matches due to malformed data
// we can't represent it in the req filter, so we don't want to
// erroneously match. This basically indicates the req tried to
@ -45,8 +65,8 @@ pub struct ReqFilter {
impl Serialize for ReqFilter {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
where
S: Serializer,
{
let mut map = serializer.serialize_map(None)?;
if let Some(ids) = &self.ids {
@ -70,7 +90,7 @@ impl Serialize for ReqFilter {
// serialize tags
if let Some(tags) = &self.tags {
for (k, v) in tags {
let vals: Vec<&String> = v.iter().collect();
let vals: Vec<&String> = v.into_iter().collect();
map.serialize_entry(&format!("#{k}"), &vals)?;
}
}
@ -80,8 +100,8 @@ impl Serialize for ReqFilter {
impl<'de> Deserialize<'de> for ReqFilter {
fn deserialize<D>(deserializer: D) -> Result<ReqFilter, D::Error>
where
D: Deserializer<'de>,
where
D: Deserializer<'de>,
{
let received: Value = Deserialize::deserialize(deserializer)?;
let filter = received.as_object().ok_or_else(|| {
@ -101,7 +121,7 @@ impl<'de> Deserialize<'de> for ReqFilter {
force_no_match: false,
};
let empty_string = "".into();
let mut ts = None;
let mut ts: Option<HashMap<char, TagOperand>> = None;
// iterate through each key, and assign values that exist
for (key, val) in filter {
// ids
@ -135,24 +155,31 @@ impl<'de> Deserialize<'de> for ReqFilter {
}
}
rf.authors = raw_authors;
} else if key.starts_with('#') && key.len() > 1 && val.is_array() {
if let Some(tag_search) = tag_search_char_from_filter(key) {
if ts.is_none() {
// Initialize the tag if necessary
ts = Some(HashMap::new());
}
if let Some(m) = ts.as_mut() {
let tag_vals: Option<Vec<String>> = Deserialize::deserialize(val).ok();
if let Some(v) = tag_vals {
let hs = v.into_iter().collect::<HashSet<_>>();
m.insert(tag_search.to_owned(), hs);
}
};
} else {
// tag search that is multi-character, don't add to subscription
rf.force_no_match = true;
continue;
} else if key.starts_with('#') && key.len() > 1 && key.len() < 4 && val.is_array() {
if ts.is_none() {
// Initialize the tag if necessary
ts = Some(HashMap::new());
}
if let Some(m) = ts.as_mut() {
let tag_vals: Option<Vec<String>> = Deserialize::deserialize(val).ok();
if let Some(v) = tag_vals {
let hs = v.into_iter().collect::<HashSet<_>>();
let hs_op = match key.len() {
2 => Some(TagOperand::Or(hs)),
3 => {
if key.chars().nth(2).unwrap() == '&' {
Some(TagOperand::And(hs))
} else {
None
}
}
_ => None
};
if let Some(hs_some) = hs_op {
m.insert(key.chars().nth(1).unwrap(), hs_some);
}
}
};
}
}
rf.tags = ts;
@ -160,32 +187,12 @@ impl<'de> Deserialize<'de> for ReqFilter {
}
}
/// Attempt to form a single-char identifier from a tag search filter
fn tag_search_char_from_filter(tagname: &str) -> Option<char> {
let tagname_nohash = &tagname[1..];
// We return the tag character if and only if the tagname consists
// of a single char.
let mut tagnamechars = tagname_nohash.chars();
let firstchar = tagnamechars.next();
match firstchar {
Some(_) => {
// check second char
if tagnamechars.next().is_none() {
firstchar
} else {
None
}
}
None => None,
}
}
impl<'de> Deserialize<'de> for Subscription {
/// Custom deserializer for subscriptions, which have a more
/// complex structure than the other message types.
fn deserialize<D>(deserializer: D) -> Result<Subscription, D::Error>
where
D: Deserializer<'de>,
where
D: Deserializer<'de>,
{
let mut v: Value = Deserialize::deserialize(deserializer)?;
// this should be a 3-or-more element array.