diff --git a/datafusion/physical-expr/benches/in_list.rs b/datafusion/physical-expr/benches/in_list.rs index 778204055bbd..73b40490526f 100644 --- a/datafusion/physical-expr/benches/in_list.rs +++ b/datafusion/physical-expr/benches/in_list.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Array, ArrayRef, Float32Array, Int32Array, StringArray}; +use arrow::array::{ + Array, ArrayRef, Float32Array, Int32Array, StringArray, StringViewArray, +}; use arrow::datatypes::{Field, Schema}; use arrow::record_batch::RecordBatch; use criterion::{criterion_group, criterion_main, Criterion}; @@ -26,6 +28,7 @@ use rand::prelude::*; use std::hint::black_box; use std::sync::Arc; +/// Measures how long `in_list(col("a"), exprs)` takes to evaluate against a single RecordBatch. fn do_bench(c: &mut Criterion, name: &str, values: ArrayRef, exprs: &[ScalarValue]) { let schema = Schema::new(vec![Field::new("a", values.data_type().clone(), true)]); let exprs = exprs.iter().map(|s| lit(s.clone())).collect(); @@ -37,77 +40,110 @@ fn do_bench(c: &mut Criterion, name: &str, values: ArrayRef, exprs: &[ScalarValu }); } +/// Generates a random alphanumeric string of the specified length. fn random_string(rng: &mut StdRng, len: usize) -> String { let value = rng.sample_iter(&Alphanumeric).take(len).collect(); String::from_utf8(value).unwrap() } -fn do_benches( +const IN_LIST_LENGTHS: [usize; 4] = [3, 6, 8, 100]; +const NULL_PERCENTS: [f64; 2] = [0., 0.2]; +const STRING_LENGTHS: [usize; 3] = [3, 12, 100]; +const ARRAY_LENGTH: usize = 1024; + +/// Runs in_list benchmarks for a string array type across all list-size × null-ratio × string-length combinations. +fn bench_string_type( c: &mut Criterion, - array_length: usize, - in_list_length: usize, - null_percent: f64, -) { - let mut rng = StdRng::seed_from_u64(120320); - for string_length in [5, 10, 20] { - let values: StringArray = (0..array_length) - .map(|_| { - rng.random_bool(null_percent) - .then(|| random_string(&mut rng, string_length)) - }) - .collect(); + rng: &mut StdRng, + make_scalar: fn(String) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + for in_list_length in IN_LIST_LENGTHS { + for null_percent in NULL_PERCENTS { + for string_length in STRING_LENGTHS { + let values: A = (0..ARRAY_LENGTH) + .map(|_| { + rng.random_bool(1.0 - null_percent) + .then(|| random_string(rng, string_length)) + }) + .collect(); + let values: ArrayRef = Arc::new(values); - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::from(random_string(&mut rng, string_length))) - .collect(); + let in_list: Vec<_> = (0..in_list_length) + .map(|_| make_scalar(random_string(rng, string_length))) + .collect(); - do_bench( - c, - &format!( - "in_list_utf8({string_length}) ({array_length}, {null_percent}) IN ({in_list_length}, 0)" - ), - Arc::new(values), - &in_list, - ) + do_bench( + c, + &format!( + "in_list/{}/list={in_list_length}/nulls={}%/str={string_length}", + values.data_type(), + (null_percent * 100.0) as u32 + ), + values, + &in_list, + ) + } + } } +} - let values: Float32Array = (0..array_length) - .map(|_| rng.random_bool(null_percent).then(|| rng.random())) - .collect(); +/// Runs in_list benchmarks for a numeric array type across all list-size × null-ratio combinations. +fn bench_numeric_type( + c: &mut Criterion, + rng: &mut StdRng, + mut gen_value: impl FnMut(&mut StdRng) -> T, + make_scalar: fn(T) -> ScalarValue, +) where + A: Array + FromIterator> + 'static, +{ + for in_list_length in IN_LIST_LENGTHS { + for null_percent in NULL_PERCENTS { + let values: A = (0..ARRAY_LENGTH) + .map(|_| rng.random_bool(1.0 - null_percent).then(|| gen_value(rng))) + .collect(); + let values: ArrayRef = Arc::new(values); - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Float32(Some(rng.random()))) - .collect(); + let in_list: Vec<_> = (0..in_list_length) + .map(|_| make_scalar(gen_value(rng))) + .collect(); - do_bench( - c, - &format!("in_list_f32 ({array_length}, {null_percent}) IN ({in_list_length}, 0)"), - Arc::new(values), - &in_list, - ); + do_bench( + c, + &format!( + "in_list/{}/list={in_list_length}/nulls={}%", + values.data_type(), + (null_percent * 100.0) as u32 + ), + values, + &in_list, + ); + } + } +} - let values: Int32Array = (0..array_length) - .map(|_| rng.random_bool(null_percent).then(|| rng.random())) - .collect(); +/// Entry point: registers in_list benchmarks for Utf8, Utf8View, Float32, and Int32 arrays. +fn criterion_benchmark(c: &mut Criterion) { + let mut rng = StdRng::seed_from_u64(120320); - let in_list: Vec<_> = (0..in_list_length) - .map(|_| ScalarValue::Int32(Some(rng.random()))) - .collect(); + // Benchmarks for string array types (Utf8, Utf8View) + bench_string_type::(c, &mut rng, |s| ScalarValue::Utf8(Some(s))); + bench_string_type::(c, &mut rng, |s| ScalarValue::Utf8View(Some(s))); - do_bench( + // Benchmarks for numeric types + bench_numeric_type::( c, - &format!("in_list_i32 ({array_length}, {null_percent}) IN ({in_list_length}, 0)"), - Arc::new(values), - &in_list, - ) -} - -fn criterion_benchmark(c: &mut Criterion) { - for in_list_length in [1, 3, 10, 100] { - for null_percent in [0., 0.2] { - do_benches(c, 1024, in_list_length, null_percent) - } - } + &mut rng, + |rng| rng.random(), + |v| ScalarValue::Float32(Some(v)), + ); + bench_numeric_type::( + c, + &mut rng, + |rng| rng.random(), + |v| ScalarValue::Int32(Some(v)), + ); } criterion_group!(benches, criterion_benchmark); diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 95029c1efe74..1da2c1f180a5 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -26,7 +26,7 @@ use crate::physical_expr::physical_exprs_bag_equal; use crate::PhysicalExpr; use arrow::array::*; -use arrow::buffer::BooleanBuffer; +use arrow::buffer::{BooleanBuffer, NullBuffer, ScalarBuffer}; use arrow::compute::kernels::boolean::{not, or_kleene}; use arrow::compute::{take, SortOptions}; use arrow::datatypes::*; @@ -87,13 +87,7 @@ impl StaticFilter for ArrayStaticFilter { /// Checks if values in `v` are contained in the `in_array` using this hash set for lookup. fn contains(&self, v: &dyn Array, negated: bool) -> Result { - // Null type comparisons always return null (SQL three-valued logic) - if v.data_type() == &DataType::Null - || self.in_array.data_type() == &DataType::Null - { - return Ok(BooleanArray::from(vec![None; v.len()])); - } - + // Handle dictionary arrays downcast_dictionary_array! { v => { let values_contains = self.contains(v.values().as_ref(), negated)?; @@ -103,9 +97,53 @@ impl StaticFilter for ArrayStaticFilter { _ => {} } + // Null type comparisons always return null (SQL three-valued logic) + if v.data_type() == &DataType::Null + || self.in_array.data_type() == &DataType::Null + { + return Ok(BooleanArray::new( + BooleanBuffer::new_unset(v.len()), + Some(NullBuffer::new_null(v.len())), + )); + } + + // Early exit: empty haystack means nothing can match + if self.map.is_empty() { + return if self.in_array.null_count() > 0 { + // Haystack has only nulls -> result is all null + Ok(BooleanArray::new( + BooleanBuffer::new_unset(v.len()), + Some(NullBuffer::new_null(v.len())), + )) + } else { + // Haystack is truly empty -> no matches + Ok(BooleanArray::from(vec![negated; v.len()])) + }; + } + + let haystack_has_nulls = self.in_array.null_count() != 0; + + // Fast path: no nulls in needle or haystack - skip all null handling + if v.null_count() == 0 && !haystack_has_nulls { + return with_hashes([v], &self.state, |hashes| { + let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; + let values: BooleanBuffer = (0..v.len()) + .map(|i| { + let contains = self + .map + .raw_entry() + .from_hash(hashes[i], |idx| cmp(i, *idx).is_eq()) + .is_some(); + contains != negated + }) + .collect(); + Ok(BooleanArray::new(values, None)) + }); + } + + // Slow path: handle nulls let needle_nulls = v.logical_nulls(); let needle_nulls = needle_nulls.as_ref(); - let haystack_has_nulls = self.in_array.null_count() != 0; with_hashes([v], &self.state, |hashes| { let cmp = make_comparator(v, &self.in_array, SortOptions::default())?; @@ -116,11 +154,10 @@ impl StaticFilter for ArrayStaticFilter { return None; } - let hash = hashes[i]; let contains = self .map .raw_entry() - .from_hash(hash, |idx| cmp(i, *idx).is_eq()) + .from_hash(hashes[i], |idx| cmp(i, *idx).is_eq()) .is_some(); match contains { @@ -134,27 +171,253 @@ impl StaticFilter for ArrayStaticFilter { } } +// ============================================================================= +// STATIC FILTER ARCHITECTURE +// ============================================================================= +// +// This module provides optimized membership testing for SQL `IN` expressions. +// +// ## Filter Selection +// +// ```text +// instantiate_static_filter(array) +// │ +// ├─ Small primitives (1-8 bytes): +// │ ├─► BranchlessFilter (≤16 elements, branchless OR-chain) +// │ ├─► PrimitiveFilter (17-32, binary search) +// │ └─► PrimitiveFilter (>32, hash set) +// │ +// ├─ Large primitives (16 bytes, e.g., Decimal128): +// │ ├─► BranchlessFilter (≤6 elements) +// │ └─► PrimitiveFilter (>6, hash set) +// │ +// ├─ Utf8View (short strings ≤12 bytes): +// │ └─► Reinterpret as i128, then use Decimal128 filters +// │ +// └─► ArrayStaticFilter (fallback for complex types) +// ``` +// +// ## Type Normalization +// +// For equality comparison, only the bit pattern matters. We normalize types +// to reduce implementations: +// - Signed integers → Unsigned equivalents (Int32 → UInt32) +// - Floats → Unsigned equivalents (Float64 → UInt64) +// - Short Utf8View → Decimal128 (16-byte inline representation) +// +// This is implemented via `TransformingFilter`, which wraps an inner filter +// and transforms input arrays before lookup. + +// ============================================================================= +// LOOKUP STRATEGY THRESHOLDS (tuned via microbenchmarks) +// ============================================================================= + +/// Maximum list size for branchless lookup on small primitives (1-8 bytes). +const SMALL_BRANCHLESS_MAX: usize = 16; + +/// Maximum list size for binary search on small primitives (1-8 bytes). +const SMALL_BINARY_MAX: usize = 32; + +/// Maximum list size for branchless lookup on 16-byte types (Decimal128, short Utf8View). +const LARGE_BRANCHLESS_MAX: usize = 6; + +/// Maximum length for inline strings in Utf8View (stored in 16-byte view). +const UTF8VIEW_INLINE_LEN: usize = 12; + +// ============================================================================= +// FILTER STRATEGY SELECTION +// ============================================================================= + +/// The lookup strategy to use for a given data type and list size. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum FilterStrategy { + Branchless, + Binary, + Hashed, + Generic, +} + +/// Determines the optimal lookup strategy based on data type and list size. +fn select_strategy(dt: &DataType, len: usize) -> FilterStrategy { + let (max_branchless, max_binary) = match dt.primitive_width() { + Some(1 | 2 | 4 | 8) => (SMALL_BRANCHLESS_MAX, SMALL_BINARY_MAX), + Some(16) => (LARGE_BRANCHLESS_MAX, LARGE_BRANCHLESS_MAX), // skip binary for 16-byte + _ => return FilterStrategy::Generic, + }; + if len <= max_branchless { + FilterStrategy::Branchless + } else if len <= max_binary { + FilterStrategy::Binary + } else { + FilterStrategy::Hashed + } +} + +// ============================================================================= +// FILTER INSTANTIATION +// ============================================================================= + +/// Creates the optimal static filter for the given array. fn instantiate_static_filter( in_array: ArrayRef, ) -> Result> { - match in_array.data_type() { - DataType::Int32 => Ok(Arc::new(Int32StaticFilter::try_new(&in_array)?)), - _ => { - /* fall through to generic implementation */ - Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) + let len = in_array.len(); + let dt = in_array.data_type(); + + // Special case: Utf8View with short strings can be reinterpreted as i128 + if matches!(dt, DataType::Utf8View) && utf8view_all_short_strings(in_array.as_ref()) { + return create_utf8view_filter(&in_array, |arr| { + if len <= LARGE_BRANCHLESS_MAX { + instantiate_branchless_filter_for_type::(arr) + } else { + Ok(Arc::new( + PrimitiveFilter::>::try_new(&arr)?, + )) + } + }); + } + + match select_strategy(dt, len) { + FilterStrategy::Branchless => dispatch_filter(&in_array, dispatch_branchless), + FilterStrategy::Binary => dispatch_filter(&in_array, dispatch_sorted), + FilterStrategy::Hashed => dispatch_filter(&in_array, dispatch_hashed), + FilterStrategy::Generic => Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)), + } +} + +/// Generic filter dispatcher with fallback to ArrayStaticFilter. +fn dispatch_filter( + in_array: &ArrayRef, + dispatch: F, +) -> Result> +where + F: Fn(&ArrayRef) -> Option>>, +{ + dispatch(in_array).unwrap_or_else(|| { + Ok(Arc::new(ArrayStaticFilter::try_new(Arc::clone(in_array))?)) + }) +} + +// ============================================================================= +// TYPE DISPATCH +// ============================================================================= +// +// Dispatches filter creation to the appropriate primitive type. +// - Unsigned types: use directly +// - Signed/Float types: reinterpret as unsigned (same bit pattern for equality) + +/// Dispatch macro that routes to the appropriate type-specific filter creation. +/// Uses $direct for unsigned types and $reinterpret for signed/float types. +macro_rules! dispatch_primitive { + ($arr:expr, $direct:ident, $reinterpret:ident) => { + match $arr.data_type() { + DataType::UInt8 => Some($direct::($arr)), + DataType::UInt16 => Some($direct::($arr)), + DataType::UInt32 => Some($direct::($arr)), + DataType::UInt64 => Some($direct::($arr)), + DataType::Int8 => Some($reinterpret::($arr)), + DataType::Int16 => Some($reinterpret::($arr)), + DataType::Int32 => Some($reinterpret::($arr)), + DataType::Int64 => Some($reinterpret::($arr)), + DataType::Float32 => Some($reinterpret::($arr)), + DataType::Float64 => Some($reinterpret::($arr)), + _ => None, } + }; +} + +fn dispatch_branchless( + arr: &ArrayRef, +) -> Option>> { + fn direct( + arr: &ArrayRef, + ) -> Result> + where + T::Native: Copy + Default + PartialEq + Send + Sync + 'static, + { + instantiate_branchless_filter_for_type::(Arc::clone(arr)) + } + fn reinterpret( + arr: &ArrayRef, + ) -> Result> + where + D::Native: Copy + Default + PartialEq + Send + Sync + 'static, + { + create_reinterpreting_filter::(arr, |a| { + instantiate_branchless_filter_for_type::(a) + }) + } + dispatch_primitive!(arr, direct, reinterpret) +} + +fn dispatch_sorted( + arr: &ArrayRef, +) -> Option>> { + fn direct( + arr: &ArrayRef, + ) -> Result> + where + T::Native: Ord + Send + Sync + 'static, + { + Ok(Arc::new( + PrimitiveFilter::>::try_new(arr)?, + )) + } + fn reinterpret( + arr: &ArrayRef, + ) -> Result> + where + D::Native: Ord + Send + Sync + 'static, + { + create_reinterpreting_filter::(arr, |a| { + Ok(Arc::new( + PrimitiveFilter::>::try_new(&a)?, + )) + }) + } + dispatch_primitive!(arr, direct, reinterpret) +} + +fn dispatch_hashed( + arr: &ArrayRef, +) -> Option>> { + if matches!(arr.data_type(), DataType::Decimal128(_, _)) { + return Some( + PrimitiveFilter::>::try_new(arr) + .map(|f| Arc::new(f) as _), + ); + } + fn direct( + arr: &ArrayRef, + ) -> Result> + where + T::Native: Hash + Eq + Send + Sync + 'static, + { + Ok(Arc::new( + PrimitiveFilter::>::try_new(arr)?, + )) + } + fn reinterpret( + arr: &ArrayRef, + ) -> Result> + where + D::Native: Hash + Eq + Send + Sync + 'static, + { + create_reinterpreting_filter::(arr, |a| { + Ok(Arc::new( + PrimitiveFilter::>::try_new(&a)?, + )) + }) } + dispatch_primitive!(arr, direct, reinterpret) } +// ============================================================================= +// GENERIC ARRAY FILTER (fallback for unsupported types) +// ============================================================================= + impl ArrayStaticFilter { - /// Computes a [`StaticFilter`] for the provided [`Array`] if there - /// are nulls present or there are more than the configured number of - /// elements. - /// - /// Note: This is split into a separate function as higher-rank trait bounds currently - /// cause type inference to misbehave fn try_new(in_array: ArrayRef) -> Result { - // Null type has no natural order - return empty hash set if in_array.data_type() == &DataType::Null { return Ok(ArrayStaticFilter { in_array, @@ -168,7 +431,6 @@ impl ArrayStaticFilter { with_hashes([&in_array], &state, |hashes| -> Result<()> { let cmp = make_comparator(&in_array, &in_array, SortOptions::default())?; - let insert_value = |idx| { let hash = hashes[idx]; if let RawEntryMut::Vacant(v) = map @@ -178,7 +440,6 @@ impl ArrayStaticFilter { v.insert_with_hasher(hash, idx, (), |x| hashes[*x]); } }; - match in_array.nulls() { Some(nulls) => { BitIndexIterator::new(nulls.validity(), nulls.offset(), nulls.len()) @@ -186,7 +447,6 @@ impl ArrayStaticFilter { } None => (0..in_array.len()).for_each(insert_value), } - Ok(()) })?; @@ -198,66 +458,343 @@ impl ArrayStaticFilter { } } -struct Int32StaticFilter { +// ============================================================================= +// RESULT BUILDER FOR IN LIST OPERATIONS +// ============================================================================= +// +// Truth table for (needle_nulls, haystack_has_nulls, negated): +// (Some, true, false) → values: valid & contains, nulls: valid & contains +// (None, true, false) → values: contains, nulls: contains +// (Some, true, true) → values: valid ^ (valid & contains), nulls: valid & contains +// (None, true, true) → values: !contains, nulls: contains +// (Some, false, false) → values: valid & contains, nulls: valid +// (Some, false, true) → values: valid & !contains, nulls: valid +// (None, false, false) → values: contains, nulls: none +// (None, false, true) → values: !contains, nulls: none + +#[inline] +fn build_in_list_result( + len: usize, + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains: C, +) -> BooleanArray +where + C: Fn(usize) -> bool, +{ + let contains_buf = BooleanBuffer::collect_bool(len, &contains); + build_result_from_contains(needle_nulls, haystack_has_nulls, negated, contains_buf) +} + +#[inline] +fn build_result_from_contains( + needle_nulls: Option<&NullBuffer>, + haystack_has_nulls: bool, + negated: bool, + contains_buf: BooleanBuffer, +) -> BooleanArray { + match (needle_nulls, haystack_has_nulls, negated) { + (Some(v), true, false) => { + let buf = v.inner() & &contains_buf; + BooleanArray::new(buf.clone(), Some(NullBuffer::new(buf))) + } + (None, true, false) => { + BooleanArray::new(contains_buf.clone(), Some(NullBuffer::new(contains_buf))) + } + (Some(v), true, true) => { + let nulls = v.inner() & &contains_buf; + BooleanArray::new(v.inner() ^ &nulls, Some(NullBuffer::new(nulls))) + } + (None, true, true) => { + BooleanArray::new(!&contains_buf, Some(NullBuffer::new(contains_buf))) + } + (Some(v), false, false) => { + BooleanArray::new(v.inner() & &contains_buf, Some(v.clone())) + } + (Some(v), false, true) => { + BooleanArray::new(v.inner() & &(!&contains_buf), Some(v.clone())) + } + (None, false, false) => BooleanArray::new(contains_buf, None), + (None, false, true) => BooleanArray::new(!&contains_buf, None), + } +} + +// ============================================================================= +// LOOKUP STRATEGY TRAIT AND IMPLEMENTATIONS +// ============================================================================= + +trait LookupStrategy: Send + Sync { + fn new(values: Vec) -> Self; + fn contains(&self, value: &T) -> bool; +} + +struct SortedLookup(Vec); + +impl LookupStrategy for SortedLookup { + fn new(mut values: Vec) -> Self { + values.sort_unstable(); + values.dedup(); + Self(values) + } + #[inline] + fn contains(&self, value: &T) -> bool { + self.0.binary_search(value).is_ok() + } +} + +struct HashedLookup(HashSet); + +impl LookupStrategy for HashedLookup { + fn new(values: Vec) -> Self { + Self(values.into_iter().collect()) + } + #[inline] + fn contains(&self, value: &T) -> bool { + self.0.contains(value) + } +} + +// ============================================================================= +// DICTIONARY ARRAY HANDLING +// ============================================================================= + +macro_rules! handle_dictionary { + ($self:ident, $v:ident, $negated:ident) => { + downcast_dictionary_array! { + $v => { + let values_contains = $self.contains($v.values().as_ref(), $negated)?; + let result = take(&values_contains, $v.keys(), None)?; + return Ok(downcast_array(result.as_ref())) + } + _ => {} + } + }; +} + +// ============================================================================= +// UNIFIED PRIMITIVE FILTER +// ============================================================================= + +struct PrimitiveFilter { null_count: usize, - values: HashSet, + lookup: S, + _phantom: std::marker::PhantomData T>, } -impl Int32StaticFilter { +impl> PrimitiveFilter { fn try_new(in_array: &ArrayRef) -> Result { - let in_array = in_array - .as_primitive_opt::() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; + let arr = in_array.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!( + "PrimitiveFilter: expected {} array", + std::any::type_name::() + ) + })?; + Ok(Self { + null_count: arr.null_count(), + lookup: S::new(arr.iter().flatten().collect()), + _phantom: std::marker::PhantomData, + }) + } +} + +impl StaticFilter for PrimitiveFilter +where + T: ArrowPrimitiveType + 'static, + T::Native: Send + Sync + 'static, + S: LookupStrategy + 'static, +{ + fn null_count(&self) -> usize { + self.null_count + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!( + "PrimitiveFilter: expected {} array", + std::any::type_name::() + ) + })?; + let values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + |i| self.lookup.contains(unsafe { values.get_unchecked(i) }), + )) + } +} + +// ============================================================================= +// BRANCHLESS FILTER (Const Generic for Small Lists) +// ============================================================================= - let mut values = HashSet::with_capacity(in_array.len()); - let null_count = in_array.null_count(); +struct BranchlessFilter { + null_count: usize, + values: [T::Native; N], +} - for v in in_array.iter().flatten() { - values.insert(v); +impl BranchlessFilter +where + T::Native: Copy + Default + PartialEq, +{ + fn try_new(in_array: &ArrayRef) -> Option> { + let in_array = in_array.as_primitive_opt::()?; + let non_null_count = in_array.len() - in_array.null_count(); + if non_null_count != N { + return None; } + let values: Vec<_> = in_array.iter().flatten().collect(); + let mut arr = [T::Native::default(); N]; + arr.copy_from_slice(&values); + Some(Ok(Self { + null_count: in_array.null_count(), + values: arr, + })) + } - Ok(Self { null_count, values }) + #[inline(always)] + fn check(&self, needle: T::Native) -> bool { + self.values + .iter() + .fold(false, |acc, &v| acc | (v == needle)) } } -impl StaticFilter for Int32StaticFilter { +impl StaticFilter for BranchlessFilter +where + T::Native: Copy + Default + PartialEq + Send + Sync, +{ fn null_count(&self) -> usize { self.null_count } fn contains(&self, v: &dyn Array, negated: bool) -> Result { - let v = v - .as_primitive_opt::() - .ok_or_else(|| exec_datafusion_err!("Failed to downcast array"))?; - - let result = match (v.null_count() > 0, negated) { - (true, false) => { - // has nulls, not negated" - BooleanArray::from_iter( - v.iter().map(|value| Some(self.values.contains(&value?))), - ) - } - (true, true) => { - // has nulls, negated - BooleanArray::from_iter( - v.iter().map(|value| Some(!self.values.contains(&value?))), - ) - } - (false, false) => { - //no null, not negated - BooleanArray::from_iter( - v.values().iter().map(|value| self.values.contains(value)), - ) - } - (false, true) => { - // no null, negated - BooleanArray::from_iter( - v.values().iter().map(|value| !self.values.contains(value)), - ) - } + handle_dictionary!(self, v, negated); + let v = v.as_primitive_opt::().ok_or_else(|| { + exec_datafusion_err!("Failed to downcast array to primitive type") + })?; + let input_values = v.values(); + Ok(build_in_list_result( + v.len(), + v.nulls(), + self.null_count > 0, + negated, + #[inline(always)] + |i| self.check(unsafe { *input_values.get_unchecked(i) }), + )) + } +} + +fn instantiate_branchless_filter_for_type( + in_array: ArrayRef, +) -> Result> +where + T::Native: Copy + Default + PartialEq + Send + Sync + 'static, +{ + macro_rules! try_branchless { + ($($n:literal),*) => { + $(if let Some(Ok(f)) = BranchlessFilter::::try_new(&in_array) { + return Ok(Arc::new(f)); + })* }; - Ok(result) } + try_branchless!(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16); + Ok(Arc::new(ArrayStaticFilter::try_new(in_array)?)) +} + +// ============================================================================= +// TRANSFORMING FILTER (Unified Type Reinterpretation) +// ============================================================================= + +struct TransformingFilter { + inner: Arc, + transform: F, +} + +impl StaticFilter for TransformingFilter +where + F: Fn(&dyn Array) -> ArrayRef + Send + Sync, +{ + fn null_count(&self) -> usize { + self.inner.null_count() + } + + fn contains(&self, v: &dyn Array, negated: bool) -> Result { + handle_dictionary!(self, v, negated); + self.inner.contains((self.transform)(v).as_ref(), negated) + } +} + +/// Creates a TransformingFilter that reinterprets arrays from type S to type D. +fn create_reinterpreting_filter( + in_array: &ArrayRef, + create_inner: F, +) -> Result> +where + S: ArrowPrimitiveType + 'static, + D: ArrowPrimitiveType + 'static, + F: FnOnce(ArrayRef) -> Result>, +{ + let reinterpreted = reinterpret_primitive::(in_array.as_ref()); + let inner = create_inner(reinterpreted)?; + Ok(Arc::new(TransformingFilter { + inner, + transform: |v: &dyn Array| reinterpret_primitive::(v), + })) +} + +/// Creates a TransformingFilter for Utf8View arrays reinterpreted as Decimal128. +fn create_utf8view_filter( + in_array: &ArrayRef, + create_inner: F, +) -> Result> +where + F: FnOnce(ArrayRef) -> Result>, +{ + let reinterpreted = reinterpret_utf8view_as_decimal128(in_array.as_ref()); + let inner = create_inner(reinterpreted)?; + Ok(Arc::new(TransformingFilter { + inner, + transform: reinterpret_utf8view_as_decimal128, + })) +} + +// ============================================================================= +// PRIMITIVE TYPE REINTERPRETATION +// ============================================================================= + +#[inline] +fn reinterpret_primitive( + array: &dyn Array, +) -> ArrayRef { + let source = array.as_primitive::(); + let buffer: ScalarBuffer = source.values().inner().clone().into(); + Arc::new(PrimitiveArray::::new(buffer, source.nulls().cloned())) +} + +// ============================================================================= +// UTF8VIEW REINTERPRETATION (short strings ≤12 bytes → Decimal128) +// ============================================================================= + +#[inline] +fn utf8view_all_short_strings(array: &dyn Array) -> bool { + let sv = array.as_string_view(); + sv.views().iter().enumerate().all(|(i, &view)| { + !sv.is_valid(i) || (view as u32) as usize <= UTF8VIEW_INLINE_LEN + }) +} + +#[inline] +fn reinterpret_utf8view_as_decimal128(array: &dyn Array) -> ArrayRef { + let sv = array.as_string_view(); + let buffer: ScalarBuffer = sv.views().inner().clone().into(); + Arc::new(PrimitiveArray::::new( + buffer, + sv.nulls().cloned(), + )) } /// Evaluates the list of expressions into an array, flattening any dictionaries @@ -414,8 +951,12 @@ impl PhysicalExpr for InListExpr { if scalar.is_null() { // SQL three-valued logic: null IN (...) is always null // The code below would handle this correctly but this is a faster path + let nulls = NullBuffer::new_null(num_rows); return Ok(ColumnarValue::Array(Arc::new( - BooleanArray::from(vec![None; num_rows]), + BooleanArray::new( + BooleanBuffer::new_unset(num_rows), + Some(nulls), + ), ))); } // Use a 1 row array to avoid code duplication/branching @@ -426,12 +967,15 @@ impl PhysicalExpr for InListExpr { // Broadcast the single result to all rows // Must check is_null() to preserve NULL values (SQL three-valued logic) if result_array.is_null(0) { - BooleanArray::from(vec![None; num_rows]) + let nulls = NullBuffer::new_null(num_rows); + BooleanArray::new( + BooleanBuffer::new_unset(num_rows), + Some(nulls), + ) + } else if result_array.value(0) { + BooleanArray::new(BooleanBuffer::new_set(num_rows), None) } else { - BooleanArray::from_iter(std::iter::repeat_n( - result_array.value(0), - num_rows, - )) + BooleanArray::new(BooleanBuffer::new_unset(num_rows), None) } } } @@ -572,11 +1116,8 @@ pub fn in_list( // Try to create a static filter for constant expressions let static_filter = try_evaluate_constant_list(&list, schema) - .and_then(ArrayStaticFilter::try_new) - .ok() - .map(|static_filter| { - Arc::new(static_filter) as Arc - }); + .and_then(instantiate_static_filter) + .ok(); Ok(Arc::new(InListExpr::new( expr, @@ -1028,6 +1569,612 @@ mod tests { Ok(()) } + #[test] + fn in_list_int8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int8, true)]); + let a = Int8Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0i8), lit(1i8)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0i8), lit(1i8)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0i8), lit(1i8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0i8), lit(1i8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_int16() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int16, true)]); + let a = Int16Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0i16), lit(1i16)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0i16), lit(1i16)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0i16), lit(1i16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0i16), lit(1i16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_int32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0i32), lit(1i32)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0i32), lit(1i32)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0i32), lit(1i32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0i32), lit(1i32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt8, true)]); + let a = UInt8Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u8), lit(1u8)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u8), lit(1u8)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u8), lit(1u8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u8), lit(1u8), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint16() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt16, true)]); + let a = UInt16Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u16), lit(1u16)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u16), lit(1u16)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u16), lit(1u16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u16), lit(1u16), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint32() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt32, true)]); + let a = UInt32Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u32), lit(1u32)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u32), lit(1u32)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u32), lit(1u32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u32), lit(1u32), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_uint64() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::UInt64, true)]); + let a = UInt64Array::from(vec![Some(0), Some(2), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in (0, 1)" + let list = vec![lit(0u64), lit(1u64)]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1)" + let list = vec![lit(0u64), lit(1u64)]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in (0, 1, NULL)" + let list = vec![lit(0u64), lit(1u64), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in (0, 1, NULL)" + let list = vec![lit(0u64), lit(1u64), lit(ScalarValue::Null)]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_large_utf8() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::LargeUtf8, true)]); + let a = LargeStringArray::from(vec![Some("a"), Some("d"), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::LargeUtf8(None))]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::LargeUtf8(None))]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_utf8_view() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8View, true)]); + let a = StringViewArray::from(vec![Some("a"), Some("d"), None]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b")" + let list = vec![lit("a"), lit("b")]; + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8View(None))]; + in_list!( + batch, + list, + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ("a", "b", null)" + let list = vec![lit("a"), lit("b"), lit(ScalarValue::Utf8View(None))]; + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_large_binary() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::LargeBinary, true)]); + let a = LargeBinaryArray::from(vec![ + Some([1, 2, 3].as_slice()), + Some([1, 2, 2].as_slice()), + None, + ]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ([1, 2, 3], [4, 5, 6])" + let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6])" + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ([1, 2, 3], [4, 5, 6], null)" + let list = vec![ + lit([1, 2, 3].as_slice()), + lit([4, 5, 6].as_slice()), + lit(ScalarValue::LargeBinary(None)), + ]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6], null)" + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + + #[test] + fn in_list_binary_view() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::BinaryView, true)]); + let a = BinaryViewArray::from(vec![ + Some([1, 2, 3].as_slice()), + Some([1, 2, 2].as_slice()), + None, + ]); + let col_a = col("a", &schema)?; + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(a)])?; + + // expression: "a in ([1, 2, 3], [4, 5, 6])" + let list = vec![lit([1, 2, 3].as_slice()), lit([4, 5, 6].as_slice())]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), Some(false), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6])" + in_list!( + batch, + list, + &true, + vec![Some(false), Some(true), None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a in ([1, 2, 3], [4, 5, 6], null)" + let list = vec![ + lit([1, 2, 3].as_slice()), + lit([4, 5, 6].as_slice()), + lit(ScalarValue::BinaryView(None)), + ]; + in_list!( + batch, + list.clone(), + &false, + vec![Some(true), None, None], + Arc::clone(&col_a), + &schema + ); + + // expression: "a not in ([1, 2, 3], [4, 5, 6], null)" + in_list!( + batch, + list, + &true, + vec![Some(false), None, None], + Arc::clone(&col_a), + &schema + ); + + Ok(()) + } + #[test] fn in_list_date64() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Date64, true)]);