[SPARK-21485][FOLLOWUP][SQL][DOCS] Describes examples and arguments separately, and...
[spark.git] / sql / catalyst / src / main / scala / org / apache / spark / sql / catalyst / expressions / randomExpressions.scala
1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 package org.apache.spark.sql.catalyst.expressions
19
20 import org.apache.spark.sql.AnalysisException
21 import org.apache.spark.sql.catalyst.InternalRow
22 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
23 import org.apache.spark.sql.types._
24 import org.apache.spark.util.Utils
25 import org.apache.spark.util.random.XORShiftRandom
26
27 /**
28  * A Random distribution generating expression.
29  * TODO: This can be made generic to generate any type of random distribution, or any type of
30  * StructType.
31  *
32  * Since this expression is stateful, it cannot be a case object.
33  */
34 abstract class RDG extends UnaryExpression with ExpectsInputTypes with Nondeterministic {
35   /**
36    * Record ID within each partition. By being transient, the Random Number Generator is
37    * reset every time we serialize and deserialize and initialize it.
38    */
39   @transient protected var rng: XORShiftRandom = _
40
41   override protected def initializeInternal(partitionIndex: Int): Unit = {
42     rng = new XORShiftRandom(seed + partitionIndex)
43   }
44
45   @transient protected lazy val seed: Long = child match {
46     case Literal(s, IntegerType) => s.asInstanceOf[Int]
47     case Literal(s, LongType) => s.asInstanceOf[Long]
48     case _ => throw new AnalysisException(
49       s"Input argument to $prettyName must be an integer, long or null literal.")
50   }
51
52   override def nullable: Boolean = false
53
54   override def dataType: DataType = DoubleType
55
56   override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
57 }
58
59 /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
60 // scalastyle:off line.size.limit
61 @ExpressionDescription(
62   usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) uniformly distributed values in [0, 1).",
63   examples = """
64     Examples:
65       > SELECT _FUNC_();
66        0.9629742951434543
67       > SELECT _FUNC_(0);
68        0.8446490682263027
69       > SELECT _FUNC_(null);
70        0.8446490682263027
71   """)
72 // scalastyle:on line.size.limit
73 case class Rand(child: Expression) extends RDG {
74
75   def this() = this(Literal(Utils.random.nextLong(), LongType))
76
77   override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
78
79   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
80     val rngTerm = ctx.freshName("rng")
81     val className = classOf[XORShiftRandom].getName
82     ctx.addMutableState(className, rngTerm, "")
83     ctx.addPartitionInitializationStatement(
84       s"$rngTerm = new $className(${seed}L + partitionIndex);")
85     ev.copy(code = s"""
86       final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false")
87   }
88 }
89
90 object Rand {
91   def apply(seed: Long): Rand = Rand(Literal(seed, LongType))
92 }
93
94 /** Generate a random column with i.i.d. values drawn from the standard normal distribution. */
95 // scalastyle:off line.size.limit
96 @ExpressionDescription(
97   usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.",
98   examples = """
99     Examples:
100       > SELECT _FUNC_();
101        -0.3254147983080288
102       > SELECT _FUNC_(0);
103        1.1164209726833079
104       > SELECT _FUNC_(null);
105        1.1164209726833079
106   """)
107 // scalastyle:on line.size.limit
108 case class Randn(child: Expression) extends RDG {
109
110   def this() = this(Literal(Utils.random.nextLong(), LongType))
111
112   override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
113
114   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
115     val rngTerm = ctx.freshName("rng")
116     val className = classOf[XORShiftRandom].getName
117     ctx.addMutableState(className, rngTerm, "")
118     ctx.addPartitionInitializationStatement(
119       s"$rngTerm = new $className(${seed}L + partitionIndex);")
120     ev.copy(code = s"""
121       final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false")
122   }
123 }
124
125 object Randn {
126   def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
127 }