/ql/src/java/org/apache/hadoop/hive/ql/udf/UDAFPercentile.java
Java | 323 lines | 205 code | 48 blank | 70 comment | 69 complexity | da88bbdd1a935c987348e93442d121ab MD5 | raw file
Possible License(s): Apache-2.0
- /*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- package org.apache.hadoop.hive.ql.udf;
- import java.util.ArrayList;
- import java.util.Collections;
- import java.util.Comparator;
- import java.util.HashMap;
- import java.util.List;
- import java.util.Map;
- import java.util.Set;
- import org.apache.hadoop.hive.ql.exec.Description;
- import org.apache.hadoop.hive.ql.exec.UDAF;
- import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
- import org.apache.hadoop.hive.serde2.io.DoubleWritable;
- import org.apache.hadoop.hive.shims.ShimLoader;
- import org.apache.hadoop.io.LongWritable;
- /**
- * UDAF for calculating the percentile values.
- * There are several definitions of percentile, and we take the method recommended by
- * NIST.
- * @see <a href="http://en.wikipedia.org/wiki/Percentile#Alternative_methods">
- * Percentile references</a>
- */
- @Description(name = "percentile",
- value = "_FUNC_(expr, pc) - Returns the percentile(s) of expr at pc (range: [0,1])."
- + "pc can be a double or double array")
- public class UDAFPercentile extends UDAF {
- private static final Comparator<LongWritable> COMPARATOR;
- static {
- COMPARATOR = ShimLoader.getHadoopShims().getLongComparator();
- }
- /**
- * A state class to store intermediate aggregation results.
- */
- public static class State {
- private Map<LongWritable, LongWritable> counts;
- private List<DoubleWritable> percentiles;
- }
- /**
- * A comparator to sort the entries in order.
- */
- public static class MyComparator implements Comparator<Map.Entry<LongWritable, LongWritable>> {
- @Override
- public int compare(Map.Entry<LongWritable, LongWritable> o1,
- Map.Entry<LongWritable, LongWritable> o2) {
- return COMPARATOR.compare(o1.getKey(), o2.getKey());
- }
- }
- /**
- * Increment the State object with o as the key, and i as the count.
- */
- private static void increment(State s, LongWritable o, long i) {
- if (s.counts == null) {
- s.counts = new HashMap<LongWritable, LongWritable>();
- }
- LongWritable count = s.counts.get(o);
- if (count == null) {
- // We have to create a new object, because the object o belongs
- // to the code that creates it and may get its value changed.
- LongWritable key = new LongWritable();
- key.set(o.get());
- s.counts.put(key, new LongWritable(i));
- } else {
- count.set(count.get() + i);
- }
- }
- /**
- * Get the percentile value.
- */
- private static double getPercentile(List<Map.Entry<LongWritable, LongWritable>> entriesList,
- double position) {
- // We may need to do linear interpolation to get the exact percentile
- long lower = (long)Math.floor(position);
- long higher = (long)Math.ceil(position);
- // Linear search since this won't take much time from the total execution anyway
- // lower has the range of [0 .. total-1]
- // The first entry with accumulated count (lower+1) corresponds to the lower position.
- int i = 0;
- while (entriesList.get(i).getValue().get() < lower + 1) {
- i++;
- }
- long lowerKey = entriesList.get(i).getKey().get();
- if (higher == lower) {
- // no interpolation needed because position does not have a fraction
- return lowerKey;
- }
- if (entriesList.get(i).getValue().get() < higher + 1) {
- i++;
- }
- long higherKey = entriesList.get(i).getKey().get();
- if (higherKey == lowerKey) {
- // no interpolation needed because lower position and higher position has the same key
- return lowerKey;
- }
- // Linear interpolation to get the exact percentile
- return (higher - position) * lowerKey + (position - lower) * higherKey;
- }
- /**
- * The evaluator for percentile computation based on long.
- */
- public static class PercentileLongEvaluator implements UDAFEvaluator {
- private final State state;
- public PercentileLongEvaluator() {
- state = new State();
- }
- public void init() {
- if (state.counts != null) {
- // We reuse the same hashmap to reduce new object allocation.
- // This means counts can be empty when there is no input data.
- state.counts.clear();
- }
- }
- /** Note that percentile can be null in a global aggregation with
- * 0 input rows: "select percentile(col, 0.5) from t where false"
- * In that case, iterate(null, null) will be called once.
- */
- public boolean iterate(LongWritable o, Double percentile) {
- if (o == null && percentile == null) {
- return false;
- }
- if (state.percentiles == null) {
- if (percentile < 0.0 || percentile > 1.0) {
- throw new RuntimeException("Percentile value must be within the range of 0 to 1.");
- }
- state.percentiles = Collections.singletonList(new DoubleWritable(percentile.doubleValue()));
- }
- if (o != null) {
- increment(state, o, 1);
- }
- return true;
- }
- public State terminatePartial() {
- return state;
- }
- public boolean merge(State other) {
- if (other == null || other.counts == null || other.percentiles == null) {
- return false;
- }
- if (state.percentiles == null) {
- state.percentiles = new ArrayList<>(other.percentiles);
- }
- for (Map.Entry<LongWritable, LongWritable> e: other.counts.entrySet()) {
- increment(state, e.getKey(), e.getValue().get());
- }
- return true;
- }
- private DoubleWritable result;
- public DoubleWritable terminate() {
- // No input data.
- if (state.counts == null || state.counts.size() == 0) {
- return null;
- }
- // Get all items into an array and sort them.
- Set<Map.Entry<LongWritable, LongWritable>> entries = state.counts.entrySet();
- List<Map.Entry<LongWritable, LongWritable>> entriesList =
- new ArrayList<Map.Entry<LongWritable, LongWritable>>(entries);
- Collections.sort(entriesList, new MyComparator());
- // Accumulate the counts.
- long total = 0;
- for (int i = 0; i < entriesList.size(); i++) {
- LongWritable count = entriesList.get(i).getValue();
- total += count.get();
- count.set(total);
- }
- // Initialize the result.
- if (result == null) {
- result = new DoubleWritable();
- }
- // maxPosition is the 1.0 percentile
- long maxPosition = total - 1;
- double position = maxPosition * state.percentiles.get(0).get();
- result.set(getPercentile(entriesList, position));
- return result;
- }
- }
- /**
- * The evaluator for percentile computation based on long for an array of percentiles.
- */
- public static class PercentileLongArrayEvaluator implements UDAFEvaluator {
- private final State state;
- public PercentileLongArrayEvaluator() {
- state = new State();
- }
- public void init() {
- if (state.counts != null) {
- // We reuse the same hashmap to reduce new object allocation.
- // This means counts can be empty when there is no input data.
- state.counts.clear();
- }
- }
- public boolean iterate(LongWritable o, List<DoubleWritable> percentiles) {
- if (state.percentiles == null) {
- if(percentiles != null) {
- for (int i = 0; i < percentiles.size(); i++) {
- if (percentiles.get(i).get() < 0.0 || percentiles.get(i).get() > 1.0) {
- throw new RuntimeException("Percentile value must be within the range of 0 to 1.");
- }
- }
- state.percentiles = new ArrayList<>(percentiles);
- }
- else {
- state.percentiles = Collections.emptyList();
- }
- }
- if (o != null) {
- increment(state, o, 1);
- }
- return true;
- }
- public State terminatePartial() {
- return state;
- }
- public boolean merge(State other) {
- if (other == null || other.counts == null || other.percentiles == null) {
- return true;
- }
- if (state.percentiles == null) {
- state.percentiles = new ArrayList<>(other.percentiles);
- }
- for (Map.Entry<LongWritable, LongWritable> e: other.counts.entrySet()) {
- increment(state, e.getKey(), e.getValue().get());
- }
- return true;
- }
- private List<DoubleWritable> results;
- public List<DoubleWritable> terminate() {
- // No input data
- if (state.counts == null || state.counts.size() == 0) {
- return null;
- }
- // Get all items into an array and sort them
- Set<Map.Entry<LongWritable, LongWritable>> entries = state.counts.entrySet();
- List<Map.Entry<LongWritable, LongWritable>> entriesList =
- new ArrayList<Map.Entry<LongWritable, LongWritable>>(entries);
- Collections.sort(entriesList, new MyComparator());
- // accumulate the counts
- long total = 0;
- for (int i = 0; i < entriesList.size(); i++) {
- LongWritable count = entriesList.get(i).getValue();
- total += count.get();
- count.set(total);
- }
- // maxPosition is the 1.0 percentile
- long maxPosition = total - 1;
- // Initialize the results
- if (results == null) {
- results = new ArrayList<DoubleWritable>();
- for (int i = 0; i < state.percentiles.size(); i++) {
- results.add(new DoubleWritable());
- }
- }
- // Set the results
- for (int i = 0; i < state.percentiles.size(); i++) {
- double position = maxPosition * state.percentiles.get(i).get();
- results.get(i).set(getPercentile(entriesList, position));
- }
- return results;
- }
- }
- }