Skip to content

Commit

Permalink
[SPARK-50636][SQL] Extending CTESubstitution.scala to make it aware o…
Browse files Browse the repository at this point in the history
…f recursion

### What changes were proposed in this pull request?

1. Self-contained changes to CTESubstitution.scala that make CTE substitutions and resolutions aware of the recursion. Also, addition of error messages for the incorrect usage of RECURSIVE keyword
2. Introduction of RECURSIVE keyword to the lexer and parser, and other additions due to the introduction of a new keyword - adding RECURSIVE to keywords.sql tests, and hive-thriftserver.

More information about recursive CTEs and the future files to be merged: https://docs.google.com/document/d/1qcEJxqoXcr5cSt6HgIQjWQSqhfkSaVYkoDHsg5oxXp4/edit

### Why are the changes needed?

Support for the recursive CTE.

### Does this PR introduce _any_ user-facing change?

Yes. RECURSIVE keyword is introduced in this PR.

### How was this patch tested?

WIP: additional tests to further test this change should be added soon

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #49232 from milanisvet/milanrcte2continue.

Lead-authored-by: Milan Cupac <[email protected]>
Co-authored-by: Nemanja Petrovic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
2 people authored and cloud-fan committed Jan 3, 2025
1 parent b210f42 commit 580b3c0
Show file tree
Hide file tree
Showing 29 changed files with 311 additions and 204 deletions.
12 changes: 12 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -4146,6 +4146,18 @@
],
"sqlState" : "38000"
},
"RECURSIVE_CTE_IN_LEGACY_MODE" : {
"message" : [
"Recursive definitions cannot be used in legacy CTE precedence mode (spark.sql.legacy.ctePrecedencePolicy=LEGACY)."
],
"sqlState" : "42836"
},
"RECURSIVE_CTE_WHEN_INLINING_IS_FORCED" : {
"message" : [
"Recursive definitions cannot be used when CTE inlining is forced."
],
"sqlState" : "42836"
},
"RECURSIVE_PROTOBUF_SCHEMA" : {
"message" : [
"Found recursive reference in Protobuf schema, which can not be processed by Spark by default: <fieldDescriptor>. try setting the option `recursive.fields.max.depth` 1 to 10. Going beyond 10 levels of recursion is not allowed."
Expand Down
1 change: 1 addition & 0 deletions docs/sql-ref-ansi-compliance.md
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,7 @@ Below is a list of all the keywords in Spark SQL.
|RECORDREADER|non-reserved|non-reserved|non-reserved|
|RECORDWRITER|non-reserved|non-reserved|non-reserved|
|RECOVER|non-reserved|non-reserved|non-reserved|
|RECURSIVE|reserved|non-reserved|reserved|
|REDUCE|non-reserved|non-reserved|non-reserved|
|REFERENCES|reserved|non-reserved|reserved|
|REFRESH|non-reserved|non-reserved|non-reserved|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ REAL: 'REAL';
RECORDREADER: 'RECORDREADER';
RECORDWRITER: 'RECORDWRITER';
RECOVER: 'RECOVER';
RECURSIVE: 'RECURSIVE';
REDUCE: 'REDUCE';
REFERENCES: 'REFERENCES';
REFRESH: 'REFRESH';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ describeColName
;

ctes
: WITH namedQuery (COMMA namedQuery)*
: WITH RECURSIVE? namedQuery (COMMA namedQuery)*
;

namedQuery
Expand Down Expand Up @@ -2118,6 +2118,7 @@ nonReserved
| RECORDREADER
| RECORDWRITER
| RECOVER
| RECURSIVE
| REDUCE
| REFERENCES
| REFRESH
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,15 @@ object CTESubstitution extends Rule[LogicalPlan] {
plan: LogicalPlan,
cteDefs: ArrayBuffer[CTERelationDef]): LogicalPlan = {
plan.resolveOperatorsUp {
case UnresolvedWith(child, relations, _) =>
val resolvedCTERelations =
resolveCTERelations(relations, isLegacy = true, forceInline = false, Seq.empty, cteDefs)
substituteCTE(child, alwaysInline = true, resolvedCTERelations)
case cte @ UnresolvedWith(child, relations, allowRecursion) =>
if (allowRecursion) {
cte.failAnalysis(
errorClass = "RECURSIVE_CTE_IN_LEGACY_MODE",
messageParameters = Map.empty)
}
val resolvedCTERelations = resolveCTERelations(relations, isLegacy = true,
forceInline = false, Seq.empty, cteDefs, allowRecursion)
substituteCTE(child, alwaysInline = true, resolvedCTERelations, None)
}
}

Expand Down Expand Up @@ -202,14 +207,21 @@ object CTESubstitution extends Rule[LogicalPlan] {
var firstSubstituted: Option[LogicalPlan] = None
val newPlan = plan.resolveOperatorsDownWithPruning(
_.containsAnyPattern(UNRESOLVED_WITH, PLAN_EXPRESSION)) {
case UnresolvedWith(child: LogicalPlan, relations, _) =>
// allowRecursion flag is set to `True` by the parser if the `RECURSIVE` keyword is used.
case cte @ UnresolvedWith(child: LogicalPlan, relations, allowRecursion) =>
if (allowRecursion && forceInline) {
cte.failAnalysis(
errorClass = "RECURSIVE_CTE_WHEN_INLINING_IS_FORCED",
messageParameters = Map.empty)
}
val resolvedCTERelations =
resolveCTERelations(relations, isLegacy = false, forceInline, outerCTEDefs, cteDefs) ++
outerCTEDefs
resolveCTERelations(relations, isLegacy = false, forceInline, outerCTEDefs, cteDefs,
allowRecursion) ++ outerCTEDefs
val substituted = substituteCTE(
traverseAndSubstituteCTE(child, forceInline, resolvedCTERelations, cteDefs)._1,
forceInline,
resolvedCTERelations)
resolvedCTERelations,
None)
if (firstSubstituted.isEmpty) {
firstSubstituted = Some(substituted)
}
Expand All @@ -228,7 +240,8 @@ object CTESubstitution extends Rule[LogicalPlan] {
isLegacy: Boolean,
forceInline: Boolean,
outerCTEDefs: Seq[(String, CTERelationDef)],
cteDefs: ArrayBuffer[CTERelationDef]): Seq[(String, CTERelationDef)] = {
cteDefs: ArrayBuffer[CTERelationDef],
allowRecursion: Boolean): Seq[(String, CTERelationDef)] = {
val alwaysInline = isLegacy || forceInline
var resolvedCTERelations = if (alwaysInline) {
Seq.empty
Expand All @@ -247,49 +260,116 @@ object CTESubstitution extends Rule[LogicalPlan] {
// NOTE: we must call `traverseAndSubstituteCTE` before `substituteCTE`, as the relations
// in the inner CTE have higher priority over the relations in the outer CTE when resolving
// inner CTE relations. For example:
// WITH t1 AS (SELECT 1)
// t2 AS (
// WITH t1 AS (SELECT 2)
// WITH t3 AS (SELECT * FROM t1)
// )
// t3 should resolve the t1 to `SELECT 2` instead of `SELECT 1`.
traverseAndSubstituteCTE(relation, forceInline, resolvedCTERelations, cteDefs)._1
// WITH
// t1 AS (SELECT 1),
// t2 AS (
// WITH
// t1 AS (SELECT 2),
// t3 AS (SELECT * FROM t1)
// SELECT * FROM t1
// )
// SELECT * FROM t2
// t3 should resolve the t1 to `SELECT 2` ("inner" t1) instead of `SELECT 1`.
//
// When recursion allowed (RECURSIVE keyword used):
// Consider following example:
// WITH
// t1 AS (SELECT 1),
// t2 AS (
// WITH RECURSIVE
// t1 AS (
// SELECT 1 AS level
// UNION (
// WITH t3 AS (SELECT level + 1 FROM t1 WHERE level < 10)
// SELECT * FROM t3
// )
// )
// SELECT * FROM t1
// )
// SELECT * FROM t2
// t1 reference within t3 would initially resolve to outer `t1` (SELECT 1), as the inner t1
// is not yet known. Therefore, we need to remove definitions that conflict with current
// relation `name` from the list of `outerCTEDefs` entering `traverseAndSubstituteCTE()`.
// NOTE: It will be recognized later in the code that this is actually a self-reference
// (reference to the inner t1).
val nonConflictingCTERelations = if (allowRecursion) {
resolvedCTERelations.filterNot {
case (cteName, cteDef) => cteDef.conf.resolver(cteName, name)
}
} else {
resolvedCTERelations
}
traverseAndSubstituteCTE(relation, forceInline, nonConflictingCTERelations, cteDefs)._1
}
// CTE definition can reference a previous one
val substituted = substituteCTE(innerCTEResolved, alwaysInline, resolvedCTERelations)

// If recursion is allowed (RECURSIVE keyword specified)
// then it has higher priority than outer or previous relations.
// Therefore, we construct a `CTERelationDef` for the current relation.
// Later if we encounter unresolved relation which we need to find which CTE Def it is
// referencing to, we first check if it is a reference to this one. If yes, then we set the
// reference as being recursive.
val recursiveCTERelation = if (allowRecursion) {
Some(name -> CTERelationDef(relation))
} else {
None
}
// CTE definition can reference a previous one or itself if recursion allowed.
val substituted = substituteCTE(innerCTEResolved, alwaysInline,
resolvedCTERelations, recursiveCTERelation)
val cteRelation = CTERelationDef(substituted)
if (!alwaysInline) {
cteDefs += cteRelation
}

// Prepending new CTEs makes sure that those have higher priority over outer ones.
resolvedCTERelations +:= (name -> cteRelation)
}
resolvedCTERelations
}

/**
* This function is called from `substituteCTE` to actually substitute unresolved relations
* with CTE references.
*/
private def resolveWithCTERelations(
table: String,
alwaysInline: Boolean,
cteRelations: Seq[(String, CTERelationDef)],
recursiveCTERelation: Option[(String, CTERelationDef)],
unresolvedRelation: UnresolvedRelation): LogicalPlan = {
cteRelations
.find(r => conf.resolver(r._1, table))
.map {
if (recursiveCTERelation.isDefined && conf.resolver(recursiveCTERelation.get._1, table)) {
// self-reference is found
recursiveCTERelation.map {
case (_, d) =>
if (alwaysInline) {
d.child
} else {
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
SubqueryAlias(table, CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}
}
.getOrElse(unresolvedRelation)
SubqueryAlias(table,
CTERelationRef(d.id, d.resolved, d.output, d.isStreaming, recursive = true))
}.get
} else {
cteRelations
.find(r => conf.resolver(r._1, table))
.map {
case (_, d) =>
if (alwaysInline) {
d.child
} else {
// Add a `SubqueryAlias` for hint-resolving rules to match relation names.
// This is a non-recursive reference, recursive parameter is by default set to false
SubqueryAlias(table,
CTERelationRef(d.id, d.resolved, d.output, d.isStreaming))
}
}
.getOrElse(unresolvedRelation)
}
}

/**
* Substitute unresolved relations in the plan with CTE references (CTERelationRef).
*/
private def substituteCTE(
plan: LogicalPlan,
alwaysInline: Boolean,
cteRelations: Seq[(String, CTERelationDef)]): LogicalPlan = {
cteRelations: Seq[(String, CTERelationDef)],
recursiveCTERelation: Option[(String, CTERelationDef)]): LogicalPlan = {
plan.resolveOperatorsUpWithPruning(
_.containsAnyPattern(RELATION_TIME_TRAVEL, UNRESOLVED_RELATION, PLAN_EXPRESSION,
UNRESOLVED_IDENTIFIER)) {
Expand All @@ -298,7 +378,8 @@ object CTESubstitution extends Rule[LogicalPlan] {
throw QueryCompilationErrors.timeTravelUnsupportedError(toSQLId(table))

case u @ UnresolvedRelation(Seq(table), _, _) =>
resolveWithCTERelations(table, alwaysInline, cteRelations, u)
resolveWithCTERelations(table, alwaysInline, cteRelations,
recursiveCTERelation, u)

case p: PlanWithUnresolvedIdentifier =>
// We must look up CTE relations first when resolving `UnresolvedRelation`s,
Expand All @@ -308,7 +389,8 @@ object CTESubstitution extends Rule[LogicalPlan] {
p.copy(planBuilder = (nameParts, children) => {
p.planBuilder.apply(nameParts, children) match {
case u @ UnresolvedRelation(Seq(table), _, _) =>
resolveWithCTERelations(table, alwaysInline, cteRelations, u)
resolveWithCTERelations(table, alwaysInline, cteRelations,
recursiveCTERelation, u)
case other => other
}
})
Expand All @@ -317,7 +399,8 @@ object CTESubstitution extends Rule[LogicalPlan] {
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
other.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
case e: SubqueryExpression =>
e.withNewPlan(apply(substituteCTE(e.plan, alwaysInline, cteRelations)))
e.withNewPlan(
apply(substituteCTE(e.plan, alwaysInline, cteRelations, None)))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {
private def pushdownPredicatesAndAttributes(
plan: LogicalPlan,
cteMap: CTEMap): LogicalPlan = plan.transformWithSubqueries {
case cteDef @ CTERelationDef(child, id, originalPlanWithPredicates, _, _, _) =>
case cteDef @ CTERelationDef(child, id, originalPlanWithPredicates, _, _) =>
val (_, _, newPreds, newAttrSet) = cteMap(id)
val originalPlan = originalPlanWithPredicates.map(_._1).getOrElse(child)
val preds = originalPlanWithPredicates.map(_._2).getOrElse(Seq.empty)
Expand Down Expand Up @@ -170,7 +170,7 @@ object PushdownPredicatesAndPruneColumnsForCTEDef extends Rule[LogicalPlan] {
object CleanUpTempCTEInfo extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan =
plan.transformWithPruning(_.containsPattern(CTE)) {
case cteDef @ CTERelationDef(_, _, Some(_), _, _, _) =>
case cteDef @ CTERelationDef(_, _, Some(_), _, _) =>
cteDef.copy(originalPlanWithPredicates = None)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,7 @@ class AstBuilder extends DataTypeAstBuilder
throw QueryParsingErrors.duplicateCteDefinitionNamesError(
duplicates.map(toSQLId).mkString(", "), ctx)
}
UnresolvedWith(plan, ctes.toSeq)
UnresolvedWith(plan, ctes.toSeq, ctx.RECURSIVE() != null)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,6 @@ case class UnresolvedWith(
* pushdown to help ensure rule idempotency.
* @param underSubquery If true, it means we don't need to add a shuffle for this CTE relation as
* subquery reuse will be applied to reuse CTE relation output.
* @param recursive If true, then this CTE Definition is recursive - it contains a self-reference.
* @param recursionAnchor A helper plan node that temporary stores the anchor term of recursive
* definitions. In the beginning of recursive resolution the `ResolveWithCTE`
* rule updates this parameter and once it is resolved the same rule resolves
Expand All @@ -877,7 +876,6 @@ case class CTERelationDef(
id: Long = CTERelationDef.newId,
originalPlanWithPredicates: Option[(LogicalPlan, Seq[Expression])] = None,
underSubquery: Boolean = false,
recursive: Boolean = false,
recursionAnchor: Option[LogicalPlan] = None) extends UnaryNode {

final override val nodePatterns: Seq[TreePattern] = Seq(CTE)
Expand All @@ -886,6 +884,13 @@ case class CTERelationDef(
copy(child = newChild)

override def output: Seq[Attribute] = if (resolved) child.output else Nil

lazy val recursive: Boolean = child.exists{
// if the reference is found inside the child, referencing to this CTE definition,
// and already marked as recursive, then this CTE definition is recursive.
case CTERelationRef(this.id, _, _, _, _, true) => true
case _ => false
}
}

object CTERelationDef {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ CREATE TABLE cte_tbl USING csv AS WITH s AS (SELECT 42 AS col) SELECT * FROM s
-- !query analysis
CreateDataSourceTableAsSelectCommand `spark_catalog`.`default`.`cte_tbl`, ErrorIfExists, [col]
+- WithCTE
:- CTERelationDef xxxx, false, false
:- CTERelationDef xxxx, false
: +- SubqueryAlias s
: +- Project [42 AS col#x]
: +- OneRowRelation
Expand All @@ -26,7 +26,7 @@ CREATE TEMPORARY VIEW cte_view AS WITH s AS (SELECT 42 AS col) SELECT * FROM s
-- !query analysis
CreateViewCommand `cte_view`, WITH s AS (SELECT 42 AS col) SELECT * FROM s, false, false, LocalTempView, UNSUPPORTED, true
+- WithCTE
:- CTERelationDef xxxx, false, false
:- CTERelationDef xxxx, false
: +- SubqueryAlias s
: +- Project [42 AS col#x]
: +- OneRowRelation
Expand All @@ -43,7 +43,7 @@ Project [col#x]
+- View (`cte_view`, [col#x])
+- Project [cast(col#x as int) AS col#x]
+- WithCTE
:- CTERelationDef xxxx, false, false
:- CTERelationDef xxxx, false
: +- SubqueryAlias s
: +- Project [42 AS col#x]
: +- OneRowRelation
Expand All @@ -58,7 +58,7 @@ INSERT INTO cte_tbl SELECT * FROM S
-- !query analysis
InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/cte_tbl, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/cte_tbl], Append, `spark_catalog`.`default`.`cte_tbl`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/cte_tbl), [col]
+- WithCTE
:- CTERelationDef xxxx, false, false
:- CTERelationDef xxxx, false
: +- SubqueryAlias s
: +- Project [43 AS col#x]
: +- OneRowRelation
Expand All @@ -80,7 +80,7 @@ INSERT INTO cte_tbl WITH s AS (SELECT 44 AS col) SELECT * FROM s
-- !query analysis
InsertIntoHadoopFsRelationCommand file:[not included in comparison]/{warehouse_dir}/cte_tbl, false, CSV, [path=file:[not included in comparison]/{warehouse_dir}/cte_tbl], Append, `spark_catalog`.`default`.`cte_tbl`, org.apache.spark.sql.execution.datasources.InMemoryFileIndex(file:[not included in comparison]/{warehouse_dir}/cte_tbl), [col]
+- WithCTE
:- CTERelationDef xxxx, false, false
:- CTERelationDef xxxx, false
: +- SubqueryAlias s
: +- Project [44 AS col#x]
: +- OneRowRelation
Expand Down
Loading

0 comments on commit 580b3c0

Please sign in to comment.