From 13286e9df6ca3410ac8671d5f1ce6a8d564905a2 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 20 Nov 2025 08:07:38 +0800 Subject: [PATCH 1/9] add specialized InList implementations for common scalar types --- .../physical-expr/src/expressions/in_list.rs | 375 ++++++++++++++++-- 1 file changed, 339 insertions(+), 36 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 95029c1efe74..4c6d7c54f54a 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -138,9 +138,29 @@ fn instantiate_static_filter( in_array: ArrayRef, ) -> Result> { match in_array.data_type() { + // Integer primitive types + DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), + DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), + DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), + DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), + // Boolean + DataType::Boolean => Ok(Arc::new(BooleanStaticFilter::try_new(&in_array)?)), + // String types + DataType::Utf8 => Ok(Arc::new(Utf8StaticFilter::try_new(&in_array)?)), + DataType::LargeUtf8 => Ok(Arc::new(LargeUtf8StaticFilter::try_new(&in_array)?)), + DataType::Utf8View => Ok(Arc::new(Utf8ViewStaticFilter::try_new(&in_array)?)), + // Binary types + DataType::Binary => Ok(Arc::new(BinaryStaticFilter::try_new(&in_array)?)), + DataType::LargeBinary => { + Ok(Arc::new(LargeBinaryStaticFilter::try_new(&in_array)?)) + } + DataType::BinaryView => Ok(Arc::new(BinaryViewStaticFilter::try_new(&in_array)?)), _ => { - /* fall through to generic implementation */ + /* fall through to generic implementation for unsupported types (Float32/Float64, Struct, etc.) */ Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) } } @@ -198,18 +218,135 @@ impl ArrayStaticFilter { } } -struct Int32StaticFilter { +// Macro to generate specialized StaticFilter implementations for primitive types +macro_rules! primitive_static_filter { + ($Name:ident, $ArrowType:ty) => { + struct $Name { + null_count: usize, + values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>, + } + + impl $Name { + fn try_new(in_array: &ArrayRef) -> Result { + let in_array = in_array + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + + let mut values = HashSet::with_capacity(in_array.len()); + let null_count = in_array.null_count(); + + for v in in_array.iter().flatten() { + values.insert(v); + } + + Ok(Self { null_count, values }) + } + } + + impl StaticFilter for $Name { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Handle dictionary arrays by recursing on the values + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + + let v = v + .as_primitive_opt::<$ArrowType>() + .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + + let haystack_has_nulls = self.null_count > 0; + + let result = match (v.null_count() > 0, haystack_has_nulls, negated) { + (true, _, false) | (false, true, false) => { + // Either needle or haystack has nulls, not negated + BooleanArray::from_iter(v.iter().map(|value| { + match value { + // SQL three-valued logic: null IN (...) is always null + None => None, + Some(v) => { + if self.values.contains(&v) { + Some(true) + } else if haystack_has_nulls { + // value not in set, but set has nulls -> null + None + } else { + Some(false) + } + } + } + })) + } + (true, _, true) | (false, true, true) => { + // Either needle or haystack has nulls, negated + BooleanArray::from_iter(v.iter().map(|value| { + match value { + // SQL three-valued logic: null NOT IN (...) is always null + None => None, + Some(v) => { + if self.values.contains(&v) { + Some(false) + } else if haystack_has_nulls { + // value not in set, but set has nulls -> null + None + } else { + Some(true) + } + } + } + })) + } + (false, false, false) => { + // no nulls anywhere, not negated + BooleanArray::from_iter( + v.values().iter().map(|value| self.values.contains(value)), + ) + } + (false, false, true) => { + // no nulls anywhere, negated + BooleanArray::from_iter( + v.values().iter().map(|value| !self.values.contains(value)), + ) + } + }; + Ok(result) + } + } + }; +} + +// Generate specialized filters for all integer primitive types +// Note: Float32 and Float64 are excluded because they don't implement Hash/Eq due to NaN +primitive_static_filter!(Int8StaticFilter, Int8Type); +primitive_static_filter!(Int16StaticFilter, Int16Type); +primitive_static_filter!(Int32StaticFilter, Int32Type); +primitive_static_filter!(Int64StaticFilter, Int64Type); +primitive_static_filter!(UInt8StaticFilter, UInt8Type); +primitive_static_filter!(UInt16StaticFilter, UInt16Type); +primitive_static_filter!(UInt32StaticFilter, UInt32Type); +primitive_static_filter!(UInt64StaticFilter, UInt64Type); + +// Boolean static filter +struct BooleanStaticFilter { null_count: usize, - values: HashSet, + values: HashSet, } -impl Int32StaticFilter { +impl BooleanStaticFilter { fn try_new(in_array: &ArrayRef) -> Result { let in_array = in_array - .as_primitive_opt::() + .as_boolean_opt() .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; - let mut values = HashSet::with_capacity(in_array.len()); + let mut values = HashSet::with_capacity(in_array.len().min(2)); let null_count = in_array.null_count(); for v in in_array.iter().flatten() { @@ -220,46 +357,215 @@ impl Int32StaticFilter { } } -impl StaticFilter for Int32StaticFilter { +impl StaticFilter for BooleanStaticFilter { fn null_count(&self) -> usize { self.null_count } fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Handle dictionary arrays by recursing on the values + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + let v = v - .as_primitive_opt::() + .as_boolean_opt() .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; - let result = match (v.null_count() > 0, negated) { - (true, false) => { - // has nulls, not negated" - BooleanArray::from_iter( - v.iter().map(|value| Some(self.values.contains(&value?))), - ) - } - (true, true) => { - // has nulls, negated - BooleanArray::from_iter( - v.iter().map(|value| Some(!self.values.contains(&value?))), - ) - } - (false, false) => { - //no null, not negated - BooleanArray::from_iter( - v.values().iter().map(|value| self.values.contains(value)), - ) + let haystack_has_nulls = self.null_count > 0; + + let result = match (v.null_count() > 0, haystack_has_nulls, negated) { + (true, _, false) | (false, true, false) => { + BooleanArray::from_iter(v.iter().map(|value| match value { + None => None, + Some(v) => { + if self.values.contains(&v) { + Some(true) + } else if haystack_has_nulls { + None + } else { + Some(false) + } + } + })) } - (false, true) => { - // no null, negated - BooleanArray::from_iter( - v.values().iter().map(|value| !self.values.contains(value)), - ) + (true, _, true) | (false, true, true) => { + BooleanArray::from_iter(v.iter().map(|value| match value { + None => None, + Some(v) => { + if self.values.contains(&v) { + Some(false) + } else if haystack_has_nulls { + None + } else { + Some(true) + } + } + })) } + (false, false, false) => BooleanArray::from_iter( + v.values().iter().map(|value| self.values.contains(&value)), + ), + (false, false, true) => BooleanArray::from_iter( + v.values().iter().map(|value| !self.values.contains(&value)), + ), }; Ok(result) } } +// Macro to generate static filter implementations for string and binary types +// This eliminates ~550 lines of duplicated code across 6 implementations +macro_rules! define_static_filter { + ( + $name:ident, + $value_type:ty, + |$arr_param:ident| $downcast:expr, + $convert:ident + ) => { + struct $name { + null_count: usize, + values: HashSet<$value_type>, + } + + impl $name { + fn try_new(in_array: &ArrayRef) -> Result { + let $arr_param = in_array; + let in_array = $downcast + .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + + let mut values = HashSet::with_capacity(in_array.len()); + let null_count = in_array.null_count(); + + for v in in_array.iter().flatten() { + values.insert(v.$convert()); + } + + Ok(Self { null_count, values }) + } + } + + impl StaticFilter for $name { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Handle dictionary arrays by recursing on the values + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + + let $arr_param = v; + let v = $downcast + .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + + let haystack_has_nulls = self.null_count > 0; + + let result = match (v.null_count() > 0, haystack_has_nulls, negated) { + (true, _, false) | (false, true, false) => { + BooleanArray::from_iter(v.iter().map(|value| { + match value { + None => None, + Some(v) => { + if self.values.contains(v) { + Some(true) + } else if haystack_has_nulls { + None + } else { + Some(false) + } + } + } + })) + } + (true, _, true) | (false, true, true) => { + BooleanArray::from_iter(v.iter().map(|value| { + match value { + None => None, + Some(v) => { + if self.values.contains(v) { + Some(false) + } else if haystack_has_nulls { + None + } else { + Some(true) + } + } + } + })) + } + (false, false, false) => { + BooleanArray::from_iter( + v.iter().map(|value| self.values.contains(value.unwrap())), + ) + } + (false, false, true) => { + BooleanArray::from_iter( + v.iter().map(|value| !self.values.contains(value.unwrap())), + ) + } + }; + Ok(result) + } + } + }; +} + +// String static filters +define_static_filter!( + Utf8StaticFilter, + String, + |arr| arr.as_string_opt::(), + to_string +); + +define_static_filter!( + LargeUtf8StaticFilter, + String, + |arr| arr.as_string_opt::(), + to_string +); + +define_static_filter!( + Utf8ViewStaticFilter, + String, + |arr| arr.as_string_view_opt(), + to_string +); + +// Binary static filters +define_static_filter!( + BinaryStaticFilter, + Vec, + |arr| arr.as_binary_opt::(), + to_vec +); + +define_static_filter!( + LargeBinaryStaticFilter, + Vec, + |arr| arr.as_binary_opt::(), + to_vec +); + +define_static_filter!( + BinaryViewStaticFilter, + Vec, + |arr| arr.as_binary_view_opt(), + to_vec +); + /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( list: &[Arc], @@ -572,11 +878,8 @@ pub fn in_list( // Try to create a static filter for constant expressions let static_filter = try_evaluate_constant_list(&list, schema) - .and_then(ArrayStaticFilter::try_new) - .ok() - .map(|static_filter| { - Arc::new(static_filter) as Arc - }); + .and_then(instantiate_static_filter) + .ok(); Ok(Arc::new(InListExpr::new( expr, From 3763dd0646f52a8ec76ff7fcc81173600bda265a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 20 Nov 2025 08:17:52 +0800 Subject: [PATCH 2/9] add tests --- .../physical-expr/src/expressions/in_list.rs | 610 +++++++++++++++++- 1 file changed, 608 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 4c6d7c54f54a..89144988f9ec 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -507,12 +507,12 @@ macro_rules! define_static_filter { } (false, false, false) => { BooleanArray::from_iter( - v.iter().map(|value| self.values.contains(value.unwrap())), + v.iter().map(|value| self.values.contains(value.expect("null_count is 0"))), ) } (false, false, true) => { BooleanArray::from_iter( - v.iter().map(|value| !self.values.contains(value.unwrap())), + v.iter().map(|value| !self.values.contains(value.expect("null_count is 0"))), ) } }; @@ -1331,6 +1331,612 @@ mod tests { Ok(()) } + #[test] + fn in_list_int8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int8, true)]); + let a = Int8Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0i8), lit(1i8)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0i8), lit(1i8)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0i8), lit(1i8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0i8), lit(1i8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_int16() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int16, true)]); + let a = Int16Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0i16), lit(1i16)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0i16), lit(1i16)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0i16), lit(1i16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0i16), lit(1i16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_int32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0i32), lit(1i32)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0i32), lit(1i32)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0i32), lit(1i32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0i32), lit(1i32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt8, true)]); + let a = UInt8Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u8), lit(1u8)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u8), lit(1u8)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u8), lit(1u8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u8), lit(1u8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint16() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt16, true)]); + let a = UInt16Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u16), lit(1u16)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u16), lit(1u16)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u16), lit(1u16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u16), lit(1u16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, true)]); + let a = UInt32Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u32), lit(1u32)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u32), lit(1u32)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u32), lit(1u32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u32), lit(1u32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt64, true)]); + let a = UInt64Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u64), lit(1u64)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u64), lit(1u64)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u64), lit(1u64), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u64), lit(1u64), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_large_utf8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::LargeUtf8, true)]); + let a = LargeStringArray::from(vec![Some("a"), Some("d"), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::LargeUtf8(None))]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::LargeUtf8(None))]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_utf8_view() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8View, true)]); + let a = StringViewArray::from(vec![Some("a"), Some("d"), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8View(None))]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8View(None))]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_large_binary() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::LargeBinary, true)]); + let a = LargeBinaryArray::from(vec![ + Some([1, 2, 3].as_slice()), + Some([1, 2, 2].as_slice()), + None, + ]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ([1, 2, 3], [4, 5, 6])" + let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6])" + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ([1, 2, 3], [4, 5, 6], null)" + let list = vec![ + lit([1, 2, 3].as_slice()), + lit([4, 5, 6].as_slice()), + lit(ScalarValue::LargeBinary(None)), + ]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6], null)" + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_binary_view() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::BinaryView, true)]); + let a = BinaryViewArray::from(vec![ + Some([1, 2, 3].as_slice()), + Some([1, 2, 2].as_slice()), + None, + ]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ([1, 2, 3], [4, 5, 6])" + let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6])" + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ([1, 2, 3], [4, 5, 6], null)" + let list = vec![ + lit([1, 2, 3].as_slice()), + lit([4, 5, 6].as_slice()), + lit(ScalarValue::BinaryView(None)), + ]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6], null)" + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + #[test] fn in_list_date64() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Date64, true)]); From 263377e03314af636b1b0c8f8e956fd3969bbdbf Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 20 Nov 2025 16:46:26 +0800 Subject: [PATCH 3/9] store hashes --- .../physical-expr/src/expressions/in_list.rs | 176 +++++++----------- 1 file changed, 68 insertions(+), 108 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 89144988f9ec..f2dc9295740e 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -419,34 +419,54 @@ impl StaticFilter for BooleanStaticFilter { } } -// Macro to generate static filter implementations for string and binary types -// This eliminates ~550 lines of duplicated code across 6 implementations -macro_rules! define_static_filter { - ( - $name:ident, - $value_type:ty, - |$arr_param:ident| $downcast:expr, - $convert:ident - ) => { +// Macro to generate hash-based static filter implementations for string and binary types +// This avoids copying string/binary data by storing only the original array and hash indices +macro_rules! define_hash_based_static_filter { + ($name:ident, |$arr_param:ident| $downcast:expr) => { struct $name { + in_array: ArrayRef, + state: RandomState, + map: HashMap, null_count: usize, - values: HashSet<$value_type>, } impl $name { fn try_new(in_array: &ArrayRef) -> Result { - let $arr_param = in_array; - let in_array = $downcast - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; - - let mut values = HashSet::with_capacity(in_array.len()); let null_count = in_array.null_count(); + let in_array_clone = Arc::clone(in_array); + let state = RandomState::new(); + let mut map: HashMap = HashMap::with_hasher(()); + + with_hashes([in_array.as_ref()], &state, |hashes| -> Result<()> { + let cmp = make_comparator(in_array, in_array, SortOptions::default())?; + + let insert_value = |idx| { + let hash = hashes[idx]; + if let RawEntryMut::Vacant(v) = map + .raw_entry_mut() + .from_hash(hash, |x| cmp(*x, idx).is_eq()) + { + v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); + } + }; - for v in in_array.iter().flatten() { - values.insert(v.$convert()); - } + match in_array.nulls() { + Some(nulls) => { + BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) + .for_each(insert_value) + } + None => (0..in_array.len()).for_each(insert_value), + } - Ok(Self { null_count, values }) + Ok(()) + })?; + + Ok(Self { + in_array: in_array_clone, + state, + map, + null_count, + }) } } @@ -466,105 +486,45 @@ macro_rules! define_static_filter { _ => {} } - let $arr_param = v; - let v = $downcast - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; - let haystack_has_nulls = self.null_count > 0; - let result = match (v.null_count() > 0, haystack_has_nulls, negated) { - (true, _, false) | (false, true, false) => { - BooleanArray::from_iter(v.iter().map(|value| { - match value { - None => None, - Some(v) => { - if self.values.contains(v) { - Some(true) - } else if haystack_has_nulls { - None - } else { - Some(false) - } - } - } - })) - } - (true, _, true) | (false, true, true) => { - BooleanArray::from_iter(v.iter().map(|value| { - match value { - None => None, - Some(v) => { - if self.values.contains(v) { - Some(false) - } else if haystack_has_nulls { - None - } else { - Some(true) - } - } - } - })) - } - (false, false, false) => { - BooleanArray::from_iter( - v.iter().map(|value| self.values.contains(value.expect("null_count is 0"))), - ) - } - (false, false, true) => { - BooleanArray::from_iter( - v.iter().map(|value| !self.values.contains(value.expect("null_count is 0"))), - ) - } - }; - Ok(result) + // Use hash-based lookup with verification + with_hashes([v], &self.state, |hashes| { + let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; + + Ok(BooleanArray::from_iter((0..v.len()).map(|i| { + if v.is_null(i) { + return None; + } + + let hash = hashes[i]; + let contains = self + .map + .raw_entry() + .from_hash(hash, |idx| cmp(i, *idx).is_eq()) + .is_some(); + + match contains { + true => Some(!negated), + false if haystack_has_nulls => None, + false => Some(negated), + } + }))) + }) } } }; } // String static filters -define_static_filter!( - Utf8StaticFilter, - String, - |arr| arr.as_string_opt::(), - to_string -); - -define_static_filter!( - LargeUtf8StaticFilter, - String, - |arr| arr.as_string_opt::(), - to_string -); - -define_static_filter!( - Utf8ViewStaticFilter, - String, - |arr| arr.as_string_view_opt(), - to_string -); +define_hash_based_static_filter!(Utf8StaticFilter, |arr| arr.as_string_opt::()); +define_hash_based_static_filter!(LargeUtf8StaticFilter, |arr| arr.as_string_opt::()); +define_hash_based_static_filter!(Utf8ViewStaticFilter, |arr| arr.as_string_view_opt()); // Binary static filters -define_static_filter!( - BinaryStaticFilter, - Vec, - |arr| arr.as_binary_opt::(), - to_vec -); - -define_static_filter!( - LargeBinaryStaticFilter, - Vec, - |arr| arr.as_binary_opt::(), - to_vec -); - -define_static_filter!( - BinaryViewStaticFilter, - Vec, - |arr| arr.as_binary_view_opt(), - to_vec -); +define_hash_based_static_filter!(BinaryStaticFilter, |arr| arr.as_binary_opt::()); +define_hash_based_static_filter!(LargeBinaryStaticFilter, |arr| arr.as_binary_opt::()); +define_hash_based_static_filter!(BinaryViewStaticFilter, |arr| arr.as_binary_view_opt()); /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( From b48ca3d0679c849c9094518eda1f8c7317878c43 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 20 Nov 2025 18:14:21 +0800 Subject: [PATCH 4/9] remove string types --- .../physical-expr/src/expressions/in_list.rs | 117 ------------------ 1 file changed, 117 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index f2dc9295740e..3d228bd5e2e7 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -149,16 +149,6 @@ fn instantiate_static_filter( DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), // Boolean DataType::Boolean => Ok(Arc::new(BooleanStaticFilter::try_new(&in_array)?)), - // String types - DataType::Utf8 => Ok(Arc::new(Utf8StaticFilter::try_new(&in_array)?)), - DataType::LargeUtf8 => Ok(Arc::new(LargeUtf8StaticFilter::try_new(&in_array)?)), - DataType::Utf8View => Ok(Arc::new(Utf8ViewStaticFilter::try_new(&in_array)?)), - // Binary types - DataType::Binary => Ok(Arc::new(BinaryStaticFilter::try_new(&in_array)?)), - DataType::LargeBinary => { - Ok(Arc::new(LargeBinaryStaticFilter::try_new(&in_array)?)) - } - DataType::BinaryView => Ok(Arc::new(BinaryViewStaticFilter::try_new(&in_array)?)), _ => { /* fall through to generic implementation for unsupported types (Float32/Float64, Struct, etc.) */ Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) @@ -419,113 +409,6 @@ impl StaticFilter for BooleanStaticFilter { } } -// Macro to generate hash-based static filter implementations for string and binary types -// This avoids copying string/binary data by storing only the original array and hash indices -macro_rules! define_hash_based_static_filter { - ($name:ident, |$arr_param:ident| $downcast:expr) => { - struct $name { - in_array: ArrayRef, - state: RandomState, - map: HashMap, - null_count: usize, - } - - impl $name { - fn try_new(in_array: &ArrayRef) -> Result { - let null_count = in_array.null_count(); - let in_array_clone = Arc::clone(in_array); - let state = RandomState::new(); - let mut map: HashMap = HashMap::with_hasher(()); - - with_hashes([in_array.as_ref()], &state, |hashes| -> Result<()> { - let cmp = make_comparator(in_array, in_array, SortOptions::default())?; - - let insert_value = |idx| { - let hash = hashes[idx]; - if let RawEntryMut::Vacant(v) = map - .raw_entry_mut() - .from_hash(hash, |x| cmp(*x, idx).is_eq()) - { - v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); - } - }; - - match in_array.nulls() { - Some(nulls) => { - BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) - .for_each(insert_value) - } - None => (0..in_array.len()).for_each(insert_value), - } - - Ok(()) - })?; - - Ok(Self { - in_array: in_array_clone, - state, - map, - null_count, - }) - } - } - - impl StaticFilter for $name { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } - - let haystack_has_nulls = self.null_count > 0; - - // Use hash-based lookup with verification - with_hashes([v], &self.state, |hashes| { - let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; - - Ok(BooleanArray::from_iter((0..v.len()).map(|i| { - if v.is_null(i) { - return None; - } - - let hash = hashes[i]; - let contains = self - .map - .raw_entry() - .from_hash(hash, |idx| cmp(i, *idx).is_eq()) - .is_some(); - - match contains { - true => Some(!negated), - false if haystack_has_nulls => None, - false => Some(negated), - } - }))) - }) - } - } - }; -} - -// String static filters -define_hash_based_static_filter!(Utf8StaticFilter, |arr| arr.as_string_opt::()); -define_hash_based_static_filter!(LargeUtf8StaticFilter, |arr| arr.as_string_opt::()); -define_hash_based_static_filter!(Utf8ViewStaticFilter, |arr| arr.as_string_view_opt()); - -// Binary static filters -define_hash_based_static_filter!(BinaryStaticFilter, |arr| arr.as_binary_opt::()); -define_hash_based_static_filter!(LargeBinaryStaticFilter, |arr| arr.as_binary_opt::()); -define_hash_based_static_filter!(BinaryViewStaticFilter, |arr| arr.as_binary_view_opt()); - /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( list: &[Arc], From e983b64347409f9ee761ebed19c651edb714f5e9 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 20 Nov 2025 22:32:53 +0800 Subject: [PATCH 5/9] Apply suggestions from code review Co-authored-by: Martin Grigorov --- datafusion/physical-expr/src/expressions/in_list.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 3d228bd5e2e7..afaa9ae93f87 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -220,7 +220,7 @@ macro_rules! primitive_static_filter { fn try_new(in_array: &ArrayRef) -> Result { let in_array = in_array .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + .ok_or_else(|| exec_datafusion_err!(format!("Failed to downcast an array to a '{}' array", stringify!($ArrowType))))?; let mut values = HashSet::with_capacity(in_array.len()); let null_count = in_array.null_count(); @@ -251,7 +251,7 @@ macro_rules! primitive_static_filter { let v = v .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + .ok_or_else(|| exec_datafusion_err!(format!("Failed to downcast an array to a '{}' array", stringify!($ArrowType))))?; let haystack_has_nulls = self.null_count > 0; @@ -334,7 +334,7 @@ impl BooleanStaticFilter { fn try_new(in_array: &ArrayRef) -> Result { let in_array = in_array .as_boolean_opt() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a boolean array"))?; let mut values = HashSet::with_capacity(in_array.len().min(2)); let null_count = in_array.null_count(); @@ -365,7 +365,7 @@ impl StaticFilter for BooleanStaticFilter { let v = v .as_boolean_opt() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a boolean array"))?; let haystack_has_nulls = self.null_count > 0; From ff302de6ad1762bf7c908c7e449e5a77f37d52a3 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 21 Nov 2025 07:22:20 +0800 Subject: [PATCH 6/9] remove boolean specialization --- .../physical-expr/src/expressions/in_list.rs | 87 ------------------- 1 file changed, 87 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index afaa9ae93f87..5f6b9ba68406 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -147,8 +147,6 @@ fn instantiate_static_filter( DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), - // Boolean - DataType::Boolean => Ok(Arc::new(BooleanStaticFilter::try_new(&in_array)?)), _ => { /* fall through to generic implementation for unsupported types (Float32/Float64, Struct, etc.) */ Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) @@ -324,91 +322,6 @@ primitive_static_filter!(UInt16StaticFilter, UInt16Type); primitive_static_filter!(UInt32StaticFilter, UInt32Type); primitive_static_filter!(UInt64StaticFilter, UInt64Type); -// Boolean static filter -struct BooleanStaticFilter { - null_count: usize, - values: HashSet, -} - -impl BooleanStaticFilter { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_boolean_opt() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a boolean array"))?; - - let mut values = HashSet::with_capacity(in_array.len().min(2)); - let null_count = in_array.null_count(); - - for v in in_array.iter().flatten() { - values.insert(v); - } - - Ok(Self { null_count, values }) - } -} - -impl StaticFilter for BooleanStaticFilter { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } - - let v = v - .as_boolean_opt() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a boolean array"))?; - - let haystack_has_nulls = self.null_count > 0; - - let result = match (v.null_count() > 0, haystack_has_nulls, negated) { - (true, _, false) | (false, true, false) => { - BooleanArray::from_iter(v.iter().map(|value| match value { - None => None, - Some(v) => { - if self.values.contains(&v) { - Some(true) - } else if haystack_has_nulls { - None - } else { - Some(false) - } - } - })) - } - (true, _, true) | (false, true, true) => { - BooleanArray::from_iter(v.iter().map(|value| match value { - None => None, - Some(v) => { - if self.values.contains(&v) { - Some(false) - } else if haystack_has_nulls { - None - } else { - Some(true) - } - } - })) - } - (false, false, false) => BooleanArray::from_iter( - v.values().iter().map(|value| self.values.contains(&value)), - ), - (false, false, true) => BooleanArray::from_iter( - v.values().iter().map(|value| !self.values.contains(&value)), - ), - }; - Ok(result) - } -} - /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( list: &[Arc], From 1e4782f1cde36a83398002a4730ff96123b547a3 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sat, 22 Nov 2025 16:43:47 +0800 Subject: [PATCH 7/9] fix, use BooleanBuffer --- .../physical-expr/src/expressions/in_list.rs | 51 ++++++++++++------- 1 file changed, 34 insertions(+), 17 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 5f6b9ba68406..51daa073efa1 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -26,7 +26,7 @@ use crate::physical_expr::physical_exprs_bag_equal; use crate::PhysicalExpr; use arrow::array::*; -use arrow::buffer::BooleanBuffer; +use arrow::buffer::{BooleanBuffer, NullBuffer}; use arrow::compute::kernels::boolean::{not, or_kleene}; use arrow::compute::{take, SortOptions}; use arrow::datatypes::*; @@ -91,7 +91,12 @@ impl StaticFilter for ArrayStaticFilter { if v.data_type() == &DataType::Null || self.in_array.data_type() == &DataType::Null { - return Ok(BooleanArray::from(vec![None; v.len()])); + // return Ok(BooleanArray::new(vec![None; v.len()])); + let nulls = NullBuffer::new_null(v.len()); + return Ok(BooleanArray::new( + BooleanBuffer::new_unset(v.len()), + Some(nulls), + )); } downcast_dictionary_array! { @@ -218,7 +223,7 @@ macro_rules! primitive_static_filter { fn try_new(in_array: &ArrayRef) -> Result { let in_array = in_array .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!(format!("Failed to downcast an array to a '{}' array", stringify!($ArrowType))))?; + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; let mut values = HashSet::with_capacity(in_array.len()); let null_count = in_array.null_count(); @@ -249,7 +254,7 @@ macro_rules! primitive_static_filter { let v = v .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!(format!("Failed to downcast an array to a '{}' array", stringify!($ArrowType))))?; + .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; let haystack_has_nulls = self.null_count > 0; @@ -294,15 +299,20 @@ macro_rules! primitive_static_filter { } (false, false, false) => { // no nulls anywhere, not negated - BooleanArray::from_iter( - v.values().iter().map(|value| self.values.contains(value)), - ) + let values = v.values(); + let mut builder = BooleanBufferBuilder::new(values.len()); + for value in values.iter() { + builder.append(self.values.contains(value)); + } + BooleanArray::new(builder.finish(), None) } (false, false, true) => { - // no nulls anywhere, negated - BooleanArray::from_iter( - v.values().iter().map(|value| !self.values.contains(value)), - ) + let values = v.values(); + let mut builder = BooleanBufferBuilder::new(values.len()); + for value in values.iter() { + builder.append(!self.values.contains(value)); + } + BooleanArray::new(builder.finish(), None) } }; Ok(result) @@ -476,8 +486,12 @@ impl PhysicalExpr for InListExpr { if scalar.is_null() { // SQL three-valued logic: null IN (...) is always null // The code below would handle this correctly but this is a faster path + let nulls = NullBuffer::new_null(num_rows); return Ok(ColumnarValue::Array(Arc::new( - BooleanArray::from(vec![None; num_rows]), + BooleanArray::new( + BooleanBuffer::new_unset(num_rows), + Some(nulls), + ), ))); } // Use a 1 row array to avoid code duplication/branching @@ -488,12 +502,15 @@ impl PhysicalExpr for InListExpr { // Broadcast the single result to all rows // Must check is_null() to preserve NULL values (SQL three-valued logic) if result_array.is_null(0) { - BooleanArray::from(vec![None; num_rows]) + let nulls = NullBuffer::new_null(num_rows); + BooleanArray::new( + BooleanBuffer::new_unset(num_rows), + Some(nulls), + ) + } else if result_array.value(0) { + BooleanArray::new(BooleanBuffer::new_set(num_rows), None) } else { - BooleanArray::from_iter(std::iter::repeat_n( - result_array.value(0), - num_rows, - )) + BooleanArray::new(BooleanBuffer::new_unset(num_rows), None) } } } From d299a91d2d3236b945e0816b09cb19ad14151cef Mon Sep 17 00:00:00 2001 From: Geoffrey Claude Date: Mon, 8 Dec 2025 20:27:01 +0100 Subject: [PATCH 8/9] Short InList Optimization (#46) --- datafusion/physical-expr/benches/in_list.rs | 169 ++++-- .../physical-expr/src/expressions/in_list.rs | 543 ++++++++++++++---- 2 files changed, 547 insertions(+), 165 deletions(-) diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index 778204055bbd..664bc2341074 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, Float32Array, Int32Array, StringArray}; +use arrow::array::{ + Array, ArrayRef, Float32Array, Int32Array, StringArray, StringViewArray, +}; use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; use criterion::{criterion_group, criterion_main, Criterion}; @@ -23,9 +25,11 @@ use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::{col, in_list, lit}; use rand::distr::Alphanumeric; use rand::prelude::*; +use std::any::TypeId; use std::hint::black_box; use std::sync::Arc; +/// Measures how long `in_list(col("a"), exprs)` takes to evaluate against a single RecordBatch. fn do_bench(c: &mut Criterion, name: &str, values: ArrayRef, exprs: &[ScalarValue]) { let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); let exprs = exprs.iter().map(|s| lit(s.clone())).collect(); @@ -37,78 +41,129 @@ fn do_bench(c: &mut Criterion, name: &str, values: ArrayRef, exprs: &[ScalarValu }); } +/// Generates a random alphanumeric string of the specified length. fn random_string(rng: &mut StdRng, len: usize) -> String { let value = rng.sample_iter(&Alphanumeric).take(len).collect(); String::from_utf8(value).unwrap() } -fn do_benches( - c: &mut Criterion, - array_length: usize, - in_list_length: usize, - null_percent: f64, -) { - let mut rng = StdRng::seed_from_u64(120320); - for string_length in [5, 10, 20] { - let values: StringArray = (0..array_length) - .map(|_| { - rng.random_bool(null_percent) - .then(|| random_string(&mut rng, string_length)) - }) - .collect(); - - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::from(random_string(&mut rng, string_length))) - .collect(); - - do_bench( - c, - &format!( - "in_list_utf8({string_length}) ({array_length}, {null_percent}) IN ({in_list_length}, 0)" - ), - Arc::new(values), - &in_list, - ) +const IN_LIST_LENGTHS: [usize; 3] = [3, 8, 100]; +const NULL_PERCENTS: [f64; 2] = [0., 0.2]; +const STRING_LENGTHS: [usize; 3] = [3, 12, 100]; +const ARRAY_LENGTH: usize = 1024; + +/// Returns a friendly type name for the array type. +fn array_type_name() -> &'static str { + let id = TypeId::of::(); + if id == TypeId::of::() { + "Utf8" + } else if id == TypeId::of::() { + "Utf8View" + } else if id == TypeId::of::() { + "Float32" + } else if id == TypeId::of::() { + "Int32" + } else { + "Unknown" } +} - let values: Float32Array = (0..array_length) - .map(|_| rng.random_bool(null_percent).then(|| rng.random())) - .collect(); +/// Builds a benchmark name from array type, list size, and null percentage. +fn bench_name(in_list_length: usize, null_percent: f64) -> String { + format!( + "in_list/{}/list={in_list_length}/nulls={}%", + array_type_name::(), + (null_percent * 100.0) as u32 + ) +} - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Float32(Some(rng.random()))) - .collect(); +/// Runs in_list benchmarks for a string array type across all list-size × null-ratio × string-length combinations. +fn bench_string_type( + c: &mut Criterion, + rng: &mut StdRng, + make_scalar: fn(String) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + for in_list_length in IN_LIST_LENGTHS { + for null_percent in NULL_PERCENTS { + for string_length in STRING_LENGTHS { + let values: A = (0..ARRAY_LENGTH) + .map(|_| { + rng.random_bool(1.0 - null_percent) + .then(|| random_string(rng, string_length)) + }) + .collect(); - do_bench( - c, - &format!("in_list_f32 ({array_length}, {null_percent}) IN ({in_list_length}, 0)"), - Arc::new(values), - &in_list, - ); + let in_list: Vec<_> = (0..in_list_length) + .map(|_| make_scalar(random_string(rng, string_length))) + .collect(); - let values: Int32Array = (0..array_length) - .map(|_| rng.random_bool(null_percent).then(|| rng.random())) - .collect(); + do_bench( + c, + &format!( + "{}/str={string_length}", + bench_name::(in_list_length, null_percent) + ), + Arc::new(values), + &in_list, + ) + } + } + } +} - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Int32(Some(rng.random()))) - .collect(); +/// Runs in_list benchmarks for a numeric array type across all list-size × null-ratio combinations. +fn bench_numeric_type( + c: &mut Criterion, + rng: &mut StdRng, + mut gen_value: impl FnMut(&mut StdRng) -> T, + make_scalar: fn(T) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + for in_list_length in IN_LIST_LENGTHS { + for null_percent in NULL_PERCENTS { + let values: A = (0..ARRAY_LENGTH) + .map(|_| rng.random_bool(1.0 - null_percent).then(|| gen_value(rng))) + .collect(); - do_bench( - c, - &format!("in_list_i32 ({array_length}, {null_percent}) IN ({in_list_length}, 0)"), - Arc::new(values), - &in_list, - ) -} + let in_list: Vec<_> = (0..in_list_length) + .map(|_| make_scalar(gen_value(rng))) + .collect(); -fn criterion_benchmark(c: &mut Criterion) { - for in_list_length in [1, 3, 10, 100] { - for null_percent in [0., 0.2] { - do_benches(c, 1024, in_list_length, null_percent) + do_bench( + c, + &bench_name::(in_list_length, null_percent), + Arc::new(values), + &in_list, + ); } } } +/// Entry point: registers in_list benchmarks for Utf8, Utf8View, Float32, and Int32 arrays. +fn criterion_benchmark(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(120320); + + // Benchmarks for string array types (Utf8, Utf8View) + bench_string_type::(c, &mut rng, |s| ScalarValue::Utf8(Some(s))); + bench_string_type::(c, &mut rng, |s| ScalarValue::Utf8View(Some(s))); + + // Benchmarks for numeric types + bench_numeric_type::( + c, + &mut rng, + |rng| rng.random(), + |v| ScalarValue::Float32(Some(v)), + ); + bench_numeric_type::( + c, + &mut rng, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ); +} + criterion_group!(benches, criterion_benchmark); criterion_main!(benches); diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 51daa073efa1..78eaf6fadf4b 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -141,21 +141,58 @@ impl StaticFilter for ArrayStaticFilter { fn instantiate_static_filter( in_array: ArrayRef, +) -> Result> { + if in_array.len() <= SORTED_LOOKUP_MAX_LEN { + instantiate_sorted_filter(in_array) + } else { + instantiate_hashed_filter(in_array) + } +} + +/// Sorted filter using binary search. Best for small lists (≤8 elements). +fn instantiate_sorted_filter( + in_array: ArrayRef, ) -> Result> { match in_array.data_type() { - // Integer primitive types - DataType::Int8 => Ok(Arc::new(Int8StaticFilter::try_new(&in_array)?)), - DataType::Int16 => Ok(Arc::new(Int16StaticFilter::try_new(&in_array)?)), - DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), - DataType::Int64 => Ok(Arc::new(Int64StaticFilter::try_new(&in_array)?)), - DataType::UInt8 => Ok(Arc::new(UInt8StaticFilter::try_new(&in_array)?)), - DataType::UInt16 => Ok(Arc::new(UInt16StaticFilter::try_new(&in_array)?)), - DataType::UInt32 => Ok(Arc::new(UInt32StaticFilter::try_new(&in_array)?)), - DataType::UInt64 => Ok(Arc::new(UInt64StaticFilter::try_new(&in_array)?)), - _ => { - /* fall through to generic implementation for unsupported types (Float32/Float64, Struct, etc.) */ - Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) - } + DataType::Int8 => Ok(Arc::new(Int8SortedFilter::try_new(&in_array)?)), + DataType::Int16 => Ok(Arc::new(Int16SortedFilter::try_new(&in_array)?)), + DataType::Int32 => Ok(Arc::new(Int32SortedFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64SortedFilter::try_new(&in_array)?)), + DataType::UInt8 => Ok(Arc::new(UInt8SortedFilter::try_new(&in_array)?)), + DataType::UInt16 => Ok(Arc::new(UInt16SortedFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32SortedFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64SortedFilter::try_new(&in_array)?)), + DataType::Float32 => Ok(Arc::new(Float32SortedFilter::try_new(&in_array)?)), + DataType::Float64 => Ok(Arc::new(Float64SortedFilter::try_new(&in_array)?)), + DataType::Utf8View => match Utf8ViewSortedFilter::try_new(&in_array) { + Some(Ok(filter)) => Ok(Arc::new(filter)), + Some(Err(e)) => Err(e), + None => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), + }, + _ => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), + } +} + +/// Hashed filter using HashSet. Best for larger lists (>8 elements). +fn instantiate_hashed_filter( + in_array: ArrayRef, +) -> Result> { + match in_array.data_type() { + DataType::Int8 => Ok(Arc::new(Int8HashedFilter::try_new(&in_array)?)), + DataType::Int16 => Ok(Arc::new(Int16HashedFilter::try_new(&in_array)?)), + DataType::Int32 => Ok(Arc::new(Int32HashedFilter::try_new(&in_array)?)), + DataType::Int64 => Ok(Arc::new(Int64HashedFilter::try_new(&in_array)?)), + DataType::UInt8 => Ok(Arc::new(UInt8HashedFilter::try_new(&in_array)?)), + DataType::UInt16 => Ok(Arc::new(UInt16HashedFilter::try_new(&in_array)?)), + DataType::UInt32 => Ok(Arc::new(UInt32HashedFilter::try_new(&in_array)?)), + DataType::UInt64 => Ok(Arc::new(UInt64HashedFilter::try_new(&in_array)?)), + // Floats don't implement Hash, fall through to generic + DataType::Utf8View => match Utf8ViewHashedFilter::try_new(&in_array) { + Some(Ok(filter)) => Ok(Arc::new(filter)), + Some(Err(e)) => Err(e), + None => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), + }, + _ => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), } } @@ -211,26 +248,184 @@ impl ArrayStaticFilter { } } -// Macro to generate specialized StaticFilter implementations for primitive types -macro_rules! primitive_static_filter { - ($Name:ident, $ArrowType:ty) => { +/// Threshold for switching from sorted Vec (binary search) to HashSet +/// For small lists, binary search has better cache locality and lower overhead +/// Maximum list size for using sorted lookup (binary search). +/// Lists with more elements use hash lookup instead. +const SORTED_LOOKUP_MAX_LEN: usize = 8; + +/// Helper to build a BooleanArray result for IN list operations. +/// Handles SQL three-valued logic for NULL values. +/// +/// # Arguments +/// * `len` - Number of elements in the needle array +/// * `needle_nulls` - Optional validity buffer from the needle array +/// * `haystack_has_nulls` - Whether the IN list contains NULL values +/// * `negated` - Whether this is a NOT IN operation +/// * `contains` - Closure that returns whether needle[i] is found in the haystack +#[inline] +fn build_in_list_result( + len: usize, + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains: C, +) -> BooleanArray +where + C: Fn(usize) -> bool, +{ + // Use collect_bool for all paths - it's vectorized and faster than element-by-element append. + // Match on (needle_has_nulls, haystack_has_nulls, negated) to specialize each case. + match (needle_nulls, haystack_has_nulls, negated) { + // Haystack has nulls: result is NULL when not found (might match the NULL) + // values_buf == nulls_buf, so compute once and clone + (Some(validity), true, false) => { + let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + let buf = validity.inner() & &contains_buf; + BooleanArray::new(buf.clone(), Some(NullBuffer::new(buf))) + } + (None, true, false) => { + let buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + BooleanArray::new(buf.clone(), Some(NullBuffer::new(buf))) + } + (Some(validity), true, true) => { + // Compute nulls_buf via SIMD AND, then derive values_buf via XOR. + // Uses identity: A & !B = A ^ (A & B) to get values from nulls. + let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + let nulls_buf = validity.inner() & &contains_buf; + let values_buf = validity.inner() ^ &nulls_buf; // valid & !contains + BooleanArray::new(values_buf, Some(NullBuffer::new(nulls_buf))) + } + (None, true, true) => { + // No needle nulls, but haystack has nulls + let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + let values_buf = !&contains_buf; + BooleanArray::new(values_buf, Some(NullBuffer::new(contains_buf))) + } + // Only needle has nulls: nulls_buf is just validity (reuse it directly!) + (Some(validity), false, false) => { + let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + let values_buf = validity.inner() & &contains_buf; + BooleanArray::new(values_buf, Some(validity.clone())) + } + (Some(validity), false, true) => { + let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + let values_buf = validity.inner() & &(!&contains_buf); + BooleanArray::new(values_buf, Some(validity.clone())) + } + // No nulls anywhere: no validity buffer needed + (None, false, false) => { + let buf = BooleanBuffer::collect_bool(len, |i| contains(i)); + BooleanArray::new(buf, None) + } + (None, false, true) => { + let buf = BooleanBuffer::collect_bool(len, |i| !contains(i)); + BooleanArray::new(buf, None) + } + } +} + +/// Sorted lookup using binary search. Best for small lists (< 8 elements). +struct SortedLookup(Vec); + +impl SortedLookup { + fn new(mut values: Vec) -> Self { + values.sort_unstable(); + values.dedup(); + Self(values) + } + + #[inline] + fn contains(&self, value: &T) -> bool { + self.0.binary_search(value).is_ok() + } +} + +/// Sorted lookup for f32 using total_cmp (floats don't implement Ord due to NaN). +struct F32SortedLookup(Vec); + +impl F32SortedLookup { + fn new(mut values: Vec) -> Self { + values.sort_unstable_by(|a, b| a.total_cmp(b)); + values.dedup_by(|a, b| a.total_cmp(b).is_eq()); + Self(values) + } + + #[inline] + fn contains(&self, value: &f32) -> bool { + self.0 + .binary_search_by(|probe| probe.total_cmp(value)) + .is_ok() + } +} + +/// Sorted lookup for f64 using total_cmp (floats don't implement Ord due to NaN). +struct F64SortedLookup(Vec); + +impl F64SortedLookup { + fn new(mut values: Vec) -> Self { + values.sort_unstable_by(|a, b| a.total_cmp(b)); + values.dedup_by(|a, b| a.total_cmp(b).is_eq()); + Self(values) + } + + #[inline] + fn contains(&self, value: &f64) -> bool { + self.0 + .binary_search_by(|probe| probe.total_cmp(value)) + .is_ok() + } +} + +/// Hash-based lookup. Best for larger lists (>= 8 elements). +struct HashedLookup(HashSet); + +impl HashedLookup { + fn new(values: Vec) -> Self { + Self(values.into_iter().collect()) + } + + #[inline] + fn contains(&self, value: &T) -> bool { + self.0.contains(value) + } +} + +/// Helper macro for dictionary array handling in StaticFilter::contains +/// This pattern is the same across all filter implementations +macro_rules! handle_dictionary { + ($self:ident, $v:ident, $negated:ident) => { + downcast_dictionary_array! { + $v => { + let values_contains = $self.contains($v.values().as_ref(), $negated)?; + let result = take(&values_contains, $v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + }; +} + +// Base macro to generate sorted StaticFilter with explicit lookup type. +macro_rules! sorted_static_filter_impl { + ($Name:ident, $ArrowType:ty, $LookupType:ty) => { struct $Name { null_count: usize, - values: HashSet<<$ArrowType as ArrowPrimitiveType>::Native>, + values: $LookupType, } impl $Name { fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; + let in_array = + in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; - let mut values = HashSet::with_capacity(in_array.len()); let null_count = in_array.null_count(); - - for v in in_array.iter().flatten() { - values.insert(v); - } + let values = <$LookupType>::new(in_array.iter().flatten().collect()); Ok(Self { null_count, values }) } @@ -242,95 +437,227 @@ macro_rules! primitive_static_filter { } fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Handle dictionary arrays by recursing on the values - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} - } + handle_dictionary!(self, v, negated); + + let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; + + let values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + // SAFETY: i < len is guaranteed by build_in_list_result + |i| self.values.contains(unsafe { values.get_unchecked(i) }), + )) + } + } + }; +} - let v = v - .as_primitive_opt::<$ArrowType>() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast an array to a '{}' array", stringify!($ArrowType)))?; - - let haystack_has_nulls = self.null_count > 0; - - let result = match (v.null_count() > 0, haystack_has_nulls, negated) { - (true, _, false) | (false, true, false) => { - // Either needle or haystack has nulls, not negated - BooleanArray::from_iter(v.iter().map(|value| { - match value { - // SQL three-valued logic: null IN (...) is always null - None => None, - Some(v) => { - if self.values.contains(&v) { - Some(true) - } else if haystack_has_nulls { - // value not in set, but set has nulls -> null - None - } else { - Some(false) - } - } - } - })) - } - (true, _, true) | (false, true, true) => { - // Either needle or haystack has nulls, negated - BooleanArray::from_iter(v.iter().map(|value| { - match value { - // SQL three-valued logic: null NOT IN (...) is always null - None => None, - Some(v) => { - if self.values.contains(&v) { - Some(false) - } else if haystack_has_nulls { - // value not in set, but set has nulls -> null - None - } else { - Some(true) - } - } - } - })) - } - (false, false, false) => { - // no nulls anywhere, not negated - let values = v.values(); - let mut builder = BooleanBufferBuilder::new(values.len()); - for value in values.iter() { - builder.append(self.values.contains(value)); - } - BooleanArray::new(builder.finish(), None) - } - (false, false, true) => { - let values = v.values(); - let mut builder = BooleanBufferBuilder::new(values.len()); - for value in values.iter() { - builder.append(!self.values.contains(value)); - } - BooleanArray::new(builder.finish(), None) - } - }; - Ok(result) +// Convenience macro for integer types (derives SortedLookup from ArrowType). +macro_rules! sorted_static_filter { + ($Name:ident, $ArrowType:ty) => { + sorted_static_filter_impl!( + $Name, + $ArrowType, + SortedLookup<<$ArrowType as ArrowPrimitiveType>::Native> + ); + }; +} + +// Macro to generate hashed StaticFilter for primitive types using HashedLookup. +macro_rules! hashed_static_filter { + ($Name:ident, $ArrowType:ty) => { + struct $Name { + null_count: usize, + values: HashedLookup<<$ArrowType as ArrowPrimitiveType>::Native>, + } + + impl $Name { + fn try_new(in_array: &ArrayRef) -> Result { + let in_array = + in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; + + let null_count = in_array.null_count(); + let values = HashedLookup::new(in_array.iter().flatten().collect()); + + Ok(Self { null_count, values }) + } + } + + impl StaticFilter for $Name { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { + exec_datafusion_err!( + "Failed to downcast an array to a '{}' array", + stringify!($ArrowType) + ) + })?; + + let values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + // SAFETY: i < len is guaranteed by build_in_list_result + |i| self.values.contains(unsafe { values.get_unchecked(i) }), + )) } } }; } -// Generate specialized filters for all integer primitive types -// Note: Float32 and Float64 are excluded because they don't implement Hash/Eq due to NaN -primitive_static_filter!(Int8StaticFilter, Int8Type); -primitive_static_filter!(Int16StaticFilter, Int16Type); -primitive_static_filter!(Int32StaticFilter, Int32Type); -primitive_static_filter!(Int64StaticFilter, Int64Type); -primitive_static_filter!(UInt8StaticFilter, UInt8Type); -primitive_static_filter!(UInt16StaticFilter, UInt16Type); -primitive_static_filter!(UInt32StaticFilter, UInt32Type); -primitive_static_filter!(UInt64StaticFilter, UInt64Type); +// Generate specialized filters for integer types (sorted for small lists, hashed for large). +sorted_static_filter!(Int8SortedFilter, Int8Type); +sorted_static_filter!(Int16SortedFilter, Int16Type); +sorted_static_filter!(Int32SortedFilter, Int32Type); +sorted_static_filter!(Int64SortedFilter, Int64Type); +sorted_static_filter!(UInt8SortedFilter, UInt8Type); +sorted_static_filter!(UInt16SortedFilter, UInt16Type); +sorted_static_filter!(UInt32SortedFilter, UInt32Type); +sorted_static_filter!(UInt64SortedFilter, UInt64Type); + +hashed_static_filter!(Int8HashedFilter, Int8Type); +hashed_static_filter!(Int16HashedFilter, Int16Type); +hashed_static_filter!(Int32HashedFilter, Int32Type); +hashed_static_filter!(Int64HashedFilter, Int64Type); +hashed_static_filter!(UInt8HashedFilter, UInt8Type); +hashed_static_filter!(UInt16HashedFilter, UInt16Type); +hashed_static_filter!(UInt32HashedFilter, UInt32Type); +hashed_static_filter!(UInt64HashedFilter, UInt64Type); + +// Float types: sorted only (floats don't implement Hash/Eq due to NaN). +sorted_static_filter_impl!(Float32SortedFilter, Float32Type, F32SortedLookup); +sorted_static_filter_impl!(Float64SortedFilter, Float64Type, F64SortedLookup); + +/// Maximum length for inline strings in Utf8View. +/// Strings ≤12 bytes are stored entirely inline in the u128 view. +const UTF8VIEW_INLINE_LEN: usize = 12; + +/// Extract string length from a StringView u128 representation +/// Layout: bytes 0-3 = length (u32 little-endian), bytes 4-15 = inline data +#[inline] +fn view_len(view: u128) -> usize { + (view as u32) as usize +} + +/// Returns (null_count, views) if all non-null strings are ≤12 bytes, otherwise None. +fn collect_short_string_views(in_array: &ArrayRef) -> Option<(usize, Vec)> { + let in_array = in_array.as_string_view_opt()?; + let raw_views = in_array.views(); + + // Check that all non-null strings are ≤12 bytes (inline) + for i in 0..in_array.len() { + if in_array.is_valid(i) && view_len(raw_views[i]) > UTF8VIEW_INLINE_LEN { + return None; // Has long strings, use generic filter + } + } + + let views: Vec = (0..in_array.len()) + .filter(|&i| in_array.is_valid(i)) + .map(|i| raw_views[i]) + .collect(); + + Some((in_array.null_count(), views)) +} + +/// Sorted filter for Utf8View when all values are short (≤12 bytes inline). +/// Uses binary search over sorted raw u128 views. Best for small lists. +struct Utf8ViewSortedFilter { + null_count: usize, + values: SortedLookup, +} + +impl Utf8ViewSortedFilter { + fn try_new(in_array: &ArrayRef) -> Option> { + let (null_count, views) = collect_short_string_views(in_array)?; + Some(Ok(Self { + null_count, + values: SortedLookup::new(views), + })) + } +} + +impl StaticFilter for Utf8ViewSortedFilter { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + let v = v.as_string_view_opt().ok_or_else(|| { + exec_datafusion_err!("Failed to downcast array to StringViewArray") + })?; + + let views = v.views(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + |i| self.values.contains(&views[i]), + )) + } +} + +/// Hashed filter for Utf8View when all values are short (≤12 bytes inline). +/// Uses hash lookup over u128 views. Best for large lists. +struct Utf8ViewHashedFilter { + null_count: usize, + values: HashedLookup, +} + +impl Utf8ViewHashedFilter { + fn try_new(in_array: &ArrayRef) -> Option> { + let (null_count, views) = collect_short_string_views(in_array)?; + Some(Ok(Self { + null_count, + values: HashedLookup::new(views), + })) + } +} + +impl StaticFilter for Utf8ViewHashedFilter { + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + + let v = v.as_string_view_opt().ok_or_else(|| { + exec_datafusion_err!("Failed to downcast array to StringViewArray") + })?; + + let views = v.views(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + |i| self.values.contains(&views[i]), + )) + } +} /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( From 0c68f1de93eb38dfafc3d65494a4a2aec0e41a7a Mon Sep 17 00:00:00 2001 From: Geoffrey Claude Date: Tue, 9 Dec 2025 16:59:34 +0100 Subject: [PATCH 9/9] perf(in_list): optimize IN expression with branchless and type-normalized filters Introduce multi-strategy filter selection (branchless/binary/hash) based on list size and data type. Add type reinterpretation to reduce implementations and fast paths for null-free evaluation. --- datafusion/physical-expr/benches/in_list.rs | 45 +- .../physical-expr/src/expressions/in_list.rs | 878 ++++++++++-------- 2 files changed, 521 insertions(+), 402 deletions(-) diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index 664bc2341074..73b40490526f 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -25,7 +25,6 @@ use datafusion_common::ScalarValue; use datafusion_physical_expr::expressions::{col, in_list, lit}; use rand::distr::Alphanumeric; use rand::prelude::*; -use std::any::TypeId; use std::hint::black_box; use std::sync::Arc; @@ -47,36 +46,11 @@ fn random_string(rng: &mut StdRng, len: usize) -> String { String::from_utf8(value).unwrap() } -const IN_LIST_LENGTHS: [usize; 3] = [3, 8, 100]; +const IN_LIST_LENGTHS: [usize; 4] = [3, 6, 8, 100]; const NULL_PERCENTS: [f64; 2] = [0., 0.2]; const STRING_LENGTHS: [usize; 3] = [3, 12, 100]; const ARRAY_LENGTH: usize = 1024; -/// Returns a friendly type name for the array type. -fn array_type_name() -> &'static str { - let id = TypeId::of::(); - if id == TypeId::of::() { - "Utf8" - } else if id == TypeId::of::() { - "Utf8View" - } else if id == TypeId::of::() { - "Float32" - } else if id == TypeId::of::() { - "Int32" - } else { - "Unknown" - } -} - -/// Builds a benchmark name from array type, list size, and null percentage. -fn bench_name(in_list_length: usize, null_percent: f64) -> String { - format!( - "in_list/{}/list={in_list_length}/nulls={}%", - array_type_name::(), - (null_percent * 100.0) as u32 - ) -} - /// Runs in_list benchmarks for a string array type across all list-size × null-ratio × string-length combinations. fn bench_string_type( c: &mut Criterion, @@ -94,6 +68,7 @@ fn bench_string_type( .then(|| random_string(rng, string_length)) }) .collect(); + let values: ArrayRef = Arc::new(values); let in_list: Vec<_> = (0..in_list_length) .map(|_| make_scalar(random_string(rng, string_length))) @@ -102,10 +77,11 @@ fn bench_string_type( do_bench( c, &format!( - "{}/str={string_length}", - bench_name::(in_list_length, null_percent) + "in_list/{}/list={in_list_length}/nulls={}%/str={string_length}", + values.data_type(), + (null_percent * 100.0) as u32 ), - Arc::new(values), + values, &in_list, ) } @@ -127,6 +103,7 @@ fn bench_numeric_type( let values: A = (0..ARRAY_LENGTH) .map(|_| rng.random_bool(1.0 - null_percent).then(|| gen_value(rng))) .collect(); + let values: ArrayRef = Arc::new(values); let in_list: Vec<_> = (0..in_list_length) .map(|_| make_scalar(gen_value(rng))) @@ -134,8 +111,12 @@ fn bench_numeric_type( do_bench( c, - &bench_name::(in_list_length, null_percent), - Arc::new(values), + &format!( + "in_list/{}/list={in_list_length}/nulls={}%", + values.data_type(), + (null_percent * 100.0) as u32 + ), + values, &in_list, ); } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 78eaf6fadf4b..1da2c1f180a5 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -26,7 +26,7 @@ use crate::physical_expr::physical_exprs_bag_equal; use crate::PhysicalExpr; use arrow::array::*; -use arrow::buffer::{BooleanBuffer, NullBuffer}; +use arrow::buffer::{BooleanBuffer, NullBuffer, ScalarBuffer}; use arrow::compute::kernels::boolean::{not, or_kleene}; use arrow::compute::{take, SortOptions}; use arrow::datatypes::*; @@ -87,30 +87,63 @@ impl StaticFilter for ArrayStaticFilter { /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. fn contains(&self, v: &dyn Array, negated: bool) -> Result { + // Handle dictionary arrays + downcast_dictionary_array! { + v => { + let values_contains = self.contains(v.values().as_ref(), negated)?; + let result = take(&values_contains, v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + // Null type comparisons always return null (SQL three-valued logic) if v.data_type() == &DataType::Null || self.in_array.data_type() == &DataType::Null { - // return Ok(BooleanArray::new(vec![None; v.len()])); - let nulls = NullBuffer::new_null(v.len()); return Ok(BooleanArray::new( BooleanBuffer::new_unset(v.len()), - Some(nulls), + Some(NullBuffer::new_null(v.len())), )); } - downcast_dictionary_array! { - v => { - let values_contains = self.contains(v.values().as_ref(), negated)?; - let result = take(&values_contains, v.keys(), None)?; - return Ok(downcast_array(result.as_ref())) - } - _ => {} + // Early exit: empty haystack means nothing can match + if self.map.is_empty() { + return if self.in_array.null_count() > 0 { + // Haystack has only nulls -> result is all null + Ok(BooleanArray::new( + BooleanBuffer::new_unset(v.len()), + Some(NullBuffer::new_null(v.len())), + )) + } else { + // Haystack is truly empty -> no matches + Ok(BooleanArray::from(vec![negated; v.len()])) + }; + } + + let haystack_has_nulls = self.in_array.null_count() != 0; + + // Fast path: no nulls in needle or haystack - skip all null handling + if v.null_count() == 0 && !haystack_has_nulls { + return with_hashes([v], &self.state, |hashes| { + let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; + let values: BooleanBuffer = (0..v.len()) + .map(|i| { + let contains = self + .map + .raw_entry() + .from_hash(hashes[i], |idx| cmp(i, *idx).is_eq()) + .is_some(); + contains != negated + }) + .collect(); + Ok(BooleanArray::new(values, None)) + }); } + // Slow path: handle nulls let needle_nulls = v.logical_nulls(); let needle_nulls = needle_nulls.as_ref(); - let haystack_has_nulls = self.in_array.null_count() != 0; with_hashes([v], &self.state, |hashes| { let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; @@ -121,11 +154,10 @@ impl StaticFilter for ArrayStaticFilter { return None; } - let hash = hashes[i]; let contains = self .map .raw_entry() - .from_hash(hash, |idx| cmp(i, *idx).is_eq()) + .from_hash(hashes[i], |idx| cmp(i, *idx).is_eq()) .is_some(); match contains { @@ -139,72 +171,253 @@ impl StaticFilter for ArrayStaticFilter { } } -fn instantiate_static_filter( - in_array: ArrayRef, -) -> Result> { - if in_array.len() <= SORTED_LOOKUP_MAX_LEN { - instantiate_sorted_filter(in_array) +// ============================================================================= +// STATIC FILTER ARCHITECTURE +// ============================================================================= +// +// This module provides optimized membership testing for SQL `IN` expressions. +// +// ## Filter Selection +// +// ```text +// instantiate_static_filter(array) +// │ +// ├─ Small primitives (1-8 bytes): +// │ ├─► BranchlessFilter (≤16 elements, branchless OR-chain) +// │ ├─► PrimitiveFilter (17-32, binary search) +// │ └─► PrimitiveFilter (>32, hash set) +// │ +// ├─ Large primitives (16 bytes, e.g., Decimal128): +// │ ├─► BranchlessFilter (≤6 elements) +// │ └─► PrimitiveFilter (>6, hash set) +// │ +// ├─ Utf8View (short strings ≤12 bytes): +// │ └─► Reinterpret as i128, then use Decimal128 filters +// │ +// └─► ArrayStaticFilter (fallback for complex types) +// ``` +// +// ## Type Normalization +// +// For equality comparison, only the bit pattern matters. We normalize types +// to reduce implementations: +// - Signed integers → Unsigned equivalents (Int32 → UInt32) +// - Floats → Unsigned equivalents (Float64 → UInt64) +// - Short Utf8View → Decimal128 (16-byte inline representation) +// +// This is implemented via `TransformingFilter`, which wraps an inner filter +// and transforms input arrays before lookup. + +// ============================================================================= +// LOOKUP STRATEGY THRESHOLDS (tuned via microbenchmarks) +// ============================================================================= + +/// Maximum list size for branchless lookup on small primitives (1-8 bytes). +const SMALL_BRANCHLESS_MAX: usize = 16; + +/// Maximum list size for binary search on small primitives (1-8 bytes). +const SMALL_BINARY_MAX: usize = 32; + +/// Maximum list size for branchless lookup on 16-byte types (Decimal128, short Utf8View). +const LARGE_BRANCHLESS_MAX: usize = 6; + +/// Maximum length for inline strings in Utf8View (stored in 16-byte view). +const UTF8VIEW_INLINE_LEN: usize = 12; + +// ============================================================================= +// FILTER STRATEGY SELECTION +// ============================================================================= + +/// The lookup strategy to use for a given data type and list size. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FilterStrategy { + Branchless, + Binary, + Hashed, + Generic, +} + +/// Determines the optimal lookup strategy based on data type and list size. +fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy { + let (max_branchless, max_binary) = match dt.primitive_width() { + Some(1 | 2 | 4 | 8) => (SMALL_BRANCHLESS_MAX, SMALL_BINARY_MAX), + Some(16) => (LARGE_BRANCHLESS_MAX, LARGE_BRANCHLESS_MAX), // skip binary for 16-byte + _ => return FilterStrategy::Generic, + }; + if len <= max_branchless { + FilterStrategy::Branchless + } else if len <= max_binary { + FilterStrategy::Binary } else { - instantiate_hashed_filter(in_array) + FilterStrategy::Hashed } } -/// Sorted filter using binary search. Best for small lists (≤8 elements). -fn instantiate_sorted_filter( +// ============================================================================= +// FILTER INSTANTIATION +// ============================================================================= + +/// Creates the optimal static filter for the given array. +fn instantiate_static_filter( in_array: ArrayRef, ) -> Result> { - match in_array.data_type() { - DataType::Int8 => Ok(Arc::new(Int8SortedFilter::try_new(&in_array)?)), - DataType::Int16 => Ok(Arc::new(Int16SortedFilter::try_new(&in_array)?)), - DataType::Int32 => Ok(Arc::new(Int32SortedFilter::try_new(&in_array)?)), - DataType::Int64 => Ok(Arc::new(Int64SortedFilter::try_new(&in_array)?)), - DataType::UInt8 => Ok(Arc::new(UInt8SortedFilter::try_new(&in_array)?)), - DataType::UInt16 => Ok(Arc::new(UInt16SortedFilter::try_new(&in_array)?)), - DataType::UInt32 => Ok(Arc::new(UInt32SortedFilter::try_new(&in_array)?)), - DataType::UInt64 => Ok(Arc::new(UInt64SortedFilter::try_new(&in_array)?)), - DataType::Float32 => Ok(Arc::new(Float32SortedFilter::try_new(&in_array)?)), - DataType::Float64 => Ok(Arc::new(Float64SortedFilter::try_new(&in_array)?)), - DataType::Utf8View => match Utf8ViewSortedFilter::try_new(&in_array) { - Some(Ok(filter)) => Ok(Arc::new(filter)), - Some(Err(e)) => Err(e), - None => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), - }, - _ => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), + let len = in_array.len(); + let dt = in_array.data_type(); + + // Special case: Utf8View with short strings can be reinterpreted as i128 + if matches!(dt, DataType::Utf8View) && utf8view_all_short_strings(in_array.as_ref()) { + return create_utf8view_filter(&in_array, |arr| { + if len <= LARGE_BRANCHLESS_MAX { + instantiate_branchless_filter_for_type::(arr) + } else { + Ok(Arc::new( + PrimitiveFilter::>::try_new(&arr)?, + )) + } + }); + } + + match select_strategy(dt, len) { + FilterStrategy::Branchless => dispatch_filter(&in_array, dispatch_branchless), + FilterStrategy::Binary => dispatch_filter(&in_array, dispatch_sorted), + FilterStrategy::Hashed => dispatch_filter(&in_array, dispatch_hashed), + FilterStrategy::Generic => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), } } -/// Hashed filter using HashSet. Best for larger lists (>8 elements). -fn instantiate_hashed_filter( - in_array: ArrayRef, -) -> Result> { - match in_array.data_type() { - DataType::Int8 => Ok(Arc::new(Int8HashedFilter::try_new(&in_array)?)), - DataType::Int16 => Ok(Arc::new(Int16HashedFilter::try_new(&in_array)?)), - DataType::Int32 => Ok(Arc::new(Int32HashedFilter::try_new(&in_array)?)), - DataType::Int64 => Ok(Arc::new(Int64HashedFilter::try_new(&in_array)?)), - DataType::UInt8 => Ok(Arc::new(UInt8HashedFilter::try_new(&in_array)?)), - DataType::UInt16 => Ok(Arc::new(UInt16HashedFilter::try_new(&in_array)?)), - DataType::UInt32 => Ok(Arc::new(UInt32HashedFilter::try_new(&in_array)?)), - DataType::UInt64 => Ok(Arc::new(UInt64HashedFilter::try_new(&in_array)?)), - // Floats don't implement Hash, fall through to generic - DataType::Utf8View => match Utf8ViewHashedFilter::try_new(&in_array) { - Some(Ok(filter)) => Ok(Arc::new(filter)), - Some(Err(e)) => Err(e), - None => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), - }, - _ => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), +/// Generic filter dispatcher with fallback to ArrayStaticFilter. +fn dispatch_filter( + in_array: &ArrayRef, + dispatch: F, +) -> Result> +where + F: Fn(&ArrayRef) -> Option>>, +{ + dispatch(in_array).unwrap_or_else(|| { + Ok(Arc::new(ArrayStaticFilter::try_new(Arc::clone(in_array))?)) + }) +} + +// ============================================================================= +// TYPE DISPATCH +// ============================================================================= +// +// Dispatches filter creation to the appropriate primitive type. +// - Unsigned types: use directly +// - Signed/Float types: reinterpret as unsigned (same bit pattern for equality) + +/// Dispatch macro that routes to the appropriate type-specific filter creation. +/// Uses $direct for unsigned types and $reinterpret for signed/float types. +macro_rules! dispatch_primitive { + ($arr:expr, $direct:ident, $reinterpret:ident) => { + match $arr.data_type() { + DataType::UInt8 => Some($direct::($arr)), + DataType::UInt16 => Some($direct::($arr)), + DataType::UInt32 => Some($direct::($arr)), + DataType::UInt64 => Some($direct::($arr)), + DataType::Int8 => Some($reinterpret::($arr)), + DataType::Int16 => Some($reinterpret::($arr)), + DataType::Int32 => Some($reinterpret::($arr)), + DataType::Int64 => Some($reinterpret::($arr)), + DataType::Float32 => Some($reinterpret::($arr)), + DataType::Float64 => Some($reinterpret::($arr)), + _ => None, + } + }; +} + +fn dispatch_branchless( + arr: &ArrayRef, +) -> Option>> { + fn direct( + arr: &ArrayRef, + ) -> Result> + where + T::Native: Copy + Default + PartialEq + Send + Sync + 'static, + { + instantiate_branchless_filter_for_type::(Arc::clone(arr)) + } + fn reinterpret( + arr: &ArrayRef, + ) -> Result> + where + D::Native: Copy + Default + PartialEq + Send + Sync + 'static, + { + create_reinterpreting_filter::(arr, |a| { + instantiate_branchless_filter_for_type::(a) + }) + } + dispatch_primitive!(arr, direct, reinterpret) +} + +fn dispatch_sorted( + arr: &ArrayRef, +) -> Option>> { + fn direct( + arr: &ArrayRef, + ) -> Result> + where + T::Native: Ord + Send + Sync + 'static, + { + Ok(Arc::new( + PrimitiveFilter::>::try_new(arr)?, + )) + } + fn reinterpret( + arr: &ArrayRef, + ) -> Result> + where + D::Native: Ord + Send + Sync + 'static, + { + create_reinterpreting_filter::(arr, |a| { + Ok(Arc::new( + PrimitiveFilter::>::try_new(&a)?, + )) + }) + } + dispatch_primitive!(arr, direct, reinterpret) +} + +fn dispatch_hashed( + arr: &ArrayRef, +) -> Option>> { + if matches!(arr.data_type(), DataType::Decimal128(_, _)) { + return Some( + PrimitiveFilter::>::try_new(arr) + .map(|f| Arc::new(f) as _), + ); + } + fn direct( + arr: &ArrayRef, + ) -> Result> + where + T::Native: Hash + Eq + Send + Sync + 'static, + { + Ok(Arc::new( + PrimitiveFilter::>::try_new(arr)?, + )) + } + fn reinterpret( + arr: &ArrayRef, + ) -> Result> + where + D::Native: Hash + Eq + Send + Sync + 'static, + { + create_reinterpreting_filter::(arr, |a| { + Ok(Arc::new( + PrimitiveFilter::>::try_new(&a)?, + )) + }) } + dispatch_primitive!(arr, direct, reinterpret) } +// ============================================================================= +// GENERIC ARRAY FILTER (fallback for unsupported types) +// ============================================================================= + impl ArrayStaticFilter { - /// Computes a [`StaticFilter`] for the provided [`Array`] if there - /// are nulls present or there are more than the configured number of - /// elements. - /// - /// Note: This is split into a separate function as higher-rank trait bounds currently - /// cause type inference to misbehave fn try_new(in_array: ArrayRef) -> Result { - // Null type has no natural order - return empty hash set if in_array.data_type() == &DataType::Null { return Ok(ArrayStaticFilter { in_array, @@ -218,7 +431,6 @@ impl ArrayStaticFilter { with_hashes([&in_array], &state, |hashes| -> Result<()> { let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; - let insert_value = |idx| { let hash = hashes[idx]; if let RawEntryMut::Vacant(v) = map @@ -228,7 +440,6 @@ impl ArrayStaticFilter { v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); } }; - match in_array.nulls() { Some(nulls) => { BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) @@ -236,7 +447,6 @@ impl ArrayStaticFilter { } None => (0..in_array.len()).for_each(insert_value), } - Ok(()) })?; @@ -248,21 +458,20 @@ impl ArrayStaticFilter { } } -/// Threshold for switching from sorted Vec (binary search) to HashSet -/// For small lists, binary search has better cache locality and lower overhead -/// Maximum list size for using sorted lookup (binary search). -/// Lists with more elements use hash lookup instead. -const SORTED_LOOKUP_MAX_LEN: usize = 8; +// ============================================================================= +// RESULT BUILDER FOR IN LIST OPERATIONS +// ============================================================================= +// +// Truth table for (needle_nulls, haystack_has_nulls, negated): +// (Some, true, false) → values: valid & contains, nulls: valid & contains +// (None, true, false) → values: contains, nulls: contains +// (Some, true, true) → values: valid ^ (valid & contains), nulls: valid & contains +// (None, true, true) → values: !contains, nulls: contains +// (Some, false, false) → values: valid & contains, nulls: valid +// (Some, false, true) → values: valid & !contains, nulls: valid +// (None, false, false) → values: contains, nulls: none +// (None, false, true) → values: !contains, nulls: none -/// Helper to build a BooleanArray result for IN list operations. -/// Handles SQL three-valued logic for NULL values. -/// -/// # Arguments -/// * `len` - Number of elements in the needle array -/// * `needle_nulls` - Optional validity buffer from the needle array -/// * `haystack_has_nulls` - Whether the IN list contains NULL values -/// * `negated` - Whether this is a NOT IN operation -/// * `contains` - Closure that returns whether needle[i] is found in the haystack #[inline] fn build_in_list_result( len: usize, @@ -274,125 +483,82 @@ fn build_in_list_result( where C: Fn(usize) -> bool, { - // Use collect_bool for all paths - it's vectorized and faster than element-by-element append. - // Match on (needle_has_nulls, haystack_has_nulls, negated) to specialize each case. + let contains_buf = BooleanBuffer::collect_bool(len, &contains); + build_result_from_contains(needle_nulls, haystack_has_nulls, negated, contains_buf) +} + +#[inline] +fn build_result_from_contains( + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains_buf: BooleanBuffer, +) -> BooleanArray { match (needle_nulls, haystack_has_nulls, negated) { - // Haystack has nulls: result is NULL when not found (might match the NULL) - // values_buf == nulls_buf, so compute once and clone - (Some(validity), true, false) => { - let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); - let buf = validity.inner() & &contains_buf; + (Some(v), true, false) => { + let buf = v.inner() & &contains_buf; BooleanArray::new(buf.clone(), Some(NullBuffer::new(buf))) } (None, true, false) => { - let buf = BooleanBuffer::collect_bool(len, |i| contains(i)); - BooleanArray::new(buf.clone(), Some(NullBuffer::new(buf))) + BooleanArray::new(contains_buf.clone(), Some(NullBuffer::new(contains_buf))) } - (Some(validity), true, true) => { - // Compute nulls_buf via SIMD AND, then derive values_buf via XOR. - // Uses identity: A & !B = A ^ (A & B) to get values from nulls. - let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); - let nulls_buf = validity.inner() & &contains_buf; - let values_buf = validity.inner() ^ &nulls_buf; // valid & !contains - BooleanArray::new(values_buf, Some(NullBuffer::new(nulls_buf))) + (Some(v), true, true) => { + let nulls = v.inner() & &contains_buf; + BooleanArray::new(v.inner() ^ &nulls, Some(NullBuffer::new(nulls))) } (None, true, true) => { - // No needle nulls, but haystack has nulls - let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); - let values_buf = !&contains_buf; - BooleanArray::new(values_buf, Some(NullBuffer::new(contains_buf))) - } - // Only needle has nulls: nulls_buf is just validity (reuse it directly!) - (Some(validity), false, false) => { - let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); - let values_buf = validity.inner() & &contains_buf; - BooleanArray::new(values_buf, Some(validity.clone())) - } - (Some(validity), false, true) => { - let contains_buf = BooleanBuffer::collect_bool(len, |i| contains(i)); - let values_buf = validity.inner() & &(!&contains_buf); - BooleanArray::new(values_buf, Some(validity.clone())) + BooleanArray::new(!&contains_buf, Some(NullBuffer::new(contains_buf))) } - // No nulls anywhere: no validity buffer needed - (None, false, false) => { - let buf = BooleanBuffer::collect_bool(len, |i| contains(i)); - BooleanArray::new(buf, None) + (Some(v), false, false) => { + BooleanArray::new(v.inner() & &contains_buf, Some(v.clone())) } - (None, false, true) => { - let buf = BooleanBuffer::collect_bool(len, |i| !contains(i)); - BooleanArray::new(buf, None) + (Some(v), false, true) => { + BooleanArray::new(v.inner() & &(!&contains_buf), Some(v.clone())) } + (None, false, false) => BooleanArray::new(contains_buf, None), + (None, false, true) => BooleanArray::new(!&contains_buf, None), } } -/// Sorted lookup using binary search. Best for small lists (< 8 elements). +// ============================================================================= +// LOOKUP STRATEGY TRAIT AND IMPLEMENTATIONS +// ============================================================================= + +trait LookupStrategy: Send + Sync { + fn new(values: Vec) -> Self; + fn contains(&self, value: &T) -> bool; +} + struct SortedLookup(Vec); -impl SortedLookup { +impl LookupStrategy for SortedLookup { fn new(mut values: Vec) -> Self { values.sort_unstable(); values.dedup(); Self(values) } - #[inline] fn contains(&self, value: &T) -> bool { self.0.binary_search(value).is_ok() } } -/// Sorted lookup for f32 using total_cmp (floats don't implement Ord due to NaN). -struct F32SortedLookup(Vec); - -impl F32SortedLookup { - fn new(mut values: Vec) -> Self { - values.sort_unstable_by(|a, b| a.total_cmp(b)); - values.dedup_by(|a, b| a.total_cmp(b).is_eq()); - Self(values) - } - - #[inline] - fn contains(&self, value: &f32) -> bool { - self.0 - .binary_search_by(|probe| probe.total_cmp(value)) - .is_ok() - } -} - -/// Sorted lookup for f64 using total_cmp (floats don't implement Ord due to NaN). -struct F64SortedLookup(Vec); - -impl F64SortedLookup { - fn new(mut values: Vec) -> Self { - values.sort_unstable_by(|a, b| a.total_cmp(b)); - values.dedup_by(|a, b| a.total_cmp(b).is_eq()); - Self(values) - } - - #[inline] - fn contains(&self, value: &f64) -> bool { - self.0 - .binary_search_by(|probe| probe.total_cmp(value)) - .is_ok() - } -} - -/// Hash-based lookup. Best for larger lists (>= 8 elements). struct HashedLookup(HashSet); -impl HashedLookup { +impl LookupStrategy for HashedLookup { fn new(values: Vec) -> Self { Self(values.into_iter().collect()) } - #[inline] fn contains(&self, value: &T) -> bool { self.0.contains(value) } } -/// Helper macro for dictionary array handling in StaticFilter::contains -/// This pattern is the same across all filter implementations +// ============================================================================= +// DICTIONARY ARRAY HANDLING +// ============================================================================= + macro_rules! handle_dictionary { ($self:ident, $v:ident, $negated:ident) => { downcast_dictionary_array! { @@ -406,259 +572,231 @@ macro_rules! handle_dictionary { }; } -// Base macro to generate sorted StaticFilter with explicit lookup type. -macro_rules! sorted_static_filter_impl { - ($Name:ident, $ArrowType:ty, $LookupType:ty) => { - struct $Name { - null_count: usize, - values: $LookupType, - } - - impl $Name { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = - in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| { - exec_datafusion_err!( - "Failed to downcast an array to a '{}' array", - stringify!($ArrowType) - ) - })?; - - let null_count = in_array.null_count(); - let values = <$LookupType>::new(in_array.iter().flatten().collect()); - - Ok(Self { null_count, values }) - } - } - - impl StaticFilter for $Name { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - handle_dictionary!(self, v, negated); - - let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { - exec_datafusion_err!( - "Failed to downcast an array to a '{}' array", - stringify!($ArrowType) - ) - })?; - - let values = v.values(); - Ok(build_in_list_result( - v.len(), - v.nulls(), - self.null_count > 0, - negated, - // SAFETY: i < len is guaranteed by build_in_list_result - |i| self.values.contains(unsafe { values.get_unchecked(i) }), - )) - } - } - }; -} - -// Convenience macro for integer types (derives SortedLookup from ArrowType). -macro_rules! sorted_static_filter { - ($Name:ident, $ArrowType:ty) => { - sorted_static_filter_impl!( - $Name, - $ArrowType, - SortedLookup<<$ArrowType as ArrowPrimitiveType>::Native> - ); - }; -} - -// Macro to generate hashed StaticFilter for primitive types using HashedLookup. -macro_rules! hashed_static_filter { - ($Name:ident, $ArrowType:ty) => { - struct $Name { - null_count: usize, - values: HashedLookup<<$ArrowType as ArrowPrimitiveType>::Native>, - } - - impl $Name { - fn try_new(in_array: &ArrayRef) -> Result { - let in_array = - in_array.as_primitive_opt::<$ArrowType>().ok_or_else(|| { - exec_datafusion_err!( - "Failed to downcast an array to a '{}' array", - stringify!($ArrowType) - ) - })?; - - let null_count = in_array.null_count(); - let values = HashedLookup::new(in_array.iter().flatten().collect()); - - Ok(Self { null_count, values }) - } - } - - impl StaticFilter for $Name { - fn null_count(&self) -> usize { - self.null_count - } - - fn contains(&self, v: &dyn Array, negated: bool) -> Result { - handle_dictionary!(self, v, negated); - - let v = v.as_primitive_opt::<$ArrowType>().ok_or_else(|| { - exec_datafusion_err!( - "Failed to downcast an array to a '{}' array", - stringify!($ArrowType) - ) - })?; - - let values = v.values(); - Ok(build_in_list_result( - v.len(), - v.nulls(), - self.null_count > 0, - negated, - // SAFETY: i < len is guaranteed by build_in_list_result - |i| self.values.contains(unsafe { values.get_unchecked(i) }), - )) - } - } - }; -} - -// Generate specialized filters for integer types (sorted for small lists, hashed for large). -sorted_static_filter!(Int8SortedFilter, Int8Type); -sorted_static_filter!(Int16SortedFilter, Int16Type); -sorted_static_filter!(Int32SortedFilter, Int32Type); -sorted_static_filter!(Int64SortedFilter, Int64Type); -sorted_static_filter!(UInt8SortedFilter, UInt8Type); -sorted_static_filter!(UInt16SortedFilter, UInt16Type); -sorted_static_filter!(UInt32SortedFilter, UInt32Type); -sorted_static_filter!(UInt64SortedFilter, UInt64Type); - -hashed_static_filter!(Int8HashedFilter, Int8Type); -hashed_static_filter!(Int16HashedFilter, Int16Type); -hashed_static_filter!(Int32HashedFilter, Int32Type); -hashed_static_filter!(Int64HashedFilter, Int64Type); -hashed_static_filter!(UInt8HashedFilter, UInt8Type); -hashed_static_filter!(UInt16HashedFilter, UInt16Type); -hashed_static_filter!(UInt32HashedFilter, UInt32Type); -hashed_static_filter!(UInt64HashedFilter, UInt64Type); - -// Float types: sorted only (floats don't implement Hash/Eq due to NaN). -sorted_static_filter_impl!(Float32SortedFilter, Float32Type, F32SortedLookup); -sorted_static_filter_impl!(Float64SortedFilter, Float64Type, F64SortedLookup); - -/// Maximum length for inline strings in Utf8View. -/// Strings ≤12 bytes are stored entirely inline in the u128 view. -const UTF8VIEW_INLINE_LEN: usize = 12; - -/// Extract string length from a StringView u128 representation -/// Layout: bytes 0-3 = length (u32 little-endian), bytes 4-15 = inline data -#[inline] -fn view_len(view: u128) -> usize { - (view as u32) as usize -} - -/// Returns (null_count, views) if all non-null strings are ≤12 bytes, otherwise None. -fn collect_short_string_views(in_array: &ArrayRef) -> Option<(usize, Vec)> { - let in_array = in_array.as_string_view_opt()?; - let raw_views = in_array.views(); - - // Check that all non-null strings are ≤12 bytes (inline) - for i in 0..in_array.len() { - if in_array.is_valid(i) && view_len(raw_views[i]) > UTF8VIEW_INLINE_LEN { - return None; // Has long strings, use generic filter - } - } - - let views: Vec = (0..in_array.len()) - .filter(|&i| in_array.is_valid(i)) - .map(|i| raw_views[i]) - .collect(); +// ============================================================================= +// UNIFIED PRIMITIVE FILTER +// ============================================================================= - Some((in_array.null_count(), views)) -} - -/// Sorted filter for Utf8View when all values are short (≤12 bytes inline). -/// Uses binary search over sorted raw u128 views. Best for small lists. -struct Utf8ViewSortedFilter { +struct PrimitiveFilter { null_count: usize, - values: SortedLookup, + lookup: S, + _phantom: std::marker::PhantomData T>, } -impl Utf8ViewSortedFilter { - fn try_new(in_array: &ArrayRef) -> Option> { - let (null_count, views) = collect_short_string_views(in_array)?; - Some(Ok(Self { - null_count, - values: SortedLookup::new(views), - })) +impl> PrimitiveFilter { + fn try_new(in_array: &ArrayRef) -> Result { + let arr = in_array.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!( + "PrimitiveFilter: expected {} array", + std::any::type_name::() + ) + })?; + Ok(Self { + null_count: arr.null_count(), + lookup: S::new(arr.iter().flatten().collect()), + _phantom: std::marker::PhantomData, + }) } } -impl StaticFilter for Utf8ViewSortedFilter { +impl StaticFilter for PrimitiveFilter +where + T: ArrowPrimitiveType + 'static, + T::Native: Send + Sync + 'static, + S: LookupStrategy + 'static, +{ fn null_count(&self) -> usize { self.null_count } fn contains(&self, v: &dyn Array, negated: bool) -> Result { handle_dictionary!(self, v, negated); - - let v = v.as_string_view_opt().ok_or_else(|| { - exec_datafusion_err!("Failed to downcast array to StringViewArray") + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!( + "PrimitiveFilter: expected {} array", + std::any::type_name::() + ) })?; - - let views = v.views(); + let values = v.values(); Ok(build_in_list_result( v.len(), v.nulls(), self.null_count > 0, negated, - |i| self.values.contains(&views[i]), + |i| self.lookup.contains(unsafe { values.get_unchecked(i) }), )) } } -/// Hashed filter for Utf8View when all values are short (≤12 bytes inline). -/// Uses hash lookup over u128 views. Best for large lists. -struct Utf8ViewHashedFilter { +// ============================================================================= +// BRANCHLESS FILTER (Const Generic for Small Lists) +// ============================================================================= + +struct BranchlessFilter { null_count: usize, - values: HashedLookup, + values: [T::Native; N], } -impl Utf8ViewHashedFilter { +impl BranchlessFilter +where + T::Native: Copy + Default + PartialEq, +{ fn try_new(in_array: &ArrayRef) -> Option> { - let (null_count, views) = collect_short_string_views(in_array)?; + let in_array = in_array.as_primitive_opt::()?; + let non_null_count = in_array.len() - in_array.null_count(); + if non_null_count != N { + return None; + } + let values: Vec<_> = in_array.iter().flatten().collect(); + let mut arr = [T::Native::default(); N]; + arr.copy_from_slice(&values); Some(Ok(Self { - null_count, - values: HashedLookup::new(views), + null_count: in_array.null_count(), + values: arr, })) } + + #[inline(always)] + fn check(&self, needle: T::Native) -> bool { + self.values + .iter() + .fold(false, |acc, &v| acc | (v == needle)) + } } -impl StaticFilter for Utf8ViewHashedFilter { +impl StaticFilter for BranchlessFilter +where + T::Native: Copy + Default + PartialEq + Send + Sync, +{ fn null_count(&self) -> usize { self.null_count } fn contains(&self, v: &dyn Array, negated: bool) -> Result { handle_dictionary!(self, v, negated); - - let v = v.as_string_view_opt().ok_or_else(|| { - exec_datafusion_err!("Failed to downcast array to StringViewArray") + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("Failed to downcast array to primitive type") })?; - - let views = v.views(); + let input_values = v.values(); Ok(build_in_list_result( v.len(), v.nulls(), self.null_count > 0, negated, - |i| self.values.contains(&views[i]), + #[inline(always)] + |i| self.check(unsafe { *input_values.get_unchecked(i) }), )) } } +fn instantiate_branchless_filter_for_type( + in_array: ArrayRef, +) -> Result> +where + T::Native: Copy + Default + PartialEq + Send + Sync + 'static, +{ + macro_rules! try_branchless { + ($($n:literal),*) => { + $(if let Some(Ok(f)) = BranchlessFilter::::try_new(&in_array) { + return Ok(Arc::new(f)); + })* + }; + } + try_branchless!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16); + Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) +} + +// ============================================================================= +// TRANSFORMING FILTER (Unified Type Reinterpretation) +// ============================================================================= + +struct TransformingFilter { + inner: Arc, + transform: F, +} + +impl StaticFilter for TransformingFilter +where + F: Fn(&dyn Array) -> ArrayRef + Send + Sync, +{ + fn null_count(&self) -> usize { + self.inner.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + self.inner.contains((self.transform)(v).as_ref(), negated) + } +} + +/// Creates a TransformingFilter that reinterprets arrays from type S to type D. +fn create_reinterpreting_filter( + in_array: &ArrayRef, + create_inner: F, +) -> Result> +where + S: ArrowPrimitiveType + 'static, + D: ArrowPrimitiveType + 'static, + F: FnOnce(ArrayRef) -> Result>, +{ + let reinterpreted = reinterpret_primitive::(in_array.as_ref()); + let inner = create_inner(reinterpreted)?; + Ok(Arc::new(TransformingFilter { + inner, + transform: |v: &dyn Array| reinterpret_primitive::(v), + })) +} + +/// Creates a TransformingFilter for Utf8View arrays reinterpreted as Decimal128. +fn create_utf8view_filter( + in_array: &ArrayRef, + create_inner: F, +) -> Result> +where + F: FnOnce(ArrayRef) -> Result>, +{ + let reinterpreted = reinterpret_utf8view_as_decimal128(in_array.as_ref()); + let inner = create_inner(reinterpreted)?; + Ok(Arc::new(TransformingFilter { + inner, + transform: reinterpret_utf8view_as_decimal128, + })) +} + +// ============================================================================= +// PRIMITIVE TYPE REINTERPRETATION +// ============================================================================= + +#[inline] +fn reinterpret_primitive( + array: &dyn Array, +) -> ArrayRef { + let source = array.as_primitive::(); + let buffer: ScalarBuffer = source.values().inner().clone().into(); + Arc::new(PrimitiveArray::::new(buffer, source.nulls().cloned())) +} + +// ============================================================================= +// UTF8VIEW REINTERPRETATION (short strings ≤12 bytes → Decimal128) +// ============================================================================= + +#[inline] +fn utf8view_all_short_strings(array: &dyn Array) -> bool { + let sv = array.as_string_view(); + sv.views().iter().enumerate().all(|(i, &view)| { + !sv.is_valid(i) || (view as u32) as usize <= UTF8VIEW_INLINE_LEN + }) +} + +#[inline] +fn reinterpret_utf8view_as_decimal128(array: &dyn Array) -> ArrayRef { + let sv = array.as_string_view(); + let buffer: ScalarBuffer = sv.views().inner().clone().into(); + Arc::new(PrimitiveArray::::new( + buffer, + sv.nulls().cloned(), + )) +} + /// Evaluates the list of expressions into an array, flattening any dictionaries fn evaluate_list( list: &[Arc],