From 05bc5d12aeae755c44e83068d35d9ef560c4d9da Mon Sep 17 00:00:00 2001 From: Pascal Seitz Date: Mon, 15 Jul 2024 18:48:47 +0800 Subject: [PATCH] add support for str fast field range query Add support for range queries on fast fields, by converting term bounds to term ordinals bounds. closes https://github.com/quickwit-oss/tantivy/issues/2023 --- src/query/range_query/mod.rs | 4 +- .../range_query/range_query_u64_fastfield.rs | 137 +++++-- src/schema/field_type.rs | 5 + sstable/src/dictionary.rs | 347 +++++++++++++++++- 4 files changed, 445 insertions(+), 48 deletions(-) diff --git a/src/query/range_query/mod.rs b/src/query/range_query/mod.rs index 8ed26c95ab..40effb85b1 100644 --- a/src/query/range_query/mod.rs +++ b/src/query/range_query/mod.rs @@ -12,9 +12,9 @@ pub use self::range_query_u64_fastfield::FastFieldRangeWeight; // TODO is this correct? pub(crate) fn is_type_valid_for_fastfield_range_query(typ: Type) -> bool { match typ { - Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true, + Type::Str | Type::U64 | Type::I64 | Type::F64 | Type::Bool | Type::Date => true, Type::IpAddr => true, - Type::Str | Type::Facet | Type::Bytes | Type::Json => false, + Type::Facet | Type::Bytes | Type::Json => false, } } diff --git a/src/query/range_query/range_query_u64_fastfield.rs b/src/query/range_query/range_query_u64_fastfield.rs index 1db436ccb1..e659d9cffb 100644 --- a/src/query/range_query/range_query_u64_fastfield.rs +++ b/src/query/range_query/range_query_u64_fastfield.rs @@ -5,7 +5,7 @@ use std::net::Ipv6Addr; use std::ops::{Bound, RangeInclusive}; -use columnar::{Column, MonotonicallyMappableToU128, MonotonicallyMappableToU64}; +use columnar::{Column, MonotonicallyMappableToU128, MonotonicallyMappableToU64, StrColumn}; use common::BinarySerializable; use super::fast_field_range_doc_set::RangeDocSet; @@ -51,16 +51,22 @@ impl Weight for FastFieldRangeWeight { } let field_name = reader.schema().get_field_name(self.field); let field_type = reader.schema().get_field_entry(self.field).field_type(); + + let term = inner_bound(&self.lower_bound) + .or(inner_bound(&self.upper_bound)) + .expect("At least one bound must be set"); + assert_eq!( + term.typ(), + field_type.value_type(), + "Field is of type {:?}, but got term of type {:?}", + field_type, + term.typ() + ); if field_type.is_ip_addr() { let parse_ip_from_bytes = |term: &Term| { - let ip_u128_bytes: [u8; 16] = - term.serialized_value_bytes().try_into().map_err(|_| { - crate::TantivyError::InvalidArgument( - "Expected 8 bytes for ip address".to_string(), - ) - })?; - let ip_u128 = u128::from_be_bytes(ip_u128_bytes); - crate::Result::::Ok(Ipv6Addr::from_u128(ip_u128)) + term.value().as_ip_addr().ok_or_else(|| { + crate::TantivyError::InvalidArgument("Expected ip address".to_string()) + }) }; let lower_bound = map_bound_res(&self.lower_bound, parse_ip_from_bytes)?; let upper_bound = map_bound_res(&self.upper_bound, parse_ip_from_bytes)?; @@ -79,33 +85,42 @@ impl Weight for FastFieldRangeWeight { let docset = RangeDocSet::new(value_range, ip_addr_column); Ok(Box::new(ConstScorer::new(docset, boost))) } else { - assert!( - maps_to_u64_fastfield(field_type.value_type()), - "{:?}", - field_type - ); - - let term = inner_bound(&self.lower_bound) - .or(inner_bound(&self.upper_bound)) - .expect("At least one bound must be set"); - assert_eq!( - term.typ(), - field_type.value_type(), - "Field is of type {:?}, but got term of type {:?}", - field_type, - term.typ() - ); + let (lower_bound, upper_bound) = if field_type.is_term() { + let Some(str_dict_column): Option = + reader.fast_fields().str(field_name)? + else { + return Ok(Box::new(EmptyScorer)); + }; + let dict = str_dict_column.dictionary(); + + let lower_bound = map_bound(&self.lower_bound, |term| { + term.serialized_value_bytes().to_vec() + }); + let upper_bound = map_bound(&self.upper_bound, |term| { + term.serialized_value_bytes().to_vec() + }); + // Get term ids for terms + let (lower_bound, upper_bound) = + dict.term_bounds_to_ord(lower_bound, upper_bound)?; + (lower_bound, upper_bound) + } else { + assert!( + maps_to_u64_fastfield(field_type.value_type()), + "{:?}", + field_type + ); + let parse_from_bytes = |term: &Term| { + u64::from_be( + BinarySerializable::deserialize(&mut &term.serialized_value_bytes()[..]) + .unwrap(), + ) + }; - let parse_from_bytes = |term: &Term| { - u64::from_be( - BinarySerializable::deserialize(&mut &term.serialized_value_bytes()[..]) - .unwrap(), - ) + let lower_bound = map_bound(&self.lower_bound, parse_from_bytes); + let upper_bound = map_bound(&self.upper_bound, parse_from_bytes); + (lower_bound, upper_bound) }; - let lower_bound = map_bound(&self.lower_bound, parse_from_bytes); - let upper_bound = map_bound(&self.upper_bound, parse_from_bytes); - let fast_field_reader = reader.fast_fields(); let Some((column, _)) = fast_field_reader.u64_lenient_for_type(None, field_name)? else { @@ -202,12 +217,38 @@ pub mod tests { use rand::seq::SliceRandom; use rand::SeedableRng; - use crate::collector::Count; + use crate::collector::{Count, TopDocs}; use crate::query::range_query::range_query_u64_fastfield::FastFieldRangeWeight; use crate::query::{QueryParser, Weight}; - use crate::schema::{NumericOptions, Schema, SchemaBuilder, FAST, INDEXED, STORED, STRING}; + use crate::schema::{ + NumericOptions, Schema, SchemaBuilder, FAST, INDEXED, STORED, STRING, TEXT, + }; use crate::{Index, IndexWriter, Term, TERMINATED}; + #[test] + fn test_text_field_ff_range_query() -> crate::Result<()> { + let mut schema_builder = Schema::builder(); + schema_builder.add_text_field("title", TEXT | FAST); + let schema = schema_builder.build(); + let index = Index::create_in_ram(schema.clone()); + let mut index_writer = index.writer_for_tests()?; + let title = schema.get_field("title").unwrap(); + index_writer.add_document(doc!( + title => "bbb" + ))?; + index_writer.add_document(doc!( + title => "ddd" + ))?; + index_writer.commit()?; + let reader = index.reader()?; + let searcher = reader.searcher(); + let query_parser = QueryParser::for_index(&index, vec![title]); + let query = query_parser.parse_query("title:[ccc TO ddd]")?; + let top_docs = searcher.search(&query, &TopDocs::with_limit(10))?; + assert_eq!(top_docs.len(), 1); + Ok(()) + } + #[derive(Clone, Debug)] pub struct Doc { pub id_name: String, @@ -224,14 +265,14 @@ pub mod tests { fn doc_from_id_1(id: u64) -> Doc { let id = id * 1000; Doc { - id_name: id.to_string(), + id_name: format!("id_name{:010}", id), id, } } fn doc_from_id_2(id: u64) -> Doc { let id = id * 1000; Doc { - id_name: (id - 1).to_string(), + id_name: format!("id_name{:010}", id - 1), id, } } @@ -319,7 +360,8 @@ pub mod tests { NumericOptions::default().set_fast().set_indexed(), ); - let text_field = schema_builder.add_text_field("id_name", STRING | STORED); + let text_field = schema_builder.add_text_field("id_name", STRING | STORED | FAST); + let text_field2 = schema_builder.add_text_field("id_name_fast", STRING | STORED | FAST); let schema = schema_builder.build(); let index = Index::create_in_ram(schema); @@ -338,6 +380,7 @@ pub mod tests { id_f64_field => doc.id as f64, id_i64_field => doc.id as i64, text_field => doc.id_name.to_string(), + text_field2 => doc.id_name.to_string(), )) .unwrap(); } @@ -382,6 +425,24 @@ pub mod tests { let query = gen_query_inclusive("ids", ids[0]..=ids[1]); assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits); + // Text query + { + let test_text_query = |field_name: &str| { + let mut id_names: Vec<&str> = + sample_docs.iter().map(|doc| doc.id_name.as_str()).collect(); + id_names.sort(); + let expected_num_hits = docs + .iter() + .filter(|doc| (id_names[0]..=id_names[1]).contains(&doc.id_name.as_str())) + .count(); + let query = format!("{}:[{} TO {}]", field_name, id_names[0], id_names[1]); + assert_eq!(get_num_hits(query_from_text(&query)), expected_num_hits); + }; + + test_text_query("id_name"); + test_text_query("id_name_fast"); + } + // Exclusive range let expected_num_hits = docs .iter() diff --git a/src/schema/field_type.rs b/src/schema/field_type.rs index 70a453dbf7..aead66deb5 100644 --- a/src/schema/field_type.rs +++ b/src/schema/field_type.rs @@ -201,6 +201,11 @@ impl FieldType { matches!(self, FieldType::IpAddr(_)) } + /// returns true if this is an str field + pub fn is_term(&self) -> bool { + matches!(self, FieldType::Str(_)) + } + /// returns true if this is an date field pub fn is_date(&self) -> bool { matches!(self, FieldType::Date(_)) diff --git a/sstable/src/dictionary.rs b/sstable/src/dictionary.rs index 3e66167fb8..ca33dcfb27 100644 --- a/sstable/src/dictionary.rs +++ b/sstable/src/dictionary.rs @@ -56,6 +56,51 @@ impl Dictionary { } } +fn map_bound(bound: &Bound, transform: impl Fn(&TFrom) -> TTo) -> Bound { + use self::Bound::*; + match bound { + Excluded(ref from_val) => Bound::Excluded(transform(from_val)), + Included(ref from_val) => Bound::Included(transform(from_val)), + Unbounded => Unbounded, + } +} + +fn map_bound_res( + bound: &Bound, + transform: impl Fn(&TFrom) -> io::Result>, +) -> io::Result> { + use self::Bound::*; + Ok(match bound { + Excluded(ref from_val) => transform(from_val)?, + Included(ref from_val) => transform(from_val)?, + Unbounded => Unbounded, + }) +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum TermOrdHit { + /// Exact term ord hit + Exact(TermOrdinal), + /// Next best term ordinal + Next(TermOrdinal), +} + +impl TermOrdHit { + fn into_exact(self) -> Option { + match self { + TermOrdHit::Exact(ord) => Some(ord), + TermOrdHit::Next(_) => None, + } + } + + fn map TermOrdinal>(self, f: F) -> Self { + match self { + TermOrdHit::Exact(ord) => TermOrdHit::Exact(f(ord)), + TermOrdHit::Next(ord) => TermOrdHit::Next(f(ord)), + } + } +} + impl Dictionary { pub fn builder(wrt: W) -> io::Result> { Ok(TSSTable::writer(wrt)) @@ -257,6 +302,19 @@ impl Dictionary { key: K, sstable_delta_reader: &mut DeltaReader, ) -> io::Result> { + self.decode_up_to_or_next(key, sstable_delta_reader) + .map(|hit| hit.into_exact()) + } + /// Decode a DeltaReader up to key, returning the number of terms traversed + /// + /// If the key was not found, returns Ok(None). + /// After calling this function, it is possible to call `DeltaReader::value` to get the + /// associated value. + fn decode_up_to_or_next>( + &self, + key: K, + sstable_delta_reader: &mut DeltaReader, + ) -> io::Result { let mut term_ord = 0; let key_bytes = key.as_ref(); let mut ok_bytes = 0; @@ -265,7 +323,7 @@ impl Dictionary { let suffix = sstable_delta_reader.suffix(); match prefix_len.cmp(&ok_bytes) { - Ordering::Less => return Ok(None), // popped bytes already matched => too far + Ordering::Less => return Ok(TermOrdHit::Next(term_ord)), /* popped bytes already matched => too far */ Ordering::Equal => (), Ordering::Greater => { // the ok prefix is less than current entry prefix => continue to next elem @@ -277,25 +335,26 @@ impl Dictionary { // we have ok_bytes byte of common prefix, check if this key adds more for (key_byte, suffix_byte) in key_bytes[ok_bytes..].iter().zip(suffix) { match suffix_byte.cmp(key_byte) { - Ordering::Less => break, // byte too small - Ordering::Equal => ok_bytes += 1, // new matching byte - Ordering::Greater => return Ok(None), // too far + Ordering::Less => break, // byte too small + Ordering::Equal => ok_bytes += 1, // new matching + // byte + Ordering::Greater => return Ok(TermOrdHit::Next(term_ord)), // too far } } if ok_bytes == key_bytes.len() { if prefix_len + suffix.len() == ok_bytes { - return Ok(Some(term_ord)); + return Ok(TermOrdHit::Exact(term_ord)); } else { // current key is a prefix of current element, not a match - return Ok(None); + return Ok(TermOrdHit::Next(term_ord)); } } term_ord += 1; } - Ok(None) + Ok(TermOrdHit::Next(term_ord)) } /// Returns the ordinal associated with a given term. @@ -312,6 +371,61 @@ impl Dictionary { .map(|opt| opt.map(|ord| ord + first_ordinal)) } + /// Returns the ordinal associated with a given term or its closest next term_id + /// The closest next term_id may not exist. + pub fn term_ord_or_next>(&self, key: K) -> io::Result { + let key_bytes = key.as_ref(); + + let Some(block_addr) = self.sstable_index.get_block_with_key(key_bytes) else { + // TODO: Would be more consistent to return last_term id + 1 + return Ok(TermOrdHit::Next(u64::MAX)); + }; + + let first_ordinal = block_addr.first_ordinal; + let mut sstable_delta_reader = self.sstable_delta_reader_block(block_addr)?; + self.decode_up_to_or_next(key_bytes, &mut sstable_delta_reader) + .map(|opt| opt.map(|ord| ord + first_ordinal)) + } + + /// Converts strings into a Bound range. + /// This does handle several special cases if the term is not exactly in the dictionary. + /// e.g. [bbb, ddd] + /// lower_bound: Bound::Included(aaa) => Included(0) // "Next" term id + /// lower_bound: Bound::Excluded(aaa) => Included(0) // "Next" term id + Change the Bounds + /// lower_bound: Bound::Included(ccc) => Included(1) // "Next" term id + /// lower_bound: Bound::Excluded(ccc) => Included(1) // "Next" term id + Change the Bounds + /// lower_bound: Bound::Included(zzz) => Included(2) // "Next" term id + /// lower_bound: Bound::Excluded(zzz) => Included(2) // "Next" term id + Change the Bounds + /// For zzz we should have some post processing to return an empty query` + /// + /// upper_bound: Bound::Included(aaa) => Excluded(0) // "Next" term id + Change the bounds + /// upper_bound: Bound::Excluded(aaa) => Excluded(0) // "Next" term id + /// upper_bound: Bound::Included(ccc) => Excluded(1) // Next term id + Change the bounds + /// upper_bound: Bound::Excluded(ccc) => Excluded(1) // Next term id + /// upper_bound: Bound::Included(zzz) => Excluded(2) // Next term id + Change the bounds + /// upper_bound: Bound::Excluded(zzz) => Excluded(2) // Next term id + pub fn term_bounds_to_ord>( + &self, + lower_bound: Bound, + upper_bound: Bound, + ) -> io::Result<(Bound, Bound)> { + let lower_bound = map_bound_res(&lower_bound, |start_bound_bytes| { + let ord = self.term_ord_or_next(start_bound_bytes)?; + match ord { + TermOrdHit::Exact(ord) => Ok(map_bound(&lower_bound, |_| ord)), + TermOrdHit::Next(ord) => Ok(Bound::Included(ord)), // Change bounds to included + } + })?; + let upper_bound = map_bound_res(&upper_bound, |end_bound_bytes| { + let ord = self.term_ord_or_next(end_bound_bytes)?; + match ord { + TermOrdHit::Exact(ord) => Ok(map_bound(&upper_bound, |_| ord)), + TermOrdHit::Next(ord) => Ok(Bound::Excluded(ord)), // Change bounds to excluded + } + })?; + Ok((lower_bound, upper_bound)) + } + /// Returns the term associated with a given term ordinal. /// /// Term ordinals are defined as the position of the term in @@ -455,12 +569,13 @@ impl Dictionary { #[cfg(test)] mod tests { - use std::ops::Range; + use std::ops::{Bound, Range}; use std::sync::{Arc, Mutex}; use common::OwnedBytes; use super::Dictionary; + use crate::dictionary::TermOrdHit; use crate::MonotonicU64SSTable; #[derive(Debug)] @@ -524,6 +639,222 @@ mod tests { (dictionary, table) } + #[test] + fn test_term_to_ord_or_next() { + let dict = { + let mut builder = Dictionary::::builder(Vec::new()).unwrap(); + + builder.insert(b"bbb", &1).unwrap(); + builder.insert(b"ddd", &2).unwrap(); + + let table = builder.finish().unwrap(); + let table = Arc::new(PermissionedHandle::new(table)); + let slice = common::file_slice::FileSlice::new(table.clone()); + + Dictionary::::open(slice).unwrap() + }; + + assert_eq!(dict.term_ord_or_next(b"aaa").unwrap(), TermOrdHit::Next(0)); + assert_eq!(dict.term_ord_or_next(b"bbb").unwrap(), TermOrdHit::Exact(0)); + assert_eq!(dict.term_ord_or_next(b"bb").unwrap(), TermOrdHit::Next(0)); + assert_eq!(dict.term_ord_or_next(b"bbbb").unwrap(), TermOrdHit::Next(1)); + assert_eq!(dict.term_ord_or_next(b"dd").unwrap(), TermOrdHit::Next(1)); + assert_eq!(dict.term_ord_or_next(b"ddd").unwrap(), TermOrdHit::Exact(1)); + assert_eq!(dict.term_ord_or_next(b"dddd").unwrap(), TermOrdHit::Next(2)); + + // Shouldn't this be u64::MAX? + assert_eq!( + dict.term_ord_or_next(b"zzzzzzz").unwrap(), + TermOrdHit::Next(2) + ); + } + #[test] + fn test_term_to_ord_or_next_2() { + let dict = { + let mut builder = Dictionary::::builder(Vec::new()).unwrap(); + + let mut term_ord = 0; + builder.insert(b"bbb", &term_ord).unwrap(); + + // Fill blocks in between + for elem in 0..50_000 { + term_ord += 1; + let key = format!("ccccc{elem:05X}").into_bytes(); + builder.insert(&key, &term_ord).unwrap(); + } + + term_ord += 1; + builder.insert(b"eee", &term_ord).unwrap(); + + let table = builder.finish().unwrap(); + let table = Arc::new(PermissionedHandle::new(table)); + let slice = common::file_slice::FileSlice::new(table.clone()); + + Dictionary::::open(slice).unwrap() + }; + + assert_eq!(dict.term_ord(b"bbb").unwrap(), Some(0)); + assert_eq!(dict.term_ord_or_next(b"bbb").unwrap(), TermOrdHit::Exact(0)); + assert_eq!(dict.term_ord_or_next(b"aaa").unwrap(), TermOrdHit::Next(0)); + assert_eq!(dict.term_ord_or_next(b"bb").unwrap(), TermOrdHit::Next(0)); + assert_eq!(dict.term_ord_or_next(b"bbbb").unwrap(), TermOrdHit::Next(1)); + assert_eq!( + dict.term_ord_or_next(b"ee").unwrap(), + TermOrdHit::Next(50001) + ); + assert_eq!( + dict.term_ord_or_next(b"eee").unwrap(), + TermOrdHit::Exact(50001) + ); + assert_eq!( + dict.term_ord_or_next(b"eeee").unwrap(), + TermOrdHit::Next(u64::MAX) + ); + + assert_eq!( + dict.term_ord_or_next(b"zzzzzzz").unwrap(), + TermOrdHit::Next(u64::MAX) + ); + } + + #[test] + fn test_term_bounds_to_ord() { + let dict = { + let mut builder = Dictionary::::builder(Vec::new()).unwrap(); + + builder.insert(b"bbb", &1).unwrap(); + builder.insert(b"ddd", &2).unwrap(); + + let table = builder.finish().unwrap(); + let table = Arc::new(PermissionedHandle::new(table)); + let slice = common::file_slice::FileSlice::new(table.clone()); + + Dictionary::::open(slice).unwrap() + }; + + // Test cases for lower_bound + assert_eq!( + dict.term_bounds_to_ord( + Bound::Included(b"aaa".as_slice()), + Bound::Included(b"ignored") + ) + .unwrap() + .0, + Bound::Included(0) + ); + + assert_eq!( + dict.term_bounds_to_ord( + Bound::Excluded(b"aaa".as_slice()), + Bound::Excluded(b"ignored") + ) + .unwrap() + .0, + Bound::Included(0) + ); + + assert_eq!( + dict.term_bounds_to_ord( + Bound::Included(b"ccc".as_slice()), + Bound::Included(b"ignored") + ) + .unwrap() + .0, + Bound::Included(1) + ); + + assert_eq!( + dict.term_bounds_to_ord( + Bound::Excluded(b"ccc".as_slice()), + Bound::Excluded(b"ignored") + ) + .unwrap() + .0, + Bound::Included(1) + ); + + assert_eq!( + dict.term_bounds_to_ord( + Bound::Included(b"zzz".as_slice()), + Bound::Included(b"ignored") + ) + .unwrap() + .0, + Bound::Included(2) + ); + + assert_eq!( + dict.term_bounds_to_ord( + Bound::Excluded(b"zzz".as_slice()), + Bound::Excluded(b"ignored") + ) + .unwrap() + .0, + Bound::Included(2) + ); + + // Test cases for upper_bound + assert_eq!( + dict.term_bounds_to_ord( + Bound::Included(b"ignored".as_slice()), + Bound::Included(b"ccc") + ) + .unwrap() + .1, + Bound::Excluded(1) + ); + + assert_eq!( + dict.term_bounds_to_ord( + Bound::Excluded(b"ignored".as_slice()), + Bound::Excluded(b"ccc") + ) + .unwrap() + .1, + Bound::Excluded(1) + ); + + assert_eq!( + dict.term_bounds_to_ord( + Bound::Included(b"ignored".as_slice()), + Bound::Included(b"zzz") + ) + .unwrap() + .1, + Bound::Excluded(2) + ); + + assert_eq!( + dict.term_bounds_to_ord( + Bound::Excluded(b"ignored".as_slice()), + Bound::Excluded(b"zzz") + ) + .unwrap() + .1, + Bound::Excluded(2) + ); + + assert_eq!( + dict.term_bounds_to_ord( + Bound::Included(b"ignored".as_slice()), + Bound::Included(b"ddd") + ) + .unwrap() + .1, + Bound::Included(1) + ); + + assert_eq!( + dict.term_bounds_to_ord( + Bound::Excluded(b"ignored".as_slice()), + Bound::Excluded(b"ddd") + ) + .unwrap() + .1, + Bound::Excluded(1) + ); + } + #[test] fn test_ord_term_conversion() { let (dic, slice) = make_test_sstable();