HIVE-18359: Extend grouping set limits from int to long (Prasanth Jayachandran review...
[hive.git] / ql / src / java / org / apache / hadoop / hive / ql / optimizer / calcite / rules / HiveExpandDistinctAggregatesRule.java
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to you under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17 package org.apache.hadoop.hive.ql.optimizer.calcite.rules;
18
19 import java.math.BigDecimal;
20 import java.util.ArrayList;
21 import java.util.Collections;
22 import java.util.HashMap;
23 import java.util.HashSet;
24 import java.util.LinkedHashSet;
25 import java.util.List;
26 import java.util.Map;
27 import java.util.Set;
28
29 import org.apache.calcite.plan.RelOptCluster;
30 import org.apache.calcite.plan.RelOptRule;
31 import org.apache.calcite.plan.RelOptRuleCall;
32 import org.apache.calcite.rel.RelNode;
33 import org.apache.calcite.rel.core.Aggregate;
34 import org.apache.calcite.rel.core.AggregateCall;
35 import org.apache.calcite.rel.core.RelFactories;
36 import org.apache.calcite.rel.metadata.RelColumnOrigin;
37 import org.apache.calcite.rel.metadata.RelMetadataQuery;
38 import org.apache.calcite.rel.type.RelDataType;
39 import org.apache.calcite.rel.type.RelDataTypeField;
40 import org.apache.calcite.rex.RexBuilder;
41 import org.apache.calcite.rex.RexInputRef;
42 import org.apache.calcite.rex.RexNode;
43 import org.apache.calcite.sql.fun.SqlStdOperatorTable;
44 import org.apache.calcite.sql.type.SqlTypeName;
45 import org.apache.calcite.util.ImmutableBitSet;
46 import org.apache.calcite.util.Pair;
47 import org.apache.calcite.util.Util;
48 import org.apache.hadoop.hive.ql.optimizer.calcite.CalciteSemanticException;
49 import org.apache.hadoop.hive.ql.optimizer.calcite.HiveCalciteUtil;
50 import org.apache.hadoop.hive.ql.optimizer.calcite.HiveRelFactories;
51 import org.apache.hadoop.hive.ql.optimizer.calcite.RelOptHiveTable;
52 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveAggregate;
53 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveGroupingID;
54 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveProject;
55 import org.apache.hadoop.hive.ql.optimizer.calcite.reloperators.HiveRelNode;
56 import org.apache.hadoop.hive.ql.optimizer.calcite.translator.TypeConverter;
57 import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
58 import org.slf4j.Logger;
59 import org.slf4j.LoggerFactory;
60
61 import com.google.common.base.Function;
62 import com.google.common.collect.ImmutableList;
63 import com.google.common.collect.Lists;
64 import com.google.common.math.IntMath;
65
66 /**
67 * Planner rule that expands distinct aggregates
68 * (such as {@code COUNT(DISTINCT x)}) from a
69 * {@link org.apache.calcite.rel.core.Aggregate}.
70 *
71 * <p>How this is done depends upon the arguments to the function. If all
72 * functions have the same argument
73 * (e.g. {@code COUNT(DISTINCT x), SUM(DISTINCT x)} both have the argument
74 * {@code x}) then one extra {@link org.apache.calcite.rel.core.Aggregate} is
75 * sufficient.
76 *
77 * <p>If there are multiple arguments
78 * (e.g. {@code COUNT(DISTINCT x), COUNT(DISTINCT y)})
79 * the rule creates separate {@code Aggregate}s and combines using a
80 * {@link org.apache.calcite.rel.core.Join}.
81 */
82
83 // Stripped down version of org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule
84 // This is adapted for Hive, but should eventually be deleted from Hive and make use of above.
85
86 public final class HiveExpandDistinctAggregatesRule extends RelOptRule {
87 //~ Static fields/initializers ---------------------------------------------
88
89 /** The default instance of the rule; operates only on logical expressions. */
90 public static final HiveExpandDistinctAggregatesRule INSTANCE =
91 new HiveExpandDistinctAggregatesRule(HiveAggregate.class,
92 HiveRelFactories.HIVE_PROJECT_FACTORY);
93
94 private static RelFactories.ProjectFactory projFactory;
95
96 protected static final Logger LOG = LoggerFactory.getLogger(HiveExpandDistinctAggregatesRule.class);
97
98 //~ Constructors -----------------------------------------------------------
99
100 public HiveExpandDistinctAggregatesRule(
101 Class<? extends Aggregate> clazz,RelFactories.ProjectFactory projectFactory) {
102 super(operand(clazz, any()));
103 projFactory = projectFactory;
104 }
105
106 RelOptCluster cluster = null;
107 RexBuilder rexBuilder = null;
108
109 //~ Methods ----------------------------------------------------------------
110
111 @Override
112 public void onMatch(RelOptRuleCall call) {
113 final Aggregate aggregate = call.rel(0);
114 int numCountDistinct = getNumCountDistinctCall(aggregate);
115 if (numCountDistinct == 0) {
116 return;
117 }
118
119 // Find all of the agg expressions. We use a List (for all count(distinct))
120 // as well as a Set (for all others) to ensure determinism.
121 int nonDistinctCount = 0;
122 List<List<Integer>> argListList = new ArrayList<List<Integer>>();
123 Set<List<Integer>> argListSets = new LinkedHashSet<List<Integer>>();
124 Set<Integer> positions = new HashSet<>();
125 for (AggregateCall aggCall : aggregate.getAggCallList()) {
126 if (!aggCall.isDistinct()) {
127 ++nonDistinctCount;
128 continue;
129 }
130 ArrayList<Integer> argList = new ArrayList<Integer>();
131 for (Integer arg : aggCall.getArgList()) {
132 argList.add(arg);
133 positions.add(arg);
134 }
135 // Aggr checks for sorted argList.
136 argListList.add(argList);
137 argListSets.add(argList);
138 }
139 Util.permAssert(argListSets.size() > 0, "containsDistinctCall lied");
140
141 if (numCountDistinct > 1 && numCountDistinct == aggregate.getAggCallList().size()
142 && aggregate.getGroupSet().isEmpty()) {
143 LOG.debug("Trigger countDistinct rewrite. numCountDistinct is " + numCountDistinct);
144 // now positions contains all the distinct positions, i.e., $5, $4, $6
145 // we need to first sort them as group by set
146 // and then get their position later, i.e., $4->1, $5->2, $6->3
147 cluster = aggregate.getCluster();
148 rexBuilder = cluster.getRexBuilder();
149 RelNode converted = null;
150 List<Integer> sourceOfForCountDistinct = new ArrayList<>();
151 sourceOfForCountDistinct.addAll(positions);
152 Collections.sort(sourceOfForCountDistinct);
153 try {
154 converted = convert(aggregate, argListList, sourceOfForCountDistinct);
155 } catch (CalciteSemanticException e) {
156 LOG.debug(e.toString());
157 throw new RuntimeException(e);
158 }
159 call.transformTo(converted);
160 return;
161 }
162
163 // If all of the agg expressions are distinct and have the same
164 // arguments then we can use a more efficient form.
165 final RelMetadataQuery mq = call.getMetadataQuery();
166 if ((nonDistinctCount == 0) && (argListSets.size() == 1)) {
167 for (Integer arg : argListSets.iterator().next()) {
168 Set<RelColumnOrigin> colOrigs = mq.getColumnOrigins(aggregate, arg);
169 if (null != colOrigs) {
170 for (RelColumnOrigin colOrig : colOrigs) {
171 RelOptHiveTable hiveTbl = (RelOptHiveTable)colOrig.getOriginTable();
172 if(hiveTbl.getPartColInfoMap().containsKey(colOrig.getOriginColumnOrdinal())) {
173 // Encountered partitioning column, this will be better handled by MetadataOnly optimizer.
174 return;
175 }
176 }
177 }
178 }
179 RelNode converted =
180 convertMonopole(
181 aggregate,
182 argListSets.iterator().next());
183 call.transformTo(converted);
184 return;
185 }
186 }
187
188 /**
189 * Converts an aggregate relational expression that contains only
190 * count(distinct) to grouping sets with count. For example select
191 * count(distinct department_id), count(distinct gender), count(distinct
192 * education_level) from employee; can be transformed to
193 * select
194 * count(case when i=1 and department_id is not null then 1 else null end) as c0,
195 * count(case when i=2 and gender is not null then 1 else null end) as c1,
196 * count(case when i=4 and education_level is not null then 1 else null end) as c2
197 * from (select
198 * grouping__id as i, department_id, gender, education_level from employee
199 * group by department_id, gender, education_level grouping sets
200 * (department_id, gender, education_level))subq;
201 * @throws CalciteSemanticException
202 */
203 private RelNode convert(Aggregate aggregate, List<List<Integer>> argList, List<Integer> sourceOfForCountDistinct) throws CalciteSemanticException {
204 // we use this map to map the position of argList to the position of grouping set
205 Map<Integer, Integer> map = new HashMap<>();
206 List<List<Integer>> cleanArgList = new ArrayList<>();
207 final Aggregate groupingSets = createGroupingSets(aggregate, argList, cleanArgList, map, sourceOfForCountDistinct);
208 return createCount(groupingSets, argList, cleanArgList, map, sourceOfForCountDistinct);
209 }
210
211 private int getGroupingIdValue(List<Integer> list, List<Integer> sourceOfForCountDistinct,
212 int groupCount) {
213 int ind = IntMath.pow(2, groupCount) - 1;
214 for (int i : list) {
215 ind &= ~(1 << groupCount - sourceOfForCountDistinct.indexOf(i) - 1);
216 }
217 return ind;
218 }
219
220 /**
221 * @param aggr: the original aggregate
222 * @param argList: the original argList in aggregate
223 * @param cleanArgList: the new argList without duplicates
224 * @param map: the mapping from the original argList to the new argList
225 * @param sourceOfForCountDistinct: the sorted positions of groupset
226 * @return
227 * @throws CalciteSemanticException
228 */
229 private RelNode createCount(Aggregate aggr, List<List<Integer>> argList,
230 List<List<Integer>> cleanArgList, Map<Integer, Integer> map,
231 List<Integer> sourceOfForCountDistinct) throws CalciteSemanticException {
232 List<RexNode> originalInputRefs = Lists.transform(aggr.getRowType().getFieldList(),
233 new Function<RelDataTypeField, RexNode>() {
234 @Override
235 public RexNode apply(RelDataTypeField input) {
236 return new RexInputRef(input.getIndex(), input.getType());
237 }
238 });
239 final List<RexNode> gbChildProjLst = Lists.newArrayList();
240 // for singular arg, count should not include null
241 // e.g., count(case when i=1 and department_id is not null then 1 else null end) as c0,
242 // for non-singular args, count can include null, i.e. (,) is counted as 1
243 for (List<Integer> list : cleanArgList) {
244 RexNode condition = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, originalInputRefs
245 .get(originalInputRefs.size() - 1), rexBuilder.makeExactLiteral(new BigDecimal(
246 getGroupingIdValue(list, sourceOfForCountDistinct, aggr.getGroupCount()))));
247 if (list.size() == 1) {
248 int pos = list.get(0);
249 RexNode notNull = rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL,
250 originalInputRefs.get(pos));
251 condition = rexBuilder.makeCall(SqlStdOperatorTable.AND, condition, notNull);
252 }
253 RexNode when = rexBuilder.makeCall(SqlStdOperatorTable.CASE, condition,
254 rexBuilder.makeExactLiteral(BigDecimal.ONE), rexBuilder.constantNull());
255 gbChildProjLst.add(when);
256 }
257
258 // create the project before GB
259 RelNode gbInputRel = HiveProject.create(aggr, gbChildProjLst, null);
260
261 // create the aggregate
262 List<AggregateCall> aggregateCalls = Lists.newArrayList();
263 RelDataType aggFnRetType = TypeConverter.convert(TypeInfoFactory.longTypeInfo,
264 cluster.getTypeFactory());
265 for (int i = 0; i < cleanArgList.size(); i++) {
266 AggregateCall aggregateCall = HiveCalciteUtil.createSingleArgAggCall("count", cluster,
267 TypeInfoFactory.longTypeInfo, i, aggFnRetType);
268 aggregateCalls.add(aggregateCall);
269 }
270 Aggregate aggregate = new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION), gbInputRel,
271 ImmutableBitSet.of(), null, aggregateCalls);
272
273 // create the project after GB. For those repeated values, e.g., select
274 // count(distinct x, y), count(distinct y, x), we find the correct mapping.
275 if (map.isEmpty()) {
276 return aggregate;
277 } else {
278 List<RexNode> originalAggrRefs = Lists.transform(aggregate.getRowType().getFieldList(),
279 new Function<RelDataTypeField, RexNode>() {
280 @Override
281 public RexNode apply(RelDataTypeField input) {
282 return new RexInputRef(input.getIndex(), input.getType());
283 }
284 });
285 final List<RexNode> projLst = Lists.newArrayList();
286 int index = 0;
287 for (int i = 0; i < argList.size(); i++) {
288 if (map.containsKey(i)) {
289 projLst.add(originalAggrRefs.get(map.get(i)));
290 } else {
291 projLst.add(originalAggrRefs.get(index++));
292 }
293 }
294 return HiveProject.create(aggregate, projLst, null);
295 }
296 }
297
298 /**
299 * @param aggregate: the original aggregate
300 * @param argList: the original argList in aggregate
301 * @param cleanArgList: the new argList without duplicates
302 * @param map: the mapping from the original argList to the new argList
303 * @param sourceOfForCountDistinct: the sorted positions of groupset
304 * @return
305 */
306 private Aggregate createGroupingSets(Aggregate aggregate, List<List<Integer>> argList,
307 List<List<Integer>> cleanArgList, Map<Integer, Integer> map,
308 List<Integer> sourceOfForCountDistinct) {
309 final ImmutableBitSet groupSet = ImmutableBitSet.of(sourceOfForCountDistinct);
310 final List<ImmutableBitSet> origGroupSets = new ArrayList<>();
311
312 for (int i = 0; i < argList.size(); i++) {
313 List<Integer> list = argList.get(i);
314 ImmutableBitSet bitSet = ImmutableBitSet.of(list);
315 int prev = origGroupSets.indexOf(bitSet);
316 if (prev == -1) {
317 origGroupSets.add(bitSet);
318 cleanArgList.add(list);
319 } else {
320 map.put(i, prev);
321 }
322 }
323 // Calcite expects the grouping sets sorted and without duplicates
324 Collections.sort(origGroupSets, ImmutableBitSet.COMPARATOR);
325
326 List<AggregateCall> aggregateCalls = new ArrayList<AggregateCall>();
327 // Create GroupingID column
328 AggregateCall aggCall = AggregateCall.create(HiveGroupingID.INSTANCE, false,
329 new ImmutableList.Builder<Integer>().build(), -1, this.cluster.getTypeFactory()
330 .createSqlType(SqlTypeName.BIGINT), HiveGroupingID.INSTANCE.getName());
331 aggregateCalls.add(aggCall);
332 return new HiveAggregate(cluster, cluster.traitSetOf(HiveRelNode.CONVENTION),
333 aggregate.getInput(), groupSet, origGroupSets, aggregateCalls);
334 }
335
336 /**
337 * Returns the number of count DISTINCT
338 *
339 * @return the number of count DISTINCT
340 */
341 private int getNumCountDistinctCall(Aggregate hiveAggregate) {
342 int cnt = 0;
343 for (AggregateCall aggCall : hiveAggregate.getAggCallList()) {
344 if (aggCall.isDistinct() && (aggCall.getAggregation().getName().equalsIgnoreCase("count"))) {
345 cnt++;
346 }
347 }
348 return cnt;
349 }
350
351 /**
352 * Converts an aggregate relational expression that contains just one
353 * distinct aggregate function (or perhaps several over the same arguments)
354 * and no non-distinct aggregate functions.
355 */
356 private RelNode convertMonopole(
357 Aggregate aggregate,
358 List<Integer> argList) {
359 // For example,
360 // SELECT deptno, COUNT(DISTINCT sal), SUM(DISTINCT sal)
361 // FROM emp
362 // GROUP BY deptno
363 //
364 // becomes
365 //
366 // SELECT deptno, COUNT(distinct_sal), SUM(distinct_sal)
367 // FROM (
368 // SELECT DISTINCT deptno, sal AS distinct_sal
369 // FROM EMP GROUP BY deptno)
370 // GROUP BY deptno
371
372 // Project the columns of the GROUP BY plus the arguments
373 // to the agg function.
374 Map<Integer, Integer> sourceOf = new HashMap<Integer, Integer>();
375 final Aggregate distinct =
376 createSelectDistinct(aggregate, argList, sourceOf);
377
378 // Create an aggregate on top, with the new aggregate list.
379 final List<AggregateCall> newAggCalls =
380 Lists.newArrayList(aggregate.getAggCallList());
381 rewriteAggCalls(newAggCalls, argList, sourceOf);
382 final int cardinality = aggregate.getGroupSet().cardinality();
383 return aggregate.copy(aggregate.getTraitSet(), distinct,
384 aggregate.indicator, ImmutableBitSet.range(cardinality), null,
385 newAggCalls);
386 }
387
388 private static void rewriteAggCalls(
389 List<AggregateCall> newAggCalls,
390 List<Integer> argList,
391 Map<Integer, Integer> sourceOf) {
392 // Rewrite the agg calls. Each distinct agg becomes a non-distinct call
393 // to the corresponding field from the right; for example,
394 // "COUNT(DISTINCT e.sal)" becomes "COUNT(distinct_e.sal)".
395 for (int i = 0; i < newAggCalls.size(); i++) {
396 final AggregateCall aggCall = newAggCalls.get(i);
397
398 // Ignore agg calls which are not distinct or have the wrong set
399 // arguments. If we're rewriting aggs whose args are {sal}, we will
400 // rewrite COUNT(DISTINCT sal) and SUM(DISTINCT sal) but ignore
401 // COUNT(DISTINCT gender) or SUM(sal).
402 if (!aggCall.isDistinct()) {
403 continue;
404 }
405 if (!aggCall.getArgList().equals(argList)) {
406 continue;
407 }
408
409 // Re-map arguments.
410 final int argCount = aggCall.getArgList().size();
411 final List<Integer> newArgs = new ArrayList<Integer>(argCount);
412 for (int j = 0; j < argCount; j++) {
413 final Integer arg = aggCall.getArgList().get(j);
414 newArgs.add(sourceOf.get(arg));
415 }
416 final AggregateCall newAggCall =
417 new AggregateCall(
418 aggCall.getAggregation(),
419 false,
420 newArgs,
421 aggCall.getType(),
422 aggCall.getName());
423 newAggCalls.set(i, newAggCall);
424 }
425 }
426
427 /**
428 * Given an {@link org.apache.calcite.rel.logical.LogicalAggregate}
429 * and the ordinals of the arguments to a
430 * particular call to an aggregate function, creates a 'select distinct'
431 * relational expression which projects the group columns and those
432 * arguments but nothing else.
433 *
434 * <p>For example, given
435 *
436 * <blockquote>
437 * <pre>select f0, count(distinct f1), count(distinct f2)
438 * from t group by f0</pre>
439 * </blockquote>
440 *
441 * and the arglist
442 *
443 * <blockquote>{2}</blockquote>
444 *
445 * returns
446 *
447 * <blockquote>
448 * <pre>select distinct f0, f2 from t</pre>
449 * </blockquote>
450 *
451 * '
452 *
453 * <p>The <code>sourceOf</code> map is populated with the source of each
454 * column; in this case sourceOf.get(0) = 0, and sourceOf.get(1) = 2.</p>
455 *
456 * @param aggregate Aggregate relational expression
457 * @param argList Ordinals of columns to make distinct
458 * @param sourceOf Out parameter, is populated with a map of where each
459 * output field came from
460 * @return Aggregate relational expression which projects the required
461 * columns
462 */
463 private static Aggregate createSelectDistinct(
464 Aggregate aggregate,
465 List<Integer> argList,
466 Map<Integer, Integer> sourceOf) {
467 final List<Pair<RexNode, String>> projects =
468 new ArrayList<Pair<RexNode, String>>();
469 final RelNode child = aggregate.getInput();
470 final List<RelDataTypeField> childFields =
471 child.getRowType().getFieldList();
472 for (int i : aggregate.getGroupSet()) {
473 sourceOf.put(i, projects.size());
474 projects.add(RexInputRef.of2(i, childFields));
475 }
476 for (Integer arg : argList) {
477 if (sourceOf.get(arg) != null) {
478 continue;
479 }
480 sourceOf.put(arg, projects.size());
481 projects.add(RexInputRef.of2(arg, childFields));
482 }
483 final RelNode project =
484 projFactory.createProject(child, Pair.left(projects), Pair.right(projects));
485
486 // Get the distinct values of the GROUP BY fields and the arguments
487 // to the agg functions.
488 return aggregate.copy(aggregate.getTraitSet(), project, false,
489 ImmutableBitSet.range(projects.size()),
490 null, ImmutableList.<AggregateCall>of());
491 }
492 }
493
494 // End AggregateExpandDistinctAggregatesRule.java