Skip to content

Commit

Permalink
[SPARK-50683][SQL] Inline the common expression in With if used once
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
As title.

### Why are the changes needed?

Simplify plan and reduce unnecessary project.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
UT.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #49310 from zml1206/with.

Authored-by: zml1206 <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
zml1206 authored and cloud-fan committed Jan 2, 2025
1 parent 721a417 commit 492fcd8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] {

private def applyInternal(p: LogicalPlan): LogicalPlan = {
val inputPlans = p.children
val commonExprIdSet = p.expressions
.flatMap(_.collect { case r: CommonExpressionRef => r.id })
.groupBy(identity)
.transform((_, v) => v.size)
.filter(_._2 > 1)
.keySet
val commonExprsPerChild = Array.fill(inputPlans.length)(mutable.ListBuffer.empty[(Alias, Long)])
var newPlan: LogicalPlan = p.mapExpressions { expr =>
rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild)
rewriteWithExprAndInputPlans(expr, inputPlans, commonExprsPerChild, commonExprIdSet)
}
val newChildren = inputPlans.zip(commonExprsPerChild).map { case (inputPlan, commonExprs) =>
if (commonExprs.isEmpty) {
Expand All @@ -96,16 +102,17 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
e: Expression,
inputPlans: Seq[LogicalPlan],
commonExprsPerChild: Array[mutable.ListBuffer[(Alias, Long)]],
commonExprIdSet: Set[CommonExpressionId],
isNestedWith: Boolean = false): Expression = {
if (!e.containsPattern(WITH_EXPRESSION)) return e
e match {
// Do not handle nested With in one pass. Leave it to the next rule executor batch.
case w: With if !isNestedWith =>
// Rewrite nested With expressions first
val child = rewriteWithExprAndInputPlans(
w.child, inputPlans, commonExprsPerChild, isNestedWith = true)
w.child, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true)
val defs = w.defs.map(rewriteWithExprAndInputPlans(
_, inputPlans, commonExprsPerChild, isNestedWith = true))
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith = true))
val refToExpr = mutable.HashMap.empty[CommonExpressionId, Expression]

defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), index) =>
Expand All @@ -114,7 +121,7 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
"Cannot rewrite canonicalized Common expression definitions")
}

if (CollapseProject.isCheap(child)) {
if (CollapseProject.isCheap(child) || !commonExprIdSet.contains(id)) {
refToExpr(id) = child
} else {
val childPlanIndex = inputPlans.indexWhere(
Expand Down Expand Up @@ -171,7 +178,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {

case c: ConditionalExpression =>
val newAlwaysEvaluatedInputs = c.alwaysEvaluatedInputs.map(
rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith))
rewriteWithExprAndInputPlans(
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith))
val newExpr = c.withNewAlwaysEvaluatedInputs(newAlwaysEvaluatedInputs)
// Use transformUp to handle nested With.
newExpr.transformUpWithPruning(_.containsPattern(WITH_EXPRESSION)) {
Expand All @@ -185,7 +193,8 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
}

case other => other.mapChildren(
rewriteWithExprAndInputPlans(_, inputPlans, commonExprsPerChild, isNestedWith)
rewriteWithExprAndInputPlans(
_, inputPlans, commonExprsPerChild, commonExprIdSet, isNestedWith)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class RewriteWithExpressionSuite extends PlanTest {
val commonExprDef2 = CommonExpressionDef(a + a, CommonExpressionId(2))
val ref2 = new CommonExpressionRef(commonExprDef2)
// The inner main expression references the outer expression
val innerExpr2 = With(ref2 + outerRef, Seq(commonExprDef2))
val innerExpr2 = With(ref2 + ref2 + outerRef, Seq(commonExprDef2))
val outerExpr2 = With(outerRef + innerExpr2, Seq(outerCommonExprDef))
comparePlans(
Optimizer.execute(testRelation.select(outerExpr2.as("col"))),
Expand All @@ -152,7 +152,8 @@ class RewriteWithExpressionSuite extends PlanTest {
.select(star(), (a + a).as("_common_expr_2"))
// The final Project contains the final result expression, which references both common
// expressions.
.select(($"_common_expr_0" + ($"_common_expr_2" + $"_common_expr_0")).as("col"))
.select(($"_common_expr_0" +
($"_common_expr_2" + $"_common_expr_2" + $"_common_expr_0")).as("col"))
.analyze
)
}
Expand Down Expand Up @@ -490,4 +491,13 @@ class RewriteWithExpressionSuite extends PlanTest {
val wrongPlan = testRelation.select(expr1.as("c1"), expr3.as("c3")).analyze
intercept[AssertionError](Optimizer.execute(wrongPlan))
}

test("SPARK-50683: inline the common expression in With if used once") {
val a = testRelation.output.head
val exprDef = CommonExpressionDef(a + a)
val exprRef = new CommonExpressionRef(exprDef)
val expr = With(exprRef + 1, Seq(exprDef))
val plan = testRelation.select(expr.as("col"))
comparePlans(Optimizer.execute(plan), testRelation.select((a + a + 1).as("col")))
}
}

0 comments on commit 492fcd8

Please sign in to comment.