Skip to content

Commit

Permalink
[SPARK-50734][SQL] Add catalog API for creating and registering SQL UDFs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR adds catalog APIs to support the creation and registration of SQL UDFs. It uses Hive Metastore to persist a SQL UDF by deserializing the function information into a FunctionResource and storing it in Hive (toCatalogFunction). During resolution, it retrieves the catalog function and deserializes it into a SQLFunction.

This PR only adds the catalog API, and a subsequent PR will add the analyzer logic to resolve SQL UDFs.

### Why are the changes needed?

To support SQL UDFs in Spark.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests. End to end tests will be added in the next PR once we support SQL UDF resolution.

### Was this patch authored or co-authored using generative AI tooling?

No

Closes #49389 from allisonwang-db/spark-50734-sql-udf-catalog-api.

Authored-by: Allison Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
allisonwang-db authored and cloud-fan committed Jan 7, 2025
1 parent 204c672 commit bba8cf4
Show file tree
Hide file tree
Showing 11 changed files with 624 additions and 90 deletions.
11 changes: 11 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,12 @@
],
"sqlState" : "22018"
},
"CORRUPTED_CATALOG_FUNCTION" : {
"message" : [
"Cannot convert the catalog function '<identifier>' into a SQL function due to corrupted function information in catalog. If the function is not a SQL function, please make sure the class name '<className>' is loadable."
],
"sqlState" : "0A000"
},
"CREATE_PERMANENT_VIEW_WITHOUT_ALIAS" : {
"message" : [
"Not allowed to create the permanent view <name> without explicitly assigning an alias for the expression <attr>."
Expand Down Expand Up @@ -5892,6 +5898,11 @@
"The number of columns produced by the RETURN clause (num: `<outputSize>`) does not match the number of column names specified by the RETURNS clause (num: `<returnParamSize>`) of <name>."
]
},
"ROUTINE_PROPERTY_TOO_LARGE" : {
"message" : [
"Cannot convert user defined routine <name> to catalog function: routine properties are too large."
]
},
"SQL_TABLE_UDF_BODY_MUST_BE_A_QUERY" : {
"message" : [
"SQL table function <name> body must be a query."
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.catalog.SQLFunction
import org.apache.spark.sql.catalyst.expressions.{Expression, Unevaluable}
import org.apache.spark.sql.catalyst.trees.TreePattern.{SQL_FUNCTION_EXPRESSION, TreePattern}
import org.apache.spark.sql.types.DataType

/**
* Represent a SQL function expression resolved from the catalog SQL function builder.
*/
case class SQLFunctionExpression(
name: String,
function: SQLFunction,
inputs: Seq[Expression],
returnType: Option[DataType]) extends Expression with Unevaluable {
override def children: Seq[Expression] = inputs
override def dataType: DataType = returnType.get
override def nullable: Boolean = true
override def prettyName: String = name
override def toString: String = s"$name(${children.mkString(", ")})"
override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): SQLFunctionExpression = copy(inputs = newChildren)
final override val nodePatterns: Seq[TreePattern] = Seq(SQL_FUNCTION_EXPRESSION)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.catalog.SQLFunction
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreePattern.{FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION, SQL_TABLE_FUNCTION, TreePattern}
import org.apache.spark.sql.errors.DataTypeErrors.toSQLId
import org.apache.spark.sql.errors.QueryCompilationErrors

/**
* A container for holding a SQL function query plan and its function identifier.
*
* @param function: the SQL function that this node represents.
* @param child: the SQL function body.
*/
case class SQLFunctionNode(
function: SQLFunction,
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
override def stringArgs: Iterator[Any] = Iterator(function.name, child)
override protected def withNewChildInternal(newChild: LogicalPlan): SQLFunctionNode =
copy(child = newChild)

// Throw a reasonable error message when trying to call a SQL UDF with TABLE argument(s).
if (child.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION)) {
throw QueryCompilationErrors
.tableValuedArgumentsNotYetImplementedForSqlFunctions("call", toSQLId(function.name.funcName))
}
}

/**
* Represent a SQL table function plan resolved from the catalog SQL table function builder.
*/
case class SQLTableFunction(
name: String,
function: SQLFunction,
inputs: Seq[Expression],
override val output: Seq[Attribute]) extends LeafNode {
final override val nodePatterns: Seq[TreePattern] = Seq(SQL_TABLE_FUNCTION)

// Throw a reasonable error message when trying to call a SQL UDF with TABLE argument(s) because
// this functionality is not implemented yet.
if (inputs.exists(_.containsPattern(FUNCTION_TABLE_RELATION_ARGUMENT_EXPRESSION))) {
throw QueryCompilationErrors
.tableValuedArgumentsNotYetImplementedForSqlFunctions("call", toSQLId(name))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import scala.collection.mutable
import org.json4s.JsonAST.{JArray, JString}
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.SparkException
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.catalog.UserDefinedFunction._
import org.apache.spark.sql.catalyst.expressions.{Expression, ScalarSubquery}
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, ScalarSubquery}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation, Project}
import org.apache.spark.sql.types.{DataType, StructType}
Expand Down Expand Up @@ -62,6 +64,8 @@ case class SQLFunction(
assert(exprText.nonEmpty || queryText.nonEmpty)
assert((isTableFunc && returnType.isRight) || (!isTableFunc && returnType.isLeft))

import SQLFunction._

override val language: RoutineLanguage = LanguageSQL

/**
Expand All @@ -88,19 +92,141 @@ case class SQLFunction(
(parsedExpression, parsedQuery)
}
}

/** Get scalar function return data type. */
def getScalarFuncReturnType: DataType = returnType match {
case Left(dataType) => dataType
case Right(_) =>
throw SparkException.internalError(
"This function is a table function, not a scalar function.")
}

/** Get table function return columns. */
def getTableFuncReturnCols: StructType = returnType match {
case Left(_) =>
throw SparkException.internalError(
"This function is a scalar function, not a table function.")
case Right(columns) => columns
}

/**
* Convert the SQL function to a [[CatalogFunction]].
*/
def toCatalogFunction: CatalogFunction = {
val props = sqlFunctionToProps ++ properties
CatalogFunction(
identifier = name,
className = SQL_FUNCTION_PREFIX,
resources = propertiesToFunctionResources(props, name))
}

/**
* Convert the SQL function to an [[ExpressionInfo]].
*/
def toExpressionInfo: ExpressionInfo = {
val props = sqlFunctionToProps ++ functionMetadataToProps ++ properties
val usage = mapper.writeValueAsString(props)
new ExpressionInfo(
SQL_FUNCTION_PREFIX,
name.database.orNull,
name.funcName,
usage,
"",
"",
"",
"",
"",
"",
"sql_udf")
}

/**
* Convert the SQL function fields into properties.
*/
private def sqlFunctionToProps: Map[String, String] = {
val props = new mutable.HashMap[String, String]
val inputParamText = inputParam.map(_.fields.map(_.toDDL).mkString(", "))
inputParamText.foreach(props.put(INPUT_PARAM, _))
val returnTypeText = returnType match {
case Left(dataType) => dataType.sql
case Right(columns) => columns.toDDL
}
props.put(RETURN_TYPE, returnTypeText)
exprText.foreach(props.put(EXPRESSION, _))
queryText.foreach(props.put(QUERY, _))
comment.foreach(props.put(COMMENT, _))
deterministic.foreach(d => props.put(DETERMINISTIC, d.toString))
containsSQL.foreach(x => props.put(CONTAINS_SQL, x.toString))
props.put(IS_TABLE_FUNC, isTableFunc.toString)
props.toMap
}

private def functionMetadataToProps: Map[String, String] = {
val props = new mutable.HashMap[String, String]
owner.foreach(props.put(OWNER, _))
props.put(CREATE_TIME, createTimeMs.toString)
props.toMap
}
}

object SQLFunction {

private val SQL_FUNCTION_PREFIX = "sqlFunction."

private val INPUT_PARAM: String = SQL_FUNCTION_PREFIX + "inputParam"
private val RETURN_TYPE: String = SQL_FUNCTION_PREFIX + "returnType"
private val EXPRESSION: String = SQL_FUNCTION_PREFIX + "expression"
private val QUERY: String = SQL_FUNCTION_PREFIX + "query"
private val COMMENT: String = SQL_FUNCTION_PREFIX + "comment"
private val DETERMINISTIC: String = SQL_FUNCTION_PREFIX + "deterministic"
private val CONTAINS_SQL: String = SQL_FUNCTION_PREFIX + "containsSQL"
private val IS_TABLE_FUNC: String = SQL_FUNCTION_PREFIX + "isTableFunc"
private val OWNER: String = SQL_FUNCTION_PREFIX + "owner"
private val CREATE_TIME: String = SQL_FUNCTION_PREFIX + "createTime"

private val FUNCTION_CATALOG_AND_NAMESPACE = "catalogAndNamespace.numParts"
private val FUNCTION_CATALOG_AND_NAMESPACE_PART_PREFIX = "catalogAndNamespace.part."

private val FUNCTION_REFERRED_TEMP_VIEW_NAMES = "referredTempViewNames"
private val FUNCTION_REFERRED_TEMP_FUNCTION_NAMES = "referredTempFunctionsNames"
private val FUNCTION_REFERRED_TEMP_VARIABLE_NAMES = "referredTempVariableNames"

/**
* Convert a [[CatalogFunction]] into a SQL function.
*/
def fromCatalogFunction(function: CatalogFunction, parser: ParserInterface): SQLFunction = {
try {
val parts = function.resources.collect { case FunctionResource(FileResource, uri) =>
val index = uri.substring(0, INDEX_LENGTH).toInt
val body = uri.substring(INDEX_LENGTH)
index -> body
}
val blob = parts.sortBy(_._1).map(_._2).mkString
val props = mapper.readValue(blob, classOf[Map[String, String]])
val isTableFunc = props(IS_TABLE_FUNC).toBoolean
val returnType = parseReturnTypeText(props(RETURN_TYPE), isTableFunc, parser)
SQLFunction(
name = function.identifier,
inputParam = props.get(INPUT_PARAM).map(parseTableSchema(_, parser)),
returnType = returnType.get,
exprText = props.get(EXPRESSION),
queryText = props.get(QUERY),
comment = props.get(COMMENT),
deterministic = props.get(DETERMINISTIC).map(_.toBoolean),
containsSQL = props.get(CONTAINS_SQL).map(_.toBoolean),
isTableFunc = isTableFunc,
props.filterNot(_._1.startsWith(SQL_FUNCTION_PREFIX)))
} catch {
case e: Exception =>
throw new AnalysisException(
errorClass = "CORRUPTED_CATALOG_FUNCTION",
messageParameters = Map(
"identifier" -> s"${function.identifier}",
"className" -> s"${function.className}"), cause = Some(e)
)
}
}

def parseDefault(text: String, parser: ParserInterface): Expression = {
parser.parseExpression(text)
}
Expand Down
Loading

0 comments on commit bba8cf4

Please sign in to comment.