[SPARK-21485][FOLLOWUP][SQL][DOCS] Describes examples and arguments separately, and...
[spark.git] / sql / catalyst / src / main / scala / org / apache / spark / sql / catalyst / expressions / mathExpressions.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.catalyst.expressions
19
20 import java.{lang => jl}
21 import java.util.Locale
22
23 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
24 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
25 import org.apache.spark.sql.catalyst.expressions.codegen._
26 import org.apache.spark.sql.catalyst.InternalRow
27 import org.apache.spark.sql.catalyst.util.NumberConverter
28 import org.apache.spark.sql.types._
29 import org.apache.spark.unsafe.types.UTF8String
30
31 /**
32  * A leaf expression specifically for math constants. Math constants expect no input.
33  *
34  * There is no code generation because they should get constant folded by the optimizer.
35  *
36  * @param c The math constant.
37  * @param name The short name of the function
38  */
39 abstract class LeafMathExpression(c: Double, name: String)
40   extends LeafExpression with CodegenFallback with Serializable {
41
42   override def dataType: DataType = DoubleType
43   override def foldable: Boolean = true
44   override def nullable: Boolean = false
45   override def toString: String = s"$name()"
46   override def prettyName: String = name
47
48   override def eval(input: InternalRow): Any = c
49 }
50
51 /**
52  * A unary expression specifically for math functions. Math Functions expect a specific type of
53  * input format, therefore these functions extend `ExpectsInputTypes`.
54  *
55  * @param f The math function.
56  * @param name The short name of the function
57  */
58 abstract class UnaryMathExpression(val f: Double => Double, name: String)
59   extends UnaryExpression with Serializable with ImplicitCastInputTypes {
60
61   override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType)
62   override def dataType: DataType = DoubleType
63   override def nullable: Boolean = true
64   override def toString: String = s"$name($child)"
65   override def prettyName: String = name
66
67   protected override def nullSafeEval(input: Any): Any = {
68     f(input.asInstanceOf[Double])
69   }
70
71   // name of function in java.lang.Math
72   def funcName: String = name.toLowerCase(Locale.ROOT)
73
74   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
75     defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)")
76   }
77 }
78
79 abstract class UnaryLogExpression(f: Double => Double, name: String)
80     extends UnaryMathExpression(f, name) {
81
82   override def nullable: Boolean = true
83
84   // values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity
85   protected val yAsymptote: Double = 0.0
86
87   protected override def nullSafeEval(input: Any): Any = {
88     val d = input.asInstanceOf[Double]
89     if (d <= yAsymptote) null else f(d)
90   }
91
92   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
93     nullSafeCodeGen(ctx, ev, c =>
94       s"""
95         if ($c <= $yAsymptote) {
96           ${ev.isNull} = true;
97         } else {
98           ${ev.value} = java.lang.Math.${funcName}($c);
99         }
100       """
101     )
102   }
103 }
104
105 /**
106  * A binary expression specifically for math functions that take two `Double`s as input and returns
107  * a `Double`.
108  *
109  * @param f The math function.
110  * @param name The short name of the function
111  */
112 abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
113   extends BinaryExpression with Serializable with ImplicitCastInputTypes {
114
115   override def inputTypes: Seq[DataType] = Seq(DoubleType, DoubleType)
116
117   override def toString: String = s"$name($left, $right)"
118
119   override def prettyName: String = name
120
121   override def dataType: DataType = DoubleType
122
123   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
124     f(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
125   }
126
127   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
128     defineCodeGen(ctx, ev, (c1, c2) =>
129       s"java.lang.Math.${name.toLowerCase(Locale.ROOT)}($c1, $c2)")
130   }
131 }
132
133 ////////////////////////////////////////////////////////////////////////////////////////////////////
134 ////////////////////////////////////////////////////////////////////////////////////////////////////
135 // Leaf math functions
136 ////////////////////////////////////////////////////////////////////////////////////////////////////
137 ////////////////////////////////////////////////////////////////////////////////////////////////////
138
139 /**
140  * Euler's number. Note that there is no code generation because this is only
141  * evaluated by the optimizer during constant folding.
142  */
143 @ExpressionDescription(
144   usage = "_FUNC_() - Returns Euler's number, e.",
145   examples = """
146     Examples:
147       > SELECT _FUNC_();
148        2.718281828459045
149   """)
150 case class EulerNumber() extends LeafMathExpression(math.E, "E")
151
152 /**
153  * Pi. Note that there is no code generation because this is only
154  * evaluated by the optimizer during constant folding.
155  */
156 @ExpressionDescription(
157   usage = "_FUNC_() - Returns pi.",
158   examples = """
159     Examples:
160       > SELECT _FUNC_();
161        3.141592653589793
162   """)
163 case class Pi() extends LeafMathExpression(math.Pi, "PI")
164
165 ////////////////////////////////////////////////////////////////////////////////////////////////////
166 ////////////////////////////////////////////////////////////////////////////////////////////////////
167 // Unary math functions
168 ////////////////////////////////////////////////////////////////////////////////////////////////////
169 ////////////////////////////////////////////////////////////////////////////////////////////////////
170
171 // scalastyle:off line.size.limit
172 @ExpressionDescription(
173   usage = "_FUNC_(expr) - Returns the inverse cosine (a.k.a. arccosine) of `expr` if -1<=`expr`<=1 or NaN otherwise.",
174   examples = """
175     Examples:
176       > SELECT _FUNC_(1);
177        0.0
178       > SELECT _FUNC_(2);
179        NaN
180   """)
181 // scalastyle:on line.size.limit
182 case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS")
183
184 // scalastyle:off line.size.limit
185 @ExpressionDescription(
186   usage = "_FUNC_(expr) - Returns the inverse sine (a.k.a. arcsine) the arc sin of `expr` if -1<=`expr`<=1 or NaN otherwise.",
187   examples = """
188     Examples:
189       > SELECT _FUNC_(0);
190        0.0
191       > SELECT _FUNC_(2);
192        NaN
193   """)
194 // scalastyle:on line.size.limit
195 case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN")
196
197 // scalastyle:off line.size.limit
198 @ExpressionDescription(
199   usage = "_FUNC_(expr) - Returns the inverse tangent (a.k.a. arctangent).",
200   examples = """
201     Examples:
202       > SELECT _FUNC_(0);
203        0.0
204   """)
205 // scalastyle:on line.size.limit
206 case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN")
207
208 @ExpressionDescription(
209   usage = "_FUNC_(expr) - Returns the cube root of `expr`.",
210   examples = """
211     Examples:
212       > SELECT _FUNC_(27.0);
213        3.0
214   """)
215 case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT")
216
217 @ExpressionDescription(
218   usage = "_FUNC_(expr) - Returns the smallest integer not smaller than `expr`.",
219   examples = """
220     Examples:
221       > SELECT _FUNC_(-0.1);
222        0
223       > SELECT _FUNC_(5);
224        5
225   """)
226 case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") {
227   override def dataType: DataType = child.dataType match {
228     case dt @ DecimalType.Fixed(_, 0) => dt
229     case DecimalType.Fixed(precision, scale) =>
230       DecimalType.bounded(precision - scale + 1, 0)
231     case _ => LongType
232   }
233
234   override def inputTypes: Seq[AbstractDataType] =
235     Seq(TypeCollection(DoubleType, DecimalType, LongType))
236
237   protected override def nullSafeEval(input: Any): Any = child.dataType match {
238     case LongType => input.asInstanceOf[Long]
239     case DoubleType => f(input.asInstanceOf[Double]).toLong
240     case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].ceil
241   }
242
243   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
244     child.dataType match {
245       case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
246       case DecimalType.Fixed(_, _) =>
247         defineCodeGen(ctx, ev, c => s"$c.ceil()")
248       case LongType => defineCodeGen(ctx, ev, c => s"$c")
249       case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
250     }
251   }
252 }
253
254 @ExpressionDescription(
255   usage = "_FUNC_(expr) - Returns the cosine of `expr`.",
256   examples = """
257     Examples:
258       > SELECT _FUNC_(0);
259        1.0
260   """)
261 case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS")
262
263 @ExpressionDescription(
264   usage = "_FUNC_(expr) - Returns the hyperbolic cosine of `expr`.",
265   examples = """
266     Examples:
267       > SELECT _FUNC_(0);
268        1.0
269   """)
270 case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH")
271
272 /**
273  * Convert a num from one base to another
274  *
275  * @param numExpr the number to be converted
276  * @param fromBaseExpr from which base
277  * @param toBaseExpr to which base
278  */
279 @ExpressionDescription(
280   usage = "_FUNC_(num, from_base, to_base) - Convert `num` from `from_base` to `to_base`.",
281   examples = """
282     Examples:
283       > SELECT _FUNC_('100', 2, 10);
284        4
285       > SELECT _FUNC_(-10, 16, -10);
286        -16
287   """)
288 case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
289   extends TernaryExpression with ImplicitCastInputTypes {
290
291   override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr)
292   override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
293   override def dataType: DataType = StringType
294   override def nullable: Boolean = true
295
296   override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = {
297     NumberConverter.convert(
298       num.asInstanceOf[UTF8String].getBytes,
299       fromBase.asInstanceOf[Int],
300       toBase.asInstanceOf[Int])
301   }
302
303   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
304     val numconv = NumberConverter.getClass.getName.stripSuffix("$")
305     nullSafeCodeGen(ctx, ev, (num, from, to) =>
306       s"""
307        ${ev.value} = $numconv.convert($num.getBytes(), $from, $to);
308        if (${ev.value} == null) {
309          ${ev.isNull} = true;
310        }
311        """
312     )
313   }
314 }
315
316 @ExpressionDescription(
317   usage = "_FUNC_(expr) - Returns e to the power of `expr`.",
318   examples = """
319     Examples:
320       > SELECT _FUNC_(0);
321        1.0
322   """)
323 case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP")
324
325 @ExpressionDescription(
326   usage = "_FUNC_(expr) - Returns exp(`expr`) - 1.",
327   examples = """
328     Examples:
329       > SELECT _FUNC_(0);
330        0.0
331   """)
332 case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1")
333
334 @ExpressionDescription(
335   usage = "_FUNC_(expr) - Returns the largest integer not greater than `expr`.",
336   examples = """
337     Examples:
338       > SELECT _FUNC_(-0.1);
339        -1
340       > SELECT _FUNC_(5);
341        5
342   """)
343 case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") {
344   override def dataType: DataType = child.dataType match {
345     case dt @ DecimalType.Fixed(_, 0) => dt
346     case DecimalType.Fixed(precision, scale) =>
347       DecimalType.bounded(precision - scale + 1, 0)
348     case _ => LongType
349   }
350
351   override def inputTypes: Seq[AbstractDataType] =
352     Seq(TypeCollection(DoubleType, DecimalType, LongType))
353
354   protected override def nullSafeEval(input: Any): Any = child.dataType match {
355     case LongType => input.asInstanceOf[Long]
356     case DoubleType => f(input.asInstanceOf[Double]).toLong
357     case DecimalType.Fixed(_, _) => input.asInstanceOf[Decimal].floor
358   }
359
360   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
361     child.dataType match {
362       case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c")
363       case DecimalType.Fixed(_, _) =>
364         defineCodeGen(ctx, ev, c => s"$c.floor()")
365       case LongType => defineCodeGen(ctx, ev, c => s"$c")
366       case _ => defineCodeGen(ctx, ev, c => s"(long)(java.lang.Math.${funcName}($c))")
367     }
368   }
369 }
370
371 object Factorial {
372
373   def factorial(n: Int): Long = {
374     if (n < factorials.length) factorials(n) else Long.MaxValue
375   }
376
377   private val factorials: Array[Long] = Array[Long](
378     1,
379     1,
380     2,
381     6,
382     24,
383     120,
384     720,
385     5040,
386     40320,
387     362880,
388     3628800,
389     39916800,
390     479001600,
391     6227020800L,
392     87178291200L,
393     1307674368000L,
394     20922789888000L,
395     355687428096000L,
396     6402373705728000L,
397     121645100408832000L,
398     2432902008176640000L
399   )
400 }
401
402 @ExpressionDescription(
403   usage = "_FUNC_(expr) - Returns the factorial of `expr`. `expr` is [0..20]. Otherwise, null.",
404   examples = """
405     Examples:
406       > SELECT _FUNC_(5);
407        120
408   """)
409 case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
410
411   override def inputTypes: Seq[DataType] = Seq(IntegerType)
412
413   override def dataType: DataType = LongType
414
415   // If the value not in the range of [0, 20], it still will be null, so set it to be true here.
416   override def nullable: Boolean = true
417
418   protected override def nullSafeEval(input: Any): Any = {
419     val value = input.asInstanceOf[jl.Integer]
420     if (value > 20 || value < 0) {
421       null
422     } else {
423       Factorial.factorial(value)
424     }
425   }
426
427   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
428     nullSafeCodeGen(ctx, ev, eval => {
429       s"""
430         if ($eval > 20 || $eval < 0) {
431           ${ev.isNull} = true;
432         } else {
433           ${ev.value} =
434             org.apache.spark.sql.catalyst.expressions.Factorial.factorial($eval);
435         }
436       """
437     })
438   }
439 }
440
441 @ExpressionDescription(
442   usage = "_FUNC_(expr) - Returns the natural logarithm (base e) of `expr`.",
443   examples = """
444     Examples:
445       > SELECT _FUNC_(1);
446        0.0
447   """)
448 case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG")
449
450 @ExpressionDescription(
451   usage = "_FUNC_(expr) - Returns the logarithm of `expr` with base 2.",
452   examples = """
453     Examples:
454       > SELECT _FUNC_(2);
455        1.0
456   """)
457 case class Log2(child: Expression)
458   extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") {
459   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
460     nullSafeCodeGen(ctx, ev, c =>
461       s"""
462         if ($c <= $yAsymptote) {
463           ${ev.isNull} = true;
464         } else {
465           ${ev.value} = java.lang.Math.log($c) / java.lang.Math.log(2);
466         }
467       """
468     )
469   }
470 }
471
472 @ExpressionDescription(
473   usage = "_FUNC_(expr) - Returns the logarithm of `expr` with base 10.",
474   examples = """
475     Examples:
476       > SELECT _FUNC_(10);
477        1.0
478   """)
479 case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10")
480
481 @ExpressionDescription(
482   usage = "_FUNC_(expr) - Returns log(1 + `expr`).",
483   examples = """
484     Examples:
485       > SELECT _FUNC_(0);
486        0.0
487   """)
488 case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") {
489   protected override val yAsymptote: Double = -1.0
490 }
491
492 // scalastyle:off line.size.limit
493 @ExpressionDescription(
494   usage = "_FUNC_(expr) - Returns the double value that is closest in value to the argument and is equal to a mathematical integer.",
495   examples = """
496     Examples:
497       > SELECT _FUNC_(12.3456);
498        12.0
499   """)
500 // scalastyle:on line.size.limit
501 case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") {
502   override def funcName: String = "rint"
503 }
504
505 @ExpressionDescription(
506   usage = "_FUNC_(expr) - Returns -1.0, 0.0 or 1.0 as `expr` is negative, 0 or positive.",
507   examples = """
508     Examples:
509       > SELECT _FUNC_(40);
510        1.0
511   """)
512 case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM")
513
514 @ExpressionDescription(
515   usage = "_FUNC_(expr) - Returns the sine of `expr`.",
516   examples = """
517     Examples:
518       > SELECT _FUNC_(0);
519        0.0
520   """)
521 case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN")
522
523 @ExpressionDescription(
524   usage = "_FUNC_(expr) - Returns the hyperbolic sine of `expr`.",
525   examples = """
526     Examples:
527       > SELECT _FUNC_(0);
528        0.0
529   """)
530 case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH")
531
532 @ExpressionDescription(
533   usage = "_FUNC_(expr) - Returns the square root of `expr`.",
534   examples = """
535     Examples:
536       > SELECT _FUNC_(4);
537        2.0
538   """)
539 case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT")
540
541 @ExpressionDescription(
542   usage = "_FUNC_(expr) - Returns the tangent of `expr`.",
543   examples = """
544     Examples:
545       > SELECT _FUNC_(0);
546        0.0
547   """)
548 case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN")
549
550 @ExpressionDescription(
551   usage = "_FUNC_(expr) - Returns the cotangent of `expr`.",
552   examples = """
553     Examples:
554       > SELECT _FUNC_(1);
555        0.6420926159343306
556   """)
557 case class Cot(child: Expression)
558   extends UnaryMathExpression((x: Double) => 1 / math.tan(x), "COT") {
559   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
560     defineCodeGen(ctx, ev, c => s"${ev.value} = 1 / java.lang.Math.tan($c);")
561   }
562 }
563
564 @ExpressionDescription(
565   usage = "_FUNC_(expr) - Returns the hyperbolic tangent of `expr`.",
566   examples = """
567     Examples:
568       > SELECT _FUNC_(0);
569        0.0
570   """)
571 case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH")
572
573 @ExpressionDescription(
574   usage = "_FUNC_(expr) - Converts radians to degrees.",
575   examples = """
576     Examples:
577       > SELECT _FUNC_(3.141592653589793);
578        180.0
579   """)
580 case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") {
581   override def funcName: String = "toDegrees"
582 }
583
584 @ExpressionDescription(
585   usage = "_FUNC_(expr) - Converts degrees to radians.",
586   examples = """
587     Examples:
588       > SELECT _FUNC_(180);
589        3.141592653589793
590   """)
591 case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") {
592   override def funcName: String = "toRadians"
593 }
594
595 // scalastyle:off line.size.limit
596 @ExpressionDescription(
597   usage = "_FUNC_(expr) - Returns the string representation of the long value `expr` represented in binary.",
598   examples = """
599     Examples:
600       > SELECT _FUNC_(13);
601        1101
602       > SELECT _FUNC_(-13);
603        1111111111111111111111111111111111111111111111111111111111110011
604       > SELECT _FUNC_(13.3);
605        1101
606   """)
607 // scalastyle:on line.size.limit
608 case class Bin(child: Expression)
609   extends UnaryExpression with Serializable with ImplicitCastInputTypes {
610
611   override def inputTypes: Seq[DataType] = Seq(LongType)
612   override def dataType: DataType = StringType
613
614   protected override def nullSafeEval(input: Any): Any =
615     UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long]))
616
617   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
618     defineCodeGen(ctx, ev, (c) =>
619       s"UTF8String.fromString(java.lang.Long.toBinaryString($c))")
620   }
621 }
622
623 object Hex {
624   val hexDigits = Array[Char](
625     '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F'
626   ).map(_.toByte)
627
628   // lookup table to translate '0' -> 0 ... 'F'/'f' -> 15
629   val unhexDigits = {
630     val array = Array.fill[Byte](128)(-1)
631     (0 to 9).foreach(i => array('0' + i) = i.toByte)
632     (0 to 5).foreach(i => array('A' + i) = (i + 10).toByte)
633     (0 to 5).foreach(i => array('a' + i) = (i + 10).toByte)
634     array
635   }
636
637   def hex(bytes: Array[Byte]): UTF8String = {
638     val length = bytes.length
639     val value = new Array[Byte](length * 2)
640     var i = 0
641     while (i < length) {
642       value(i * 2) = Hex.hexDigits((bytes(i) & 0xF0) >> 4)
643       value(i * 2 + 1) = Hex.hexDigits(bytes(i) & 0x0F)
644       i += 1
645     }
646     UTF8String.fromBytes(value)
647   }
648
649   def hex(num: Long): UTF8String = {
650     // Extract the hex digits of num into value[] from right to left
651     val value = new Array[Byte](16)
652     var numBuf = num
653     var len = 0
654     do {
655       len += 1
656       value(value.length - len) = Hex.hexDigits((numBuf & 0xF).toInt)
657       numBuf >>>= 4
658     } while (numBuf != 0)
659     UTF8String.fromBytes(java.util.Arrays.copyOfRange(value, value.length - len, value.length))
660   }
661
662   def unhex(bytes: Array[Byte]): Array[Byte] = {
663     val out = new Array[Byte]((bytes.length + 1) >> 1)
664     var i = 0
665     if ((bytes.length & 0x01) != 0) {
666       // padding with '0'
667       if (bytes(0) < 0) {
668         return null
669       }
670       val v = Hex.unhexDigits(bytes(0))
671       if (v == -1) {
672         return null
673       }
674       out(0) = v
675       i += 1
676     }
677     // two characters form the hex value.
678     while (i < bytes.length) {
679       if (bytes(i) < 0 || bytes(i + 1) < 0) {
680         return null
681       }
682       val first = Hex.unhexDigits(bytes(i))
683       val second = Hex.unhexDigits(bytes(i + 1))
684       if (first == -1 || second == -1) {
685         return null
686       }
687       out(i / 2) = (((first << 4) | second) & 0xFF).toByte
688       i += 2
689     }
690     out
691   }
692 }
693
694 /**
695  * If the argument is an INT or binary, hex returns the number as a STRING in hexadecimal format.
696  * Otherwise if the number is a STRING, it converts each character into its hex representation
697  * and returns the resulting STRING. Negative numbers would be treated as two's complement.
698  */
699 @ExpressionDescription(
700   usage = "_FUNC_(expr) - Converts `expr` to hexadecimal.",
701   examples = """
702     Examples:
703       > SELECT _FUNC_(17);
704        11
705       > SELECT _FUNC_('Spark SQL');
706        537061726B2053514C
707   """)
708 case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
709
710   override def inputTypes: Seq[AbstractDataType] =
711     Seq(TypeCollection(LongType, BinaryType, StringType))
712
713   override def dataType: DataType = StringType
714
715   protected override def nullSafeEval(num: Any): Any = child.dataType match {
716     case LongType => Hex.hex(num.asInstanceOf[Long])
717     case BinaryType => Hex.hex(num.asInstanceOf[Array[Byte]])
718     case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes)
719   }
720
721   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
722     nullSafeCodeGen(ctx, ev, (c) => {
723       val hex = Hex.getClass.getName.stripSuffix("$")
724       s"${ev.value} = " + (child.dataType match {
725         case StringType => s"""$hex.hex($c.getBytes());"""
726         case _ => s"""$hex.hex($c);"""
727       })
728     })
729   }
730 }
731
732 /**
733  * Performs the inverse operation of HEX.
734  * Resulting characters are returned as a byte array.
735  */
736 @ExpressionDescription(
737   usage = "_FUNC_(expr) - Converts hexadecimal `expr` to binary.",
738   examples = """
739     Examples:
740       > SELECT decode(_FUNC_('537061726B2053514C'), 'UTF-8');
741        Spark SQL
742   """)
743 case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
744
745   override def inputTypes: Seq[AbstractDataType] = Seq(StringType)
746
747   override def nullable: Boolean = true
748   override def dataType: DataType = BinaryType
749
750   protected override def nullSafeEval(num: Any): Any =
751     Hex.unhex(num.asInstanceOf[UTF8String].getBytes)
752
753   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
754     nullSafeCodeGen(ctx, ev, (c) => {
755       val hex = Hex.getClass.getName.stripSuffix("$")
756       s"""
757         ${ev.value} = $hex.unhex($c.getBytes());
758         ${ev.isNull} = ${ev.value} == null;
759        """
760     })
761   }
762 }
763
764
765 ////////////////////////////////////////////////////////////////////////////////////////////////////
766 ////////////////////////////////////////////////////////////////////////////////////////////////////
767 // Binary math functions
768 ////////////////////////////////////////////////////////////////////////////////////////////////////
769 ////////////////////////////////////////////////////////////////////////////////////////////////////
770
771 // scalastyle:off line.size.limit
772 @ExpressionDescription(
773   usage = "_FUNC_(expr1, expr2) - Returns the angle in radians between the positive x-axis of a plane and the point given by the coordinates (`expr1`, `expr2`).",
774   examples = """
775     Examples:
776       > SELECT _FUNC_(0, 0);
777        0.0
778   """)
779 // scalastyle:on line.size.limit
780 case class Atan2(left: Expression, right: Expression)
781   extends BinaryMathExpression(math.atan2, "ATAN2") {
782
783   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
784     // With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
785     math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0)
786   }
787
788   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
789     defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)")
790   }
791 }
792
793 @ExpressionDescription(
794   usage = "_FUNC_(expr1, expr2) - Raises `expr1` to the power of `expr2`.",
795   examples = """
796     Examples:
797       > SELECT _FUNC_(2, 3);
798        8.0
799   """)
800 case class Pow(left: Expression, right: Expression)
801   extends BinaryMathExpression(math.pow, "POWER") {
802   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
803     defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)")
804   }
805 }
806
807
808 /**
809  * Bitwise left shift.
810  *
811  * @param left the base number to shift.
812  * @param right number of bits to left shift.
813  */
814 @ExpressionDescription(
815   usage = "_FUNC_(base, expr) - Bitwise left shift.",
816   examples = """
817     Examples:
818       > SELECT _FUNC_(2, 1);
819        4
820   """)
821 case class ShiftLeft(left: Expression, right: Expression)
822   extends BinaryExpression with ImplicitCastInputTypes {
823
824   override def inputTypes: Seq[AbstractDataType] =
825     Seq(TypeCollection(IntegerType, LongType), IntegerType)
826
827   override def dataType: DataType = left.dataType
828
829   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
830     input1 match {
831       case l: jl.Long => l << input2.asInstanceOf[jl.Integer]
832       case i: jl.Integer => i << input2.asInstanceOf[jl.Integer]
833     }
834   }
835
836   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
837     defineCodeGen(ctx, ev, (left, right) => s"$left << $right")
838   }
839 }
840
841
842 /**
843  * Bitwise (signed) right shift.
844  *
845  * @param left the base number to shift.
846  * @param right number of bits to right shift.
847  */
848 @ExpressionDescription(
849   usage = "_FUNC_(base, expr) - Bitwise (signed) right shift.",
850   examples = """
851     Examples:
852       > SELECT _FUNC_(4, 1);
853        2
854   """)
855 case class ShiftRight(left: Expression, right: Expression)
856   extends BinaryExpression with ImplicitCastInputTypes {
857
858   override def inputTypes: Seq[AbstractDataType] =
859     Seq(TypeCollection(IntegerType, LongType), IntegerType)
860
861   override def dataType: DataType = left.dataType
862
863   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
864     input1 match {
865       case l: jl.Long => l >> input2.asInstanceOf[jl.Integer]
866       case i: jl.Integer => i >> input2.asInstanceOf[jl.Integer]
867     }
868   }
869
870   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
871     defineCodeGen(ctx, ev, (left, right) => s"$left >> $right")
872   }
873 }
874
875
876 /**
877  * Bitwise unsigned right shift, for integer and long data type.
878  *
879  * @param left the base number.
880  * @param right the number of bits to right shift.
881  */
882 @ExpressionDescription(
883   usage = "_FUNC_(base, expr) - Bitwise unsigned right shift.",
884   examples = """
885     Examples:
886       > SELECT _FUNC_(4, 1);
887        2
888   """)
889 case class ShiftRightUnsigned(left: Expression, right: Expression)
890   extends BinaryExpression with ImplicitCastInputTypes {
891
892   override def inputTypes: Seq[AbstractDataType] =
893     Seq(TypeCollection(IntegerType, LongType), IntegerType)
894
895   override def dataType: DataType = left.dataType
896
897   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
898     input1 match {
899       case l: jl.Long => l >>> input2.asInstanceOf[jl.Integer]
900       case i: jl.Integer => i >>> input2.asInstanceOf[jl.Integer]
901     }
902   }
903
904   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
905     defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right")
906   }
907 }
908
909 @ExpressionDescription(
910   usage = "_FUNC_(expr1, expr2) - Returns sqrt(`expr1`**2 + `expr2`**2).",
911   examples = """
912     Examples:
913       > SELECT _FUNC_(3, 4);
914        5.0
915   """)
916 case class Hypot(left: Expression, right: Expression)
917   extends BinaryMathExpression(math.hypot, "HYPOT")
918
919
920 /**
921  * Computes the logarithm of a number.
922  *
923  * @param left the logarithm base, default to e.
924  * @param right the number to compute the logarithm of.
925  */
926 @ExpressionDescription(
927   usage = "_FUNC_(base, expr) - Returns the logarithm of `expr` with `base`.",
928   examples = """
929     Examples:
930       > SELECT _FUNC_(10, 100);
931        2.0
932   """)
933 case class Logarithm(left: Expression, right: Expression)
934   extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") {
935
936   /**
937    * Natural log, i.e. using e as the base.
938    */
939   def this(child: Expression) = {
940     this(EulerNumber(), child)
941   }
942
943   override def nullable: Boolean = true
944
945   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
946     val dLeft = input1.asInstanceOf[Double]
947     val dRight = input2.asInstanceOf[Double]
948     // Unlike Hive, we support Log base in (0.0, 1.0]
949     if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft)
950   }
951
952   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
953     if (left.isInstanceOf[EulerNumber]) {
954       nullSafeCodeGen(ctx, ev, (c1, c2) =>
955         s"""
956           if ($c2 <= 0.0) {
957             ${ev.isNull} = true;
958           } else {
959             ${ev.value} = java.lang.Math.log($c2);
960           }
961         """)
962     } else {
963       nullSafeCodeGen(ctx, ev, (c1, c2) =>
964         s"""
965           if ($c1 <= 0.0 || $c2 <= 0.0) {
966             ${ev.isNull} = true;
967           } else {
968             ${ev.value} = java.lang.Math.log($c2) / java.lang.Math.log($c1);
969           }
970         """)
971     }
972   }
973 }
974
975 /**
976  * Round the `child`'s result to `scale` decimal place when `scale` >= 0
977  * or round at integral part when `scale` < 0.
978  *
979  * Child of IntegralType would round to itself when `scale` >= 0.
980  * Child of FractionalType whose value is NaN or Infinite would always round to itself.
981  *
982  * Round's dataType would always equal to `child`'s dataType except for DecimalType,
983  * which would lead scale decrease from the origin DecimalType.
984  *
985  * @param child expr to be round, all [[NumericType]] is allowed as Input
986  * @param scale new scale to be round to, this should be a constant int at runtime
987  * @param mode rounding mode (e.g. HALF_UP, HALF_EVEN)
988  * @param modeStr rounding mode string name (e.g. "ROUND_HALF_UP", "ROUND_HALF_EVEN")
989  */
990 abstract class RoundBase(child: Expression, scale: Expression,
991     mode: BigDecimal.RoundingMode.Value, modeStr: String)
992   extends BinaryExpression with Serializable with ImplicitCastInputTypes {
993
994   override def left: Expression = child
995   override def right: Expression = scale
996
997   // round of Decimal would eval to null if it fails to `changePrecision`
998   override def nullable: Boolean = true
999
1000   override def foldable: Boolean = child.foldable
1001
1002   override lazy val dataType: DataType = child.dataType match {
1003     // if the new scale is bigger which means we are scaling up,
1004     // keep the original scale as `Decimal` does
1005     case DecimalType.Fixed(p, s) => DecimalType(p, if (_scale > s) s else _scale)
1006     case t => t
1007   }
1008
1009   override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
1010
1011   override def checkInputDataTypes(): TypeCheckResult = {
1012     super.checkInputDataTypes() match {
1013       case TypeCheckSuccess =>
1014         if (scale.foldable) {
1015           TypeCheckSuccess
1016         } else {
1017           TypeCheckFailure("Only foldable Expression is allowed for scale arguments")
1018         }
1019       case f => f
1020     }
1021   }
1022
1023   // Avoid repeated evaluation since `scale` is a constant int,
1024   // avoid unnecessary `child` evaluation in both codegen and non-codegen eval
1025   // by checking if scaleV == null as well.
1026   private lazy val scaleV: Any = scale.eval(EmptyRow)
1027   private lazy val _scale: Int = scaleV.asInstanceOf[Int]
1028
1029   override def eval(input: InternalRow): Any = {
1030     if (scaleV == null) { // if scale is null, no need to eval its child at all
1031       null
1032     } else {
1033       val evalE = child.eval(input)
1034       if (evalE == null) {
1035         null
1036       } else {
1037         nullSafeEval(evalE)
1038       }
1039     }
1040   }
1041
1042   // not overriding since _scale is a constant int at runtime
1043   def nullSafeEval(input1: Any): Any = {
1044     dataType match {
1045       case DecimalType.Fixed(_, s) =>
1046         val decimal = input1.asInstanceOf[Decimal]
1047         decimal.toPrecision(decimal.precision, s, mode).orNull
1048       case ByteType =>
1049         BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
1050       case ShortType =>
1051         BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort
1052       case IntegerType =>
1053         BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt
1054       case LongType =>
1055         BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, mode).toLong
1056       case FloatType =>
1057         val f = input1.asInstanceOf[Float]
1058         if (f.isNaN || f.isInfinite) {
1059           f
1060         } else {
1061           BigDecimal(f.toDouble).setScale(_scale, mode).toFloat
1062         }
1063       case DoubleType =>
1064         val d = input1.asInstanceOf[Double]
1065         if (d.isNaN || d.isInfinite) {
1066           d
1067         } else {
1068           BigDecimal(d).setScale(_scale, mode).toDouble
1069         }
1070     }
1071   }
1072
1073   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
1074     val ce = child.genCode(ctx)
1075
1076     val evaluationCode = dataType match {
1077       case DecimalType.Fixed(_, s) =>
1078         s"""
1079         if (${ce.value}.changePrecision(${ce.value}.precision(), ${s},
1080             java.math.BigDecimal.${modeStr})) {
1081           ${ev.value} = ${ce.value};
1082         } else {
1083           ${ev.isNull} = true;
1084         }"""
1085       case ByteType =>
1086         if (_scale < 0) {
1087           s"""
1088           ${ev.value} = new java.math.BigDecimal(${ce.value}).
1089             setScale(${_scale}, java.math.BigDecimal.${modeStr}).byteValue();"""
1090         } else {
1091           s"${ev.value} = ${ce.value};"
1092         }
1093       case ShortType =>
1094         if (_scale < 0) {
1095           s"""
1096           ${ev.value} = new java.math.BigDecimal(${ce.value}).
1097             setScale(${_scale}, java.math.BigDecimal.${modeStr}).shortValue();"""
1098         } else {
1099           s"${ev.value} = ${ce.value};"
1100         }
1101       case IntegerType =>
1102         if (_scale < 0) {
1103           s"""
1104           ${ev.value} = new java.math.BigDecimal(${ce.value}).
1105             setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();"""
1106         } else {
1107           s"${ev.value} = ${ce.value};"
1108         }
1109       case LongType =>
1110         if (_scale < 0) {
1111           s"""
1112           ${ev.value} = new java.math.BigDecimal(${ce.value}).
1113             setScale(${_scale}, java.math.BigDecimal.${modeStr}).longValue();"""
1114         } else {
1115           s"${ev.value} = ${ce.value};"
1116         }
1117       case FloatType => // if child eval to NaN or Infinity, just return it.
1118         s"""
1119           if (Float.isNaN(${ce.value}) || Float.isInfinite(${ce.value})) {
1120             ${ev.value} = ${ce.value};
1121           } else {
1122             ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
1123               setScale(${_scale}, java.math.BigDecimal.${modeStr}).floatValue();
1124           }"""
1125       case DoubleType => // if child eval to NaN or Infinity, just return it.
1126         s"""
1127           if (Double.isNaN(${ce.value}) || Double.isInfinite(${ce.value})) {
1128             ${ev.value} = ${ce.value};
1129           } else {
1130             ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}).
1131               setScale(${_scale}, java.math.BigDecimal.${modeStr}).doubleValue();
1132           }"""
1133     }
1134
1135     if (scaleV == null) { // if scale is null, no need to eval its child at all
1136       ev.copy(code = s"""
1137         boolean ${ev.isNull} = true;
1138         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""")
1139     } else {
1140       ev.copy(code = s"""
1141         ${ce.code}
1142         boolean ${ev.isNull} = ${ce.isNull};
1143         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
1144         if (!${ev.isNull}) {
1145           $evaluationCode
1146         }""")
1147     }
1148   }
1149 }
1150
1151 /**
1152  * Round an expression to d decimal places using HALF_UP rounding mode.
1153  * round(2.5) == 3.0, round(3.5) == 4.0.
1154  */
1155 // scalastyle:off line.size.limit
1156 @ExpressionDescription(
1157   usage = "_FUNC_(expr, d) - Returns `expr` rounded to `d` decimal places using HALF_UP rounding mode.",
1158   examples = """
1159     Examples:
1160       > SELECT _FUNC_(2.5, 0);
1161        3.0
1162   """)
1163 // scalastyle:on line.size.limit
1164 case class Round(child: Expression, scale: Expression)
1165   extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP")
1166     with Serializable with ImplicitCastInputTypes {
1167   def this(child: Expression) = this(child, Literal(0))
1168 }
1169
1170 /**
1171  * Round an expression to d decimal places using HALF_EVEN rounding mode,
1172  * also known as Gaussian rounding or bankers' rounding.
1173  * round(2.5) = 2.0, round(3.5) = 4.0.
1174  */
1175 // scalastyle:off line.size.limit
1176 @ExpressionDescription(
1177   usage = "_FUNC_(expr, d) - Returns `expr` rounded to `d` decimal places using HALF_EVEN rounding mode.",
1178   examples = """
1179     Examples:
1180       > SELECT _FUNC_(2.5, 0);
1181        2.0
1182   """)
1183 // scalastyle:on line.size.limit
1184 case class BRound(child: Expression, scale: Expression)
1185   extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN")
1186     with Serializable with ImplicitCastInputTypes {
1187   def this(child: Expression) = this(child, Literal(0))
1188 }