[SPARK-21485][FOLLOWUP][SQL][DOCS] Describes examples and arguments separately, and...
[spark.git] / sql / catalyst / src / main / scala / org / apache / spark / sql / catalyst / analysis / FunctionRegistry.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.catalyst.analysis
19
20 import java.util.Locale
21 import javax.annotation.concurrent.GuardedBy
22
23 import scala.collection.mutable
24 import scala.language.existentials
25 import scala.reflect.ClassTag
26 import scala.util.{Failure, Success, Try}
27
28 import org.apache.spark.sql.AnalysisException
29 import org.apache.spark.sql.catalyst.FunctionIdentifier
30 import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
31 import org.apache.spark.sql.catalyst.expressions._
32 import org.apache.spark.sql.catalyst.expressions.aggregate._
33 import org.apache.spark.sql.catalyst.expressions.xml._
34 import org.apache.spark.sql.types._
35
36
37 /**
38  * A catalog for looking up user defined functions, used by an [[Analyzer]].
39  *
40  * Note:
41  *   1) The implementation should be thread-safe to allow concurrent access.
42  *   2) the database name is always case-sensitive here, callers are responsible to
43  *      format the database name w.r.t. case-sensitive config.
44  */
45 trait FunctionRegistry {
46
47   final def registerFunction(name: FunctionIdentifier, builder: FunctionBuilder): Unit = {
48     val info = new ExpressionInfo(
49       builder.getClass.getCanonicalName, name.database.orNull, name.funcName)
50     registerFunction(name, info, builder)
51   }
52
53   def registerFunction(
54     name: FunctionIdentifier,
55     info: ExpressionInfo,
56     builder: FunctionBuilder): Unit
57
58   /* Create or replace a temporary function. */
59   final def createOrReplaceTempFunction(name: String, builder: FunctionBuilder): Unit = {
60     registerFunction(
61       FunctionIdentifier(name),
62       builder)
63   }
64
65   @throws[AnalysisException]("If function does not exist")
66   def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression
67
68   /* List all of the registered function names. */
69   def listFunction(): Seq[FunctionIdentifier]
70
71   /* Get the class of the registered function by specified name. */
72   def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo]
73
74   /* Get the builder of the registered function by specified name. */
75   def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder]
76
77   /** Drop a function and return whether the function existed. */
78   def dropFunction(name: FunctionIdentifier): Boolean
79
80   /** Checks if a function with a given name exists. */
81   def functionExists(name: FunctionIdentifier): Boolean = lookupFunction(name).isDefined
82
83   /** Clear all registered functions. */
84   def clear(): Unit
85
86   /** Create a copy of this registry with identical functions as this registry. */
87   override def clone(): FunctionRegistry = throw new CloneNotSupportedException()
88 }
89
90 class SimpleFunctionRegistry extends FunctionRegistry {
91
92   @GuardedBy("this")
93   private val functionBuilders =
94     new mutable.HashMap[FunctionIdentifier, (ExpressionInfo, FunctionBuilder)]
95
96   // Resolution of the function name is always case insensitive, but the database name
97   // depends on the caller
98   private def normalizeFuncName(name: FunctionIdentifier): FunctionIdentifier = {
99     FunctionIdentifier(name.funcName.toLowerCase(Locale.ROOT), name.database)
100   }
101
102   override def registerFunction(
103       name: FunctionIdentifier,
104       info: ExpressionInfo,
105       builder: FunctionBuilder): Unit = synchronized {
106     functionBuilders.put(normalizeFuncName(name), (info, builder))
107   }
108
109   override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
110     val func = synchronized {
111       functionBuilders.get(normalizeFuncName(name)).map(_._2).getOrElse {
112         throw new AnalysisException(s"undefined function $name")
113       }
114     }
115     func(children)
116   }
117
118   override def listFunction(): Seq[FunctionIdentifier] = synchronized {
119     functionBuilders.iterator.map(_._1).toList
120   }
121
122   override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = synchronized {
123     functionBuilders.get(normalizeFuncName(name)).map(_._1)
124   }
125
126   override def lookupFunctionBuilder(
127       name: FunctionIdentifier): Option[FunctionBuilder] = synchronized {
128     functionBuilders.get(normalizeFuncName(name)).map(_._2)
129   }
130
131   override def dropFunction(name: FunctionIdentifier): Boolean = synchronized {
132     functionBuilders.remove(normalizeFuncName(name)).isDefined
133   }
134
135   override def clear(): Unit = synchronized {
136     functionBuilders.clear()
137   }
138
139   override def clone(): SimpleFunctionRegistry = synchronized {
140     val registry = new SimpleFunctionRegistry
141     functionBuilders.iterator.foreach { case (name, (info, builder)) =>
142       registry.registerFunction(name, info, builder)
143     }
144     registry
145   }
146 }
147
148 /**
149  * A trivial catalog that returns an error when a function is requested. Used for testing when all
150  * functions are already filled in and the analyzer needs only to resolve attribute references.
151  */
152 object EmptyFunctionRegistry extends FunctionRegistry {
153   override def registerFunction(
154       name: FunctionIdentifier, info: ExpressionInfo, builder: FunctionBuilder): Unit = {
155     throw new UnsupportedOperationException
156   }
157
158   override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
159     throw new UnsupportedOperationException
160   }
161
162   override def listFunction(): Seq[FunctionIdentifier] = {
163     throw new UnsupportedOperationException
164   }
165
166   override def lookupFunction(name: FunctionIdentifier): Option[ExpressionInfo] = {
167     throw new UnsupportedOperationException
168   }
169
170   override def lookupFunctionBuilder(name: FunctionIdentifier): Option[FunctionBuilder] = {
171     throw new UnsupportedOperationException
172   }
173
174   override def dropFunction(name: FunctionIdentifier): Boolean = {
175     throw new UnsupportedOperationException
176   }
177
178   override def clear(): Unit = {
179     throw new UnsupportedOperationException
180   }
181
182   override def clone(): FunctionRegistry = this
183 }
184
185
186 object FunctionRegistry {
187
188   type FunctionBuilder = Seq[Expression] => Expression
189
190   // Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite
191   val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
192     // misc non-aggregate functions
193     expression[Abs]("abs"),
194     expression[Coalesce]("coalesce"),
195     expression[Explode]("explode"),
196     expressionGeneratorOuter[Explode]("explode_outer"),
197     expression[Greatest]("greatest"),
198     expression[If]("if"),
199     expression[Inline]("inline"),
200     expressionGeneratorOuter[Inline]("inline_outer"),
201     expression[IsNaN]("isnan"),
202     expression[IfNull]("ifnull"),
203     expression[IsNull]("isnull"),
204     expression[IsNotNull]("isnotnull"),
205     expression[Least]("least"),
206     expression[NaNvl]("nanvl"),
207     expression[NullIf]("nullif"),
208     expression[Nvl]("nvl"),
209     expression[Nvl2]("nvl2"),
210     expression[PosExplode]("posexplode"),
211     expressionGeneratorOuter[PosExplode]("posexplode_outer"),
212     expression[Rand]("rand"),
213     expression[Randn]("randn"),
214     expression[Stack]("stack"),
215     expression[CaseWhen]("when"),
216
217     // math functions
218     expression[Acos]("acos"),
219     expression[Asin]("asin"),
220     expression[Atan]("atan"),
221     expression[Atan2]("atan2"),
222     expression[Bin]("bin"),
223     expression[BRound]("bround"),
224     expression[Cbrt]("cbrt"),
225     expression[Ceil]("ceil"),
226     expression[Ceil]("ceiling"),
227     expression[Cos]("cos"),
228     expression[Cosh]("cosh"),
229     expression[Conv]("conv"),
230     expression[ToDegrees]("degrees"),
231     expression[EulerNumber]("e"),
232     expression[Exp]("exp"),
233     expression[Expm1]("expm1"),
234     expression[Floor]("floor"),
235     expression[Factorial]("factorial"),
236     expression[Hex]("hex"),
237     expression[Hypot]("hypot"),
238     expression[Logarithm]("log"),
239     expression[Log10]("log10"),
240     expression[Log1p]("log1p"),
241     expression[Log2]("log2"),
242     expression[Log]("ln"),
243     expression[Remainder]("mod"),
244     expression[UnaryMinus]("negative"),
245     expression[Pi]("pi"),
246     expression[Pmod]("pmod"),
247     expression[UnaryPositive]("positive"),
248     expression[Pow]("pow"),
249     expression[Pow]("power"),
250     expression[ToRadians]("radians"),
251     expression[Rint]("rint"),
252     expression[Round]("round"),
253     expression[ShiftLeft]("shiftleft"),
254     expression[ShiftRight]("shiftright"),
255     expression[ShiftRightUnsigned]("shiftrightunsigned"),
256     expression[Signum]("sign"),
257     expression[Signum]("signum"),
258     expression[Sin]("sin"),
259     expression[Sinh]("sinh"),
260     expression[StringToMap]("str_to_map"),
261     expression[Sqrt]("sqrt"),
262     expression[Tan]("tan"),
263     expression[Cot]("cot"),
264     expression[Tanh]("tanh"),
265
266     expression[Add]("+"),
267     expression[Subtract]("-"),
268     expression[Multiply]("*"),
269     expression[Divide]("/"),
270     expression[Remainder]("%"),
271
272     // aggregate functions
273     expression[HyperLogLogPlusPlus]("approx_count_distinct"),
274     expression[Average]("avg"),
275     expression[Corr]("corr"),
276     expression[Count]("count"),
277     expression[CovPopulation]("covar_pop"),
278     expression[CovSample]("covar_samp"),
279     expression[First]("first"),
280     expression[First]("first_value"),
281     expression[Kurtosis]("kurtosis"),
282     expression[Last]("last"),
283     expression[Last]("last_value"),
284     expression[Max]("max"),
285     expression[Average]("mean"),
286     expression[Min]("min"),
287     expression[Percentile]("percentile"),
288     expression[Skewness]("skewness"),
289     expression[ApproximatePercentile]("percentile_approx"),
290     expression[ApproximatePercentile]("approx_percentile"),
291     expression[StddevSamp]("std"),
292     expression[StddevSamp]("stddev"),
293     expression[StddevPop]("stddev_pop"),
294     expression[StddevSamp]("stddev_samp"),
295     expression[Sum]("sum"),
296     expression[VarianceSamp]("variance"),
297     expression[VariancePop]("var_pop"),
298     expression[VarianceSamp]("var_samp"),
299     expression[CollectList]("collect_list"),
300     expression[CollectSet]("collect_set"),
301     expression[CountMinSketchAgg]("count_min_sketch"),
302
303     // string functions
304     expression[Ascii]("ascii"),
305     expression[Chr]("char"),
306     expression[Chr]("chr"),
307     expression[Base64]("base64"),
308     expression[BitLength]("bit_length"),
309     expression[Length]("char_length"),
310     expression[Length]("character_length"),
311     expression[Concat]("concat"),
312     expression[ConcatWs]("concat_ws"),
313     expression[Decode]("decode"),
314     expression[Elt]("elt"),
315     expression[Encode]("encode"),
316     expression[FindInSet]("find_in_set"),
317     expression[FormatNumber]("format_number"),
318     expression[FormatString]("format_string"),
319     expression[GetJsonObject]("get_json_object"),
320     expression[InitCap]("initcap"),
321     expression[StringInstr]("instr"),
322     expression[Lower]("lcase"),
323     expression[Length]("length"),
324     expression[Levenshtein]("levenshtein"),
325     expression[Like]("like"),
326     expression[Lower]("lower"),
327     expression[OctetLength]("octet_length"),
328     expression[StringLocate]("locate"),
329     expression[StringLPad]("lpad"),
330     expression[StringTrimLeft]("ltrim"),
331     expression[JsonTuple]("json_tuple"),
332     expression[ParseUrl]("parse_url"),
333     expression[StringLocate]("position"),
334     expression[FormatString]("printf"),
335     expression[RegExpExtract]("regexp_extract"),
336     expression[RegExpReplace]("regexp_replace"),
337     expression[StringRepeat]("repeat"),
338     expression[StringReplace]("replace"),
339     expression[StringReverse]("reverse"),
340     expression[RLike]("rlike"),
341     expression[StringRPad]("rpad"),
342     expression[StringTrimRight]("rtrim"),
343     expression[Sentences]("sentences"),
344     expression[SoundEx]("soundex"),
345     expression[StringSpace]("space"),
346     expression[StringSplit]("split"),
347     expression[Substring]("substr"),
348     expression[Substring]("substring"),
349     expression[Left]("left"),
350     expression[Right]("right"),
351     expression[SubstringIndex]("substring_index"),
352     expression[StringTranslate]("translate"),
353     expression[StringTrim]("trim"),
354     expression[Upper]("ucase"),
355     expression[UnBase64]("unbase64"),
356     expression[Unhex]("unhex"),
357     expression[Upper]("upper"),
358     expression[XPathList]("xpath"),
359     expression[XPathBoolean]("xpath_boolean"),
360     expression[XPathDouble]("xpath_double"),
361     expression[XPathDouble]("xpath_number"),
362     expression[XPathFloat]("xpath_float"),
363     expression[XPathInt]("xpath_int"),
364     expression[XPathLong]("xpath_long"),
365     expression[XPathShort]("xpath_short"),
366     expression[XPathString]("xpath_string"),
367
368     // datetime functions
369     expression[AddMonths]("add_months"),
370     expression[CurrentDate]("current_date"),
371     expression[CurrentTimestamp]("current_timestamp"),
372     expression[DateDiff]("datediff"),
373     expression[DateAdd]("date_add"),
374     expression[DateFormatClass]("date_format"),
375     expression[DateSub]("date_sub"),
376     expression[DayOfMonth]("day"),
377     expression[DayOfYear]("dayofyear"),
378     expression[DayOfMonth]("dayofmonth"),
379     expression[FromUnixTime]("from_unixtime"),
380     expression[FromUTCTimestamp]("from_utc_timestamp"),
381     expression[Hour]("hour"),
382     expression[LastDay]("last_day"),
383     expression[Minute]("minute"),
384     expression[Month]("month"),
385     expression[MonthsBetween]("months_between"),
386     expression[NextDay]("next_day"),
387     expression[CurrentTimestamp]("now"),
388     expression[Quarter]("quarter"),
389     expression[Second]("second"),
390     expression[ParseToTimestamp]("to_timestamp"),
391     expression[ParseToDate]("to_date"),
392     expression[ToUnixTimestamp]("to_unix_timestamp"),
393     expression[ToUTCTimestamp]("to_utc_timestamp"),
394     expression[TruncDate]("trunc"),
395     expression[UnixTimestamp]("unix_timestamp"),
396     expression[DayOfWeek]("dayofweek"),
397     expression[WeekOfYear]("weekofyear"),
398     expression[Year]("year"),
399     expression[TimeWindow]("window"),
400
401     // collection functions
402     expression[CreateArray]("array"),
403     expression[ArrayContains]("array_contains"),
404     expression[CreateMap]("map"),
405     expression[CreateNamedStruct]("named_struct"),
406     expression[MapKeys]("map_keys"),
407     expression[MapValues]("map_values"),
408     expression[Size]("size"),
409     expression[SortArray]("sort_array"),
410     CreateStruct.registryEntry,
411
412     // misc functions
413     expression[AssertTrue]("assert_true"),
414     expression[Crc32]("crc32"),
415     expression[Md5]("md5"),
416     expression[Uuid]("uuid"),
417     expression[Murmur3Hash]("hash"),
418     expression[Sha1]("sha"),
419     expression[Sha1]("sha1"),
420     expression[Sha2]("sha2"),
421     expression[SparkPartitionID]("spark_partition_id"),
422     expression[InputFileName]("input_file_name"),
423     expression[InputFileBlockStart]("input_file_block_start"),
424     expression[InputFileBlockLength]("input_file_block_length"),
425     expression[MonotonicallyIncreasingID]("monotonically_increasing_id"),
426     expression[CurrentDatabase]("current_database"),
427     expression[CallMethodViaReflection]("reflect"),
428     expression[CallMethodViaReflection]("java_method"),
429
430     // grouping sets
431     expression[Cube]("cube"),
432     expression[Rollup]("rollup"),
433     expression[Grouping]("grouping"),
434     expression[GroupingID]("grouping_id"),
435
436     // window functions
437     expression[Lead]("lead"),
438     expression[Lag]("lag"),
439     expression[RowNumber]("row_number"),
440     expression[CumeDist]("cume_dist"),
441     expression[NTile]("ntile"),
442     expression[Rank]("rank"),
443     expression[DenseRank]("dense_rank"),
444     expression[PercentRank]("percent_rank"),
445
446     // predicates
447     expression[And]("and"),
448     expression[In]("in"),
449     expression[Not]("not"),
450     expression[Or]("or"),
451
452     // comparison operators
453     expression[EqualNullSafe]("<=>"),
454     expression[EqualTo]("="),
455     expression[EqualTo]("=="),
456     expression[GreaterThan](">"),
457     expression[GreaterThanOrEqual](">="),
458     expression[LessThan]("<"),
459     expression[LessThanOrEqual]("<="),
460     expression[Not]("!"),
461
462     // bitwise
463     expression[BitwiseAnd]("&"),
464     expression[BitwiseNot]("~"),
465     expression[BitwiseOr]("|"),
466     expression[BitwiseXor]("^"),
467
468     // json
469     expression[StructsToJson]("to_json"),
470     expression[JsonToStructs]("from_json"),
471
472     // cast
473     expression[Cast]("cast"),
474     // Cast aliases (SPARK-16730)
475     castAlias("boolean", BooleanType),
476     castAlias("tinyint", ByteType),
477     castAlias("smallint", ShortType),
478     castAlias("int", IntegerType),
479     castAlias("bigint", LongType),
480     castAlias("float", FloatType),
481     castAlias("double", DoubleType),
482     castAlias("decimal", DecimalType.USER_DEFAULT),
483     castAlias("date", DateType),
484     castAlias("timestamp", TimestampType),
485     castAlias("binary", BinaryType),
486     castAlias("string", StringType)
487   )
488
489   val builtin: SimpleFunctionRegistry = {
490     val fr = new SimpleFunctionRegistry
491     expressions.foreach {
492       case (name, (info, builder)) => fr.registerFunction(FunctionIdentifier(name), info, builder)
493     }
494     fr
495   }
496
497   val functionSet: Set[FunctionIdentifier] = builtin.listFunction().toSet
498
499   /** See usage above. */
500   private def expression[T <: Expression](name: String)
501       (implicit tag: ClassTag[T]): (String, (ExpressionInfo, FunctionBuilder)) = {
502
503     // For `RuntimeReplaceable`, skip the constructor with most arguments, which is the main
504     // constructor and contains non-parameter `child` and should not be used as function builder.
505     val constructors = if (classOf[RuntimeReplaceable].isAssignableFrom(tag.runtimeClass)) {
506       val all = tag.runtimeClass.getConstructors
507       val maxNumArgs = all.map(_.getParameterCount).max
508       all.filterNot(_.getParameterCount == maxNumArgs)
509     } else {
510       tag.runtimeClass.getConstructors
511     }
512     // See if we can find a constructor that accepts Seq[Expression]
513     val varargCtor = constructors.find(_.getParameterTypes.toSeq == Seq(classOf[Seq[_]]))
514     val builder = (expressions: Seq[Expression]) => {
515       if (varargCtor.isDefined) {
516         // If there is an apply method that accepts Seq[Expression], use that one.
517         Try(varargCtor.get.newInstance(expressions).asInstanceOf[Expression]) match {
518           case Success(e) => e
519           case Failure(e) =>
520             // the exception is an invocation exception. To get a meaningful message, we need the
521             // cause.
522             throw new AnalysisException(e.getCause.getMessage)
523         }
524       } else {
525         // Otherwise, find a constructor method that matches the number of arguments, and use that.
526         val params = Seq.fill(expressions.size)(classOf[Expression])
527         val f = constructors.find(_.getParameterTypes.toSeq == params).getOrElse {
528           throw new AnalysisException(s"Invalid number of arguments for function $name")
529         }
530         Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match {
531           case Success(e) => e
532           case Failure(e) =>
533             // the exception is an invocation exception. To get a meaningful message, we need the
534             // cause.
535             throw new AnalysisException(e.getCause.getMessage)
536         }
537       }
538     }
539
540     (name, (expressionInfo[T](name), builder))
541   }
542
543   /**
544    * Creates a function registry lookup entry for cast aliases (SPARK-16730).
545    * For example, if name is "int", and dataType is IntegerType, this means int(x) would become
546    * an alias for cast(x as IntegerType).
547    * See usage above.
548    */
549   private def castAlias(
550       name: String,
551       dataType: DataType): (String, (ExpressionInfo, FunctionBuilder)) = {
552     val builder = (args: Seq[Expression]) => {
553       if (args.size != 1) {
554         throw new AnalysisException(s"Function $name accepts only one argument")
555       }
556       Cast(args.head, dataType)
557     }
558     val clazz = scala.reflect.classTag[Cast].runtimeClass
559     val usage = "_FUNC_(expr) - Casts the value `expr` to the target data type `_FUNC_`."
560     val expressionInfo =
561       new ExpressionInfo(clazz.getCanonicalName, null, name, usage, "", "", "", "")
562     (name, (expressionInfo, builder))
563   }
564
565   /**
566    * Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name.
567    */
568   private def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = {
569     val clazz = scala.reflect.classTag[T].runtimeClass
570     val df = clazz.getAnnotation(classOf[ExpressionDescription])
571     if (df != null) {
572       if (df.extended().isEmpty) {
573         new ExpressionInfo(
574           clazz.getCanonicalName,
575           null,
576           name,
577           df.usage(),
578           df.arguments(),
579           df.examples(),
580           df.note(),
581           df.since())
582       } else {
583         // This exists for the backward compatibility with old `ExpressionDescription`s defining
584         // the extended description in `extended()`.
585         new ExpressionInfo(clazz.getCanonicalName, null, name, df.usage(), df.extended())
586       }
587     } else {
588       new ExpressionInfo(clazz.getCanonicalName, name)
589     }
590   }
591
592   private def expressionGeneratorOuter[T <: Generator : ClassTag](name: String)
593     : (String, (ExpressionInfo, FunctionBuilder)) = {
594     val (_, (info, generatorBuilder)) = expression[T](name)
595     val outerBuilder = (args: Seq[Expression]) => {
596       GeneratorOuter(generatorBuilder(args).asInstanceOf[Generator])
597     }
598     (name, (info, outerBuilder))
599   }
600 }