/tags/release-0.1-rc2/hive/external/ql/src/java/org/apache/hadoop/hive/ql/udf/UDAFPercentile.java

# · Java · 311 lines · 196 code · 46 blank · 69 comment · 66 complexity · 9fcfd43445aaa588a1bd5edd08a62da9 MD5 · raw file

  1. /**
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. package org.apache.hadoop.hive.ql.udf;
  19. import java.util.ArrayList;
  20. import java.util.Collections;
  21. import java.util.Comparator;
  22. import java.util.HashMap;
  23. import java.util.List;
  24. import java.util.Map;
  25. import java.util.Set;
  26. import org.apache.hadoop.hive.ql.exec.Description;
  27. import org.apache.hadoop.hive.ql.exec.UDAF;
  28. import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
  29. import org.apache.hadoop.hive.serde2.io.DoubleWritable;
  30. import org.apache.hadoop.io.LongWritable;
  31. /**
  32. * UDAF for calculating the percentile values.
  33. * There are several definitions of percentile, and we take the method recommended by
  34. * NIST.
  35. * @see http://en.wikipedia.org/wiki/Percentile#Alternative_methods
  36. */
  37. @Description(name = "percentile",
  38. value = "_FUNC_(expr, pc) - Returns the percentile(s) of expr at pc (range: [0,1])."
  39. + "pc can be a double or double array")
  40. public class UDAFPercentile extends UDAF {
  41. /**
  42. * A state class to store intermediate aggregation results.
  43. */
  44. public static class State {
  45. private Map<LongWritable, LongWritable> counts;
  46. private List<DoubleWritable> percentiles;
  47. }
  48. /**
  49. * A comparator to sort the entries in order.
  50. */
  51. public static class MyComparator implements Comparator<Map.Entry<LongWritable, LongWritable>> {
  52. @Override
  53. public int compare(Map.Entry<LongWritable, LongWritable> o1,
  54. Map.Entry<LongWritable, LongWritable> o2) {
  55. return o1.getKey().compareTo(o2.getKey());
  56. }
  57. }
  58. /**
  59. * Increment the State object with o as the key, and i as the count.
  60. */
  61. private static void increment(State s, LongWritable o, long i) {
  62. if (s.counts == null) {
  63. s.counts = new HashMap<LongWritable, LongWritable>();
  64. }
  65. LongWritable count = s.counts.get(o);
  66. if (count == null) {
  67. // We have to create a new object, because the object o belongs
  68. // to the code that creates it and may get its value changed.
  69. LongWritable key = new LongWritable();
  70. key.set(o.get());
  71. s.counts.put(key, new LongWritable(i));
  72. } else {
  73. count.set(count.get() + i);
  74. }
  75. }
  76. /**
  77. * Get the percentile value.
  78. */
  79. private static double getPercentile(List<Map.Entry<LongWritable, LongWritable>> entriesList,
  80. double position) {
  81. // We may need to do linear interpolation to get the exact percentile
  82. long lower = (long)Math.floor(position);
  83. long higher = (long)Math.ceil(position);
  84. // Linear search since this won't take much time from the total execution anyway
  85. // lower has the range of [0 .. total-1]
  86. // The first entry with accumulated count (lower+1) corresponds to the lower position.
  87. int i = 0;
  88. while (entriesList.get(i).getValue().get() < lower + 1) {
  89. i++;
  90. }
  91. long lowerKey = entriesList.get(i).getKey().get();
  92. if (higher == lower) {
  93. // no interpolation needed because position does not have a fraction
  94. return lowerKey;
  95. }
  96. if (entriesList.get(i).getValue().get() < higher + 1) {
  97. i++;
  98. }
  99. long higherKey = entriesList.get(i).getKey().get();
  100. if (higherKey == lowerKey) {
  101. // no interpolation needed because lower position and higher position has the same key
  102. return lowerKey;
  103. }
  104. // Linear interpolation to get the exact percentile
  105. return (higher - position) * lowerKey + (position - lower) * higherKey;
  106. }
  107. /**
  108. * The evaluator for percentile computation based on long.
  109. */
  110. public static class PercentileLongEvaluator implements UDAFEvaluator {
  111. private final State state;
  112. public PercentileLongEvaluator() {
  113. state = new State();
  114. }
  115. public void init() {
  116. if (state.counts != null) {
  117. // We reuse the same hashmap to reduce new object allocation.
  118. // This means counts can be empty when there is no input data.
  119. state.counts.clear();
  120. }
  121. }
  122. /** Note that percentile can be null in a global aggregation with
  123. * 0 input rows: "select percentile(col, 0.5) from t where false"
  124. * In that case, iterate(null, null) will be called once.
  125. */
  126. public boolean iterate(LongWritable o, Double percentile) {
  127. if (o == null && percentile == null) {
  128. return false;
  129. }
  130. if (state.percentiles == null) {
  131. if (percentile < 0.0 || percentile > 1.0) {
  132. throw new RuntimeException("Percentile value must be wihin the range of 0 to 1.");
  133. }
  134. state.percentiles = new ArrayList<DoubleWritable>(1);
  135. state.percentiles.add(new DoubleWritable(percentile.doubleValue()));
  136. }
  137. if (o != null) {
  138. increment(state, o, 1);
  139. }
  140. return true;
  141. }
  142. public State terminatePartial() {
  143. return state;
  144. }
  145. public boolean merge(State other) {
  146. if (other == null || other.counts == null || other.percentiles == null) {
  147. return false;
  148. }
  149. if (state.percentiles == null) {
  150. state.percentiles = new ArrayList<DoubleWritable>(other.percentiles);
  151. }
  152. for (Map.Entry<LongWritable, LongWritable> e: other.counts.entrySet()) {
  153. increment(state, e.getKey(), e.getValue().get());
  154. }
  155. return true;
  156. }
  157. private DoubleWritable result;
  158. public DoubleWritable terminate() {
  159. // No input data.
  160. if (state.counts == null || state.counts.size() == 0) {
  161. return null;
  162. }
  163. // Get all items into an array and sort them.
  164. Set<Map.Entry<LongWritable, LongWritable>> entries = state.counts.entrySet();
  165. List<Map.Entry<LongWritable, LongWritable>> entriesList =
  166. new ArrayList<Map.Entry<LongWritable, LongWritable>>(entries);
  167. Collections.sort(entriesList, new MyComparator());
  168. // Accumulate the counts.
  169. long total = 0;
  170. for (int i = 0; i < entriesList.size(); i++) {
  171. LongWritable count = entriesList.get(i).getValue();
  172. total += count.get();
  173. count.set(total);
  174. }
  175. // Initialize the result.
  176. if (result == null) {
  177. result = new DoubleWritable();
  178. }
  179. // maxPosition is the 1.0 percentile
  180. long maxPosition = total - 1;
  181. double position = maxPosition * state.percentiles.get(0).get();
  182. result.set(getPercentile(entriesList, position));
  183. return result;
  184. }
  185. }
  186. /**
  187. * The evaluator for percentile computation based on long for an array of percentiles.
  188. */
  189. public static class PercentileLongArrayEvaluator implements UDAFEvaluator {
  190. private final State state;
  191. public PercentileLongArrayEvaluator() {
  192. state = new State();
  193. }
  194. public void init() {
  195. if (state.counts != null) {
  196. // We reuse the same hashmap to reduce new object allocation.
  197. // This means counts can be empty when there is no input data.
  198. state.counts.clear();
  199. }
  200. }
  201. public boolean iterate(LongWritable o, List<DoubleWritable> percentiles) {
  202. if (state.percentiles == null) {
  203. for (int i = 0; i < percentiles.size(); i++) {
  204. if (percentiles.get(i).get() < 0.0 || percentiles.get(i).get() > 1.0) {
  205. throw new RuntimeException("Percentile value must be wihin the range of 0 to 1.");
  206. }
  207. }
  208. state.percentiles = new ArrayList<DoubleWritable>(percentiles);
  209. }
  210. if (o != null) {
  211. increment(state, o, 1);
  212. }
  213. return true;
  214. }
  215. public State terminatePartial() {
  216. return state;
  217. }
  218. public boolean merge(State other) {
  219. if (other == null || other.counts == null || other.percentiles == null) {
  220. return true;
  221. }
  222. if (state.percentiles == null) {
  223. state.percentiles = new ArrayList<DoubleWritable>(other.percentiles);
  224. }
  225. for (Map.Entry<LongWritable, LongWritable> e: other.counts.entrySet()) {
  226. increment(state, e.getKey(), e.getValue().get());
  227. }
  228. return true;
  229. }
  230. private List<DoubleWritable> results;
  231. public List<DoubleWritable> terminate() {
  232. // No input data
  233. if (state.counts == null || state.counts.size() == 0) {
  234. return null;
  235. }
  236. // Get all items into an array and sort them
  237. Set<Map.Entry<LongWritable, LongWritable>> entries = state.counts.entrySet();
  238. List<Map.Entry<LongWritable, LongWritable>> entriesList =
  239. new ArrayList<Map.Entry<LongWritable, LongWritable>>(entries);
  240. Collections.sort(entriesList, new MyComparator());
  241. // accumulate the counts
  242. long total = 0;
  243. for (int i = 0; i < entriesList.size(); i++) {
  244. LongWritable count = entriesList.get(i).getValue();
  245. total += count.get();
  246. count.set(total);
  247. }
  248. // maxPosition is the 1.0 percentile
  249. long maxPosition = total - 1;
  250. // Initialize the results
  251. if (results == null) {
  252. results = new ArrayList<DoubleWritable>();
  253. for (int i = 0; i < state.percentiles.size(); i++) {
  254. results.add(new DoubleWritable());
  255. }
  256. }
  257. // Set the results
  258. for (int i = 0; i < state.percentiles.size(); i++) {
  259. double position = maxPosition * state.percentiles.get(i).get();
  260. results.get(i).set(getPercentile(entriesList, position));
  261. }
  262. return results;
  263. }
  264. }
  265. }