Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
245 changes: 185 additions & 60 deletions datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode};
use datafusion_common::{internal_err, plan_err, Column, Result};
use datafusion_expr::expr::{Exists, InSubquery};
use datafusion_expr::expr_rewriter::create_col_from_scalar_expr;
use datafusion_expr::logical_plan::{JoinType, Projection, Subquery};
use datafusion_expr::logical_plan::{
Join as LogicalJoin, JoinType, Projection, Subquery,
};
use datafusion_expr::utils::{conjunction, split_conjunction_owned};
use datafusion_expr::{
exists, in_subquery, lit, not, not_exists, not_in_subquery, BinaryExpr, Expr, Filter,
Expand Down Expand Up @@ -66,82 +68,166 @@ impl OptimizerRule for DecorrelatePredicateSubquery {
})?
.data;

match plan {
LogicalPlan::Filter(filter) => {
if !has_subquery(&filter.predicate) {
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}
// Handle Filters first (existing behavior)
if let LogicalPlan::Filter(filter) = plan.clone() {
if !has_subquery(&filter.predicate) {
return Ok(Transformed::no(LogicalPlan::Filter(filter)));
}

let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
split_conjunction_owned(filter.predicate)
.into_iter()
.partition(has_subquery);
let (with_subqueries, mut other_exprs): (Vec<_>, Vec<_>) =
split_conjunction_owned(filter.predicate)
.into_iter()
.partition(has_subquery);

if with_subqueries.is_empty() {
return internal_err!(
"can not find expected subqueries in DecorrelatePredicateSubquery"
);
}
if with_subqueries.is_empty() {
return internal_err!(
"can not find expected subqueries in DecorrelatePredicateSubquery"
);
}

// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = Arc::unwrap_or_clone(filter.input);
for subquery_expr in with_subqueries {
match extract_subquery_info(subquery_expr) {
// The subquery expression is at the top level of the filter
SubqueryPredicate::Top(subquery) => {
match build_join_top(
&subquery,
&cur_input,
config.alias_generator(),
)? {
Some(plan) => cur_input = plan,
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
None => other_exprs.push(subquery.expr()),
}
}
// The subquery expression is embedded within another expression
SubqueryPredicate::Embedded(expr) => {
let (plan, expr_without_subqueries) =
rewrite_inner_subqueries(cur_input, expr, config)?;
cur_input = plan;
other_exprs.push(expr_without_subqueries);
// iterate through all exists clauses in predicate, turning each into a join
let mut cur_input = Arc::unwrap_or_clone(filter.input);
for subquery_expr in with_subqueries {
match extract_subquery_info(subquery_expr) {
// The subquery expression is at the top level of the filter
SubqueryPredicate::Top(subquery) => {
match build_join_top(
&subquery,
&cur_input,
config.alias_generator(),
)? {
Some(plan) => cur_input = plan,
// If the subquery can not be converted to a Join, reconstruct the subquery expression and add it to the Filter
None => other_exprs.push(subquery.expr()),
}
}
// The subquery expression is embedded within another expression
SubqueryPredicate::Embedded(expr) => {
let (plan, expr_without_subqueries) =
rewrite_inner_subqueries(cur_input, expr, config)?;
cur_input = plan;
other_exprs.push(expr_without_subqueries);
}
}
}

let expr = conjunction(other_exprs);
if let Some(expr) = expr {
let new_filter = Filter::try_new(expr, Arc::new(cur_input))?;
return Ok(Transformed::yes(LogicalPlan::Filter(new_filter)));
}
return Ok(Transformed::yes(cur_input));
}

let expr = conjunction(other_exprs);
let mut new_plan = cur_input;
if let Some(expr) = expr {
let new_filter = Filter::try_new(expr, Arc::new(new_plan))?;
new_plan = LogicalPlan::Filter(new_filter);
// Additionally handle subqueries embedded in Join.filter expressions
if let LogicalPlan::Join(join) = plan {
if let Some(predicate) = &join.filter {
if has_subquery(predicate) {
let (new_left, new_predicate) = rewrite_inner_subqueries(
Arc::unwrap_or_clone(join.left),
predicate.clone(),
config,
)?;

let new_join = LogicalJoin::try_new(
Arc::new(new_left),
Arc::clone(&join.right),
join.on.clone(),
Some(new_predicate),
join.join_type,
join.join_constraint,
join.null_equals_null,
)?;
return Ok(Transformed::yes(LogicalPlan::Join(new_join)));
}
Ok(Transformed::yes(new_plan))
}
LogicalPlan::Projection(proj) => {
// Only proceed if any projection expression contains a subquery
if !proj.expr.iter().any(has_subquery) {
return Ok(Transformed::no(LogicalPlan::Projection(proj)));
return Ok(Transformed::no(LogicalPlan::Join(join)));
}

// Handle subqueries embedded in Aggregate group/aggregate expressions
if let LogicalPlan::Aggregate(aggregate) = plan {
let mut needs_rewrite = false;
for e in &aggregate.group_expr {
if has_subquery(e) {
needs_rewrite = true;
break;
}
}
if !needs_rewrite {
for e in &aggregate.aggr_expr {
if has_subquery(e) {
needs_rewrite = true;
break;
}
}
}
if !needs_rewrite {
return Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
}

let mut cur_input = Arc::unwrap_or_clone(proj.input);
let mut new_exprs = Vec::with_capacity(proj.expr.len());
for e in proj.expr {
let old_name = e.schema_name().to_string();
let (plan_after, rewritten) =
rewrite_inner_subqueries(cur_input, e, config)?;
cur_input = plan_after;
let new_name = rewritten.schema_name().to_string();
let mut cur_input = Arc::unwrap_or_clone(aggregate.input);
let mut new_group_exprs = Vec::with_capacity(aggregate.group_expr.len());
for expr in aggregate.group_expr {
if has_subquery(&expr) {
let (next_input, rewritten_expr) =
rewrite_inner_subqueries(cur_input, expr, config)?;
cur_input = next_input;
new_group_exprs.push(rewritten_expr);
} else {
new_group_exprs.push(expr);
}
}
let mut new_aggr_exprs = Vec::with_capacity(aggregate.aggr_expr.len());
for expr in aggregate.aggr_expr {
if has_subquery(&expr) {
let old_name = expr.schema_name().to_string();
let (next_input, rewritten_expr) =
rewrite_inner_subqueries(cur_input, expr, config)?;
cur_input = next_input;
let new_name = rewritten_expr.schema_name().to_string();
if new_name != old_name {
new_exprs.push(rewritten.alias(old_name));
new_aggr_exprs.push(rewritten_expr.alias(old_name));
} else {
new_exprs.push(rewritten);
new_aggr_exprs.push(rewritten_expr);
}
} else {
new_aggr_exprs.push(expr);
}
let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?;
Ok(Transformed::yes(LogicalPlan::Projection(new_proj)))
}
other => Ok(Transformed::no(other)),

let new_plan = LogicalPlanBuilder::from(cur_input)
.aggregate(new_group_exprs, new_aggr_exprs)?
.build()?;
return Ok(Transformed::yes(new_plan));
}

// Handle Projection nodes with subqueries in expressions
if let LogicalPlan::Projection(proj) = plan {
// Only proceed if any projection expression contains a subquery
if !proj.expr.iter().any(has_subquery) {
return Ok(Transformed::no(LogicalPlan::Projection(proj)));
}

let mut cur_input = Arc::unwrap_or_clone(proj.input);
let mut new_exprs = Vec::with_capacity(proj.expr.len());
for e in proj.expr {
let old_name = e.schema_name().to_string();
let (plan_after, rewritten) =
rewrite_inner_subqueries(cur_input, e, config)?;
cur_input = plan_after;
let new_name = rewritten.schema_name().to_string();
if new_name != old_name {
new_exprs.push(rewritten.alias(old_name));
} else {
new_exprs.push(rewritten);
}
}
let new_proj = Projection::try_new(new_exprs, Arc::new(cur_input))?;
return Ok(Transformed::yes(LogicalPlan::Projection(new_proj)));
}

// Other plans unchanged
Ok(Transformed::no(plan))
}

fn name(&self) -> &str {
Expand Down Expand Up @@ -477,6 +563,45 @@ mod tests {
))
}

/// Aggregation with CASE WHEN ... IN (subquery) should be decorrelated under the Aggregate
#[test]
fn aggregate_case_in_subquery() -> Result<()> {
let table_scan = test_table_scan_with_name("distinct_source")?;
use datafusion_expr::expr_fn::when;
use datafusion_functions_aggregate::expr_fn::max as agg_max;

let agg_b: Expr = agg_max(col("distinct_source.b"));
let subq = LogicalPlanBuilder::from(table_scan.clone())
.aggregate(Vec::<Expr>::new(), vec![agg_b])?
.project(vec![col("max(distinct_source.b)")])?
.build()?;

let case_expr = when(
in_subquery(col("distinct_source.b"), Arc::new(subq)),
lit(1),
)
.otherwise(lit(0))?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("distinct_source.a").alias("primary_key")],
vec![
agg_max(case_expr).alias("is_in_most_recent_task"),
agg_max(col("distinct_source.c")).alias("max_timestamp"),
],
)?
.build()?;

use crate::{OptimizerContext, OptimizerRule};
let optimized = DecorrelatePredicateSubquery::new()
.rewrite(plan, &OptimizerContext::new())?
.data;
let lp = optimized.display_indent().to_string();
assert!(lp.contains("Aggregate:"));
assert!(lp.contains("Left"));
Ok(())
}

/// Test for several IN subquery expressions
#[test]
fn in_subquery_multiple() -> Result<()> {
Expand Down
Loading
Loading