[SPARK-21485][FOLLOWUP][SQL][DOCS] Describes examples and arguments separately, and...
[spark.git] / sql / catalyst / src / main / scala / org / apache / spark / sql / catalyst / expressions / conditionalExpressions.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.catalyst.expressions
19
20 import org.apache.spark.sql.catalyst.InternalRow
21 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
22 import org.apache.spark.sql.catalyst.expressions.codegen._
23 import org.apache.spark.sql.types._
24
25 // scalastyle:off line.size.limit
26 @ExpressionDescription(
27   usage = "_FUNC_(expr1, expr2, expr3) - If `expr1` evaluates to true, then returns `expr2`; otherwise returns `expr3`.",
28   examples = """
29     Examples:
30       > SELECT _FUNC_(1 < 2, 'a', 'b');
31        a
32   """)
33 // scalastyle:on line.size.limit
34 case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
35   extends Expression {
36
37   override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil
38   override def nullable: Boolean = trueValue.nullable || falseValue.nullable
39
40   override def checkInputDataTypes(): TypeCheckResult = {
41     if (predicate.dataType != BooleanType) {
42       TypeCheckResult.TypeCheckFailure(
43         s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
44     } else if (!trueValue.dataType.sameType(falseValue.dataType)) {
45       TypeCheckResult.TypeCheckFailure(s"differing types in '$sql' " +
46         s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
47     } else {
48       TypeCheckResult.TypeCheckSuccess
49     }
50   }
51
52   override def dataType: DataType = trueValue.dataType
53
54   override def eval(input: InternalRow): Any = {
55     if (java.lang.Boolean.TRUE.equals(predicate.eval(input))) {
56       trueValue.eval(input)
57     } else {
58       falseValue.eval(input)
59     }
60   }
61
62   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
63     val condEval = predicate.genCode(ctx)
64     val trueEval = trueValue.genCode(ctx)
65     val falseEval = falseValue.genCode(ctx)
66
67     // place generated code of condition, true value and false value in separate methods if
68     // their code combined is large
69     val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length
70     val generatedCode = if (combinedLength > 1024 &&
71       // Split these expressions only if they are created from a row object
72       (ctx.INPUT_ROW != null && ctx.currentVars == null)) {
73
74       val (condFuncName, condGlobalIsNull, condGlobalValue) =
75         createAndAddFunction(ctx, condEval, predicate.dataType, "evalIfCondExpr")
76       val (trueFuncName, trueGlobalIsNull, trueGlobalValue) =
77         createAndAddFunction(ctx, trueEval, trueValue.dataType, "evalIfTrueExpr")
78       val (falseFuncName, falseGlobalIsNull, falseGlobalValue) =
79         createAndAddFunction(ctx, falseEval, falseValue.dataType, "evalIfFalseExpr")
80       s"""
81         $condFuncName(${ctx.INPUT_ROW});
82         boolean ${ev.isNull} = false;
83         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
84         if (!$condGlobalIsNull && $condGlobalValue) {
85           $trueFuncName(${ctx.INPUT_ROW});
86           ${ev.isNull} = $trueGlobalIsNull;
87           ${ev.value} = $trueGlobalValue;
88         } else {
89           $falseFuncName(${ctx.INPUT_ROW});
90           ${ev.isNull} = $falseGlobalIsNull;
91           ${ev.value} = $falseGlobalValue;
92         }
93       """
94     }
95     else {
96       s"""
97         ${condEval.code}
98         boolean ${ev.isNull} = false;
99         ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
100         if (!${condEval.isNull} && ${condEval.value}) {
101           ${trueEval.code}
102           ${ev.isNull} = ${trueEval.isNull};
103           ${ev.value} = ${trueEval.value};
104         } else {
105           ${falseEval.code}
106           ${ev.isNull} = ${falseEval.isNull};
107           ${ev.value} = ${falseEval.value};
108         }
109       """
110     }
111
112     ev.copy(code = generatedCode)
113   }
114
115   private def createAndAddFunction(
116       ctx: CodegenContext,
117       ev: ExprCode,
118       dataType: DataType,
119       baseFuncName: String): (String, String, String) = {
120     val globalIsNull = ctx.freshName("isNull")
121     ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;")
122     val globalValue = ctx.freshName("value")
123     ctx.addMutableState(ctx.javaType(dataType), globalValue,
124       s"$globalValue = ${ctx.defaultValue(dataType)};")
125     val funcName = ctx.freshName(baseFuncName)
126     val funcBody =
127       s"""
128          |private void $funcName(InternalRow ${ctx.INPUT_ROW}) {
129          |  ${ev.code.trim}
130          |  $globalIsNull = ${ev.isNull};
131          |  $globalValue = ${ev.value};
132          |}
133          """.stripMargin
134     val fullFuncName = ctx.addNewFunction(funcName, funcBody)
135     (fullFuncName, globalIsNull, globalValue)
136   }
137
138   override def toString: String = s"if ($predicate) $trueValue else $falseValue"
139
140   override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))"
141 }
142
143 /**
144  * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
145  *
146  * @param branches seq of (branch condition, branch value)
147  * @param elseValue optional value for the else branch
148  */
149 abstract class CaseWhenBase(
150     branches: Seq[(Expression, Expression)],
151     elseValue: Option[Expression])
152   extends Expression with Serializable {
153
154   override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
155
156   // both then and else expressions should be considered.
157   def valueTypes: Seq[DataType] = branches.map(_._2.dataType) ++ elseValue.map(_.dataType)
158
159   def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall {
160     case Seq(dt1, dt2) => dt1.sameType(dt2)
161   }
162
163   override def dataType: DataType = branches.head._2.dataType
164
165   override def nullable: Boolean = {
166     // Result is nullable if any of the branch is nullable, or if the else value is nullable
167     branches.exists(_._2.nullable) || elseValue.map(_.nullable).getOrElse(true)
168   }
169
170   override def checkInputDataTypes(): TypeCheckResult = {
171     // Make sure all branch conditions are boolean types.
172     if (valueTypesEqual) {
173       if (branches.forall(_._1.dataType == BooleanType)) {
174         TypeCheckResult.TypeCheckSuccess
175       } else {
176         val index = branches.indexWhere(_._1.dataType != BooleanType)
177         TypeCheckResult.TypeCheckFailure(
178           s"WHEN expressions in CaseWhen should all be boolean type, " +
179             s"but the ${index + 1}th when expression's type is ${branches(index)._1}")
180       }
181     } else {
182       TypeCheckResult.TypeCheckFailure(
183         "THEN and ELSE expressions should all be same type or coercible to a common type")
184     }
185   }
186
187   override def eval(input: InternalRow): Any = {
188     var i = 0
189     val size = branches.size
190     while (i < size) {
191       if (java.lang.Boolean.TRUE.equals(branches(i)._1.eval(input))) {
192         return branches(i)._2.eval(input)
193       }
194       i += 1
195     }
196     if (elseValue.isDefined) {
197       return elseValue.get.eval(input)
198     } else {
199       return null
200     }
201   }
202
203   override def toString: String = {
204     val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
205     val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
206     "CASE" + cases + elseCase + " END"
207   }
208
209   override def sql: String = {
210     val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
211     val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
212     "CASE" + cases + elseCase + " END"
213   }
214 }
215
216
217 /**
218  * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
219  * When a = true, returns b; when c = true, returns d; else returns e.
220  *
221  * @param branches seq of (branch condition, branch value)
222  * @param elseValue optional value for the else branch
223  */
224 // scalastyle:off line.size.limit
225 @ExpressionDescription(
226   usage = "CASE WHEN expr1 THEN expr2 [WHEN expr3 THEN expr4]* [ELSE expr5] END - When `expr1` = true, returns `expr2`; when `expr3` = true, return `expr4`; else return `expr5`.")
227 // scalastyle:on line.size.limit
228 case class CaseWhen(
229     val branches: Seq[(Expression, Expression)],
230     val elseValue: Option[Expression] = None)
231   extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {
232
233   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
234     super[CodegenFallback].doGenCode(ctx, ev)
235   }
236
237   def toCodegen(): CaseWhenCodegen = {
238     CaseWhenCodegen(branches, elseValue)
239   }
240 }
241
242 /**
243  * CaseWhen expression used when code generation condition is satisfied.
244  * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
245  *
246  * @param branches seq of (branch condition, branch value)
247  * @param elseValue optional value for the else branch
248  */
249 case class CaseWhenCodegen(
250     val branches: Seq[(Expression, Expression)],
251     val elseValue: Option[Expression] = None)
252   extends CaseWhenBase(branches, elseValue) with Serializable {
253
254   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
255     // Generate code that looks like:
256     //
257     // condA = ...
258     // if (condA) {
259     //   valueA
260     // } else {
261     //   condB = ...
262     //   if (condB) {
263     //     valueB
264     //   } else {
265     //     condC = ...
266     //     if (condC) {
267     //       valueC
268     //     } else {
269     //       elseValue
270     //     }
271     //   }
272     // }
273     val cases = branches.map { case (condExpr, valueExpr) =>
274       val cond = condExpr.genCode(ctx)
275       val res = valueExpr.genCode(ctx)
276       s"""
277         ${cond.code}
278         if (!${cond.isNull} && ${cond.value}) {
279           ${res.code}
280           ${ev.isNull} = ${res.isNull};
281           ${ev.value} = ${res.value};
282         }
283       """
284     }
285
286     var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n")
287
288     elseValue.foreach { elseExpr =>
289       val res = elseExpr.genCode(ctx)
290       generatedCode +=
291         s"""
292           ${res.code}
293           ${ev.isNull} = ${res.isNull};
294           ${ev.value} = ${res.value};
295         """
296     }
297
298     generatedCode += "}\n" * cases.size
299
300     ev.copy(code = s"""
301       boolean ${ev.isNull} = true;
302       ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
303       $generatedCode""")
304   }
305 }
306
307 /** Factory methods for CaseWhen. */
308 object CaseWhen {
309   def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
310     CaseWhen(branches, Option(elseValue))
311   }
312
313   /**
314    * A factory method to facilitate the creation of this expression when used in parsers.
315    *
316    * @param branches Expressions at even position are the branch conditions, and expressions at odd
317    *                 position are branch values.
318    */
319   def createFromParser(branches: Seq[Expression]): CaseWhen = {
320     val cases = branches.grouped(2).flatMap {
321       case cond :: value :: Nil => Some((cond, value))
322       case value :: Nil => None
323     }.toArray.toSeq  // force materialization to make the seq serializable
324     val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
325     CaseWhen(cases, elseValue)
326   }
327 }
328
329 /**
330  * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
331  * When a = b, returns c; when a = d, returns e; else returns f.
332  */
333 object CaseKeyWhen {
334   def apply(key: Expression, branches: Seq[Expression]): CaseWhen = {
335     val cases = branches.grouped(2).flatMap {
336       case Seq(cond, value) => Some((EqualTo(key, cond), value))
337       case Seq(value) => None
338     }.toArray.toSeq  // force materialization to make the seq serializable
339     val elseValue = if (branches.size % 2 == 1) Some(branches.last) else None
340     CaseWhen(cases, elseValue)
341   }
342 }