PageRenderTime 45ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/ql/src/java/org/apache/hadoop/hive/ql/udf/UDAFPercentile.java

http://github.com/apache/hive
Java | 323 lines | 205 code | 48 blank | 70 comment | 69 complexity | da88bbdd1a935c987348e93442d121ab MD5 | raw file
Possible License(s): Apache-2.0
  1. /*
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. package org.apache.hadoop.hive.ql.udf;
  19. import java.util.ArrayList;
  20. import java.util.Collections;
  21. import java.util.Comparator;
  22. import java.util.HashMap;
  23. import java.util.List;
  24. import java.util.Map;
  25. import java.util.Set;
  26. import org.apache.hadoop.hive.ql.exec.Description;
  27. import org.apache.hadoop.hive.ql.exec.UDAF;
  28. import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
  29. import org.apache.hadoop.hive.serde2.io.DoubleWritable;
  30. import org.apache.hadoop.hive.shims.ShimLoader;
  31. import org.apache.hadoop.io.LongWritable;
  32. /**
  33. * UDAF for calculating the percentile values.
  34. * There are several definitions of percentile, and we take the method recommended by
  35. * NIST.
  36. * @see <a href="http://en.wikipedia.org/wiki/Percentile#Alternative_methods">
  37. * Percentile references</a>
  38. */
  39. @Description(name = "percentile",
  40. value = "_FUNC_(expr, pc) - Returns the percentile(s) of expr at pc (range: [0,1])."
  41. + "pc can be a double or double array")
  42. public class UDAFPercentile extends UDAF {
  43. private static final Comparator<LongWritable> COMPARATOR;
  44. static {
  45. COMPARATOR = ShimLoader.getHadoopShims().getLongComparator();
  46. }
  47. /**
  48. * A state class to store intermediate aggregation results.
  49. */
  50. public static class State {
  51. private Map<LongWritable, LongWritable> counts;
  52. private List<DoubleWritable> percentiles;
  53. }
  54. /**
  55. * A comparator to sort the entries in order.
  56. */
  57. public static class MyComparator implements Comparator<Map.Entry<LongWritable, LongWritable>> {
  58. @Override
  59. public int compare(Map.Entry<LongWritable, LongWritable> o1,
  60. Map.Entry<LongWritable, LongWritable> o2) {
  61. return COMPARATOR.compare(o1.getKey(), o2.getKey());
  62. }
  63. }
  64. /**
  65. * Increment the State object with o as the key, and i as the count.
  66. */
  67. private static void increment(State s, LongWritable o, long i) {
  68. if (s.counts == null) {
  69. s.counts = new HashMap<LongWritable, LongWritable>();
  70. }
  71. LongWritable count = s.counts.get(o);
  72. if (count == null) {
  73. // We have to create a new object, because the object o belongs
  74. // to the code that creates it and may get its value changed.
  75. LongWritable key = new LongWritable();
  76. key.set(o.get());
  77. s.counts.put(key, new LongWritable(i));
  78. } else {
  79. count.set(count.get() + i);
  80. }
  81. }
  82. /**
  83. * Get the percentile value.
  84. */
  85. private static double getPercentile(List<Map.Entry<LongWritable, LongWritable>> entriesList,
  86. double position) {
  87. // We may need to do linear interpolation to get the exact percentile
  88. long lower = (long)Math.floor(position);
  89. long higher = (long)Math.ceil(position);
  90. // Linear search since this won't take much time from the total execution anyway
  91. // lower has the range of [0 .. total-1]
  92. // The first entry with accumulated count (lower+1) corresponds to the lower position.
  93. int i = 0;
  94. while (entriesList.get(i).getValue().get() < lower + 1) {
  95. i++;
  96. }
  97. long lowerKey = entriesList.get(i).getKey().get();
  98. if (higher == lower) {
  99. // no interpolation needed because position does not have a fraction
  100. return lowerKey;
  101. }
  102. if (entriesList.get(i).getValue().get() < higher + 1) {
  103. i++;
  104. }
  105. long higherKey = entriesList.get(i).getKey().get();
  106. if (higherKey == lowerKey) {
  107. // no interpolation needed because lower position and higher position has the same key
  108. return lowerKey;
  109. }
  110. // Linear interpolation to get the exact percentile
  111. return (higher - position) * lowerKey + (position - lower) * higherKey;
  112. }
  113. /**
  114. * The evaluator for percentile computation based on long.
  115. */
  116. public static class PercentileLongEvaluator implements UDAFEvaluator {
  117. private final State state;
  118. public PercentileLongEvaluator() {
  119. state = new State();
  120. }
  121. public void init() {
  122. if (state.counts != null) {
  123. // We reuse the same hashmap to reduce new object allocation.
  124. // This means counts can be empty when there is no input data.
  125. state.counts.clear();
  126. }
  127. }
  128. /** Note that percentile can be null in a global aggregation with
  129. * 0 input rows: "select percentile(col, 0.5) from t where false"
  130. * In that case, iterate(null, null) will be called once.
  131. */
  132. public boolean iterate(LongWritable o, Double percentile) {
  133. if (o == null && percentile == null) {
  134. return false;
  135. }
  136. if (state.percentiles == null) {
  137. if (percentile < 0.0 || percentile > 1.0) {
  138. throw new RuntimeException("Percentile value must be within the range of 0 to 1.");
  139. }
  140. state.percentiles = Collections.singletonList(new DoubleWritable(percentile.doubleValue()));
  141. }
  142. if (o != null) {
  143. increment(state, o, 1);
  144. }
  145. return true;
  146. }
  147. public State terminatePartial() {
  148. return state;
  149. }
  150. public boolean merge(State other) {
  151. if (other == null || other.counts == null || other.percentiles == null) {
  152. return false;
  153. }
  154. if (state.percentiles == null) {
  155. state.percentiles = new ArrayList<>(other.percentiles);
  156. }
  157. for (Map.Entry<LongWritable, LongWritable> e: other.counts.entrySet()) {
  158. increment(state, e.getKey(), e.getValue().get());
  159. }
  160. return true;
  161. }
  162. private DoubleWritable result;
  163. public DoubleWritable terminate() {
  164. // No input data.
  165. if (state.counts == null || state.counts.size() == 0) {
  166. return null;
  167. }
  168. // Get all items into an array and sort them.
  169. Set<Map.Entry<LongWritable, LongWritable>> entries = state.counts.entrySet();
  170. List<Map.Entry<LongWritable, LongWritable>> entriesList =
  171. new ArrayList<Map.Entry<LongWritable, LongWritable>>(entries);
  172. Collections.sort(entriesList, new MyComparator());
  173. // Accumulate the counts.
  174. long total = 0;
  175. for (int i = 0; i < entriesList.size(); i++) {
  176. LongWritable count = entriesList.get(i).getValue();
  177. total += count.get();
  178. count.set(total);
  179. }
  180. // Initialize the result.
  181. if (result == null) {
  182. result = new DoubleWritable();
  183. }
  184. // maxPosition is the 1.0 percentile
  185. long maxPosition = total - 1;
  186. double position = maxPosition * state.percentiles.get(0).get();
  187. result.set(getPercentile(entriesList, position));
  188. return result;
  189. }
  190. }
  191. /**
  192. * The evaluator for percentile computation based on long for an array of percentiles.
  193. */
  194. public static class PercentileLongArrayEvaluator implements UDAFEvaluator {
  195. private final State state;
  196. public PercentileLongArrayEvaluator() {
  197. state = new State();
  198. }
  199. public void init() {
  200. if (state.counts != null) {
  201. // We reuse the same hashmap to reduce new object allocation.
  202. // This means counts can be empty when there is no input data.
  203. state.counts.clear();
  204. }
  205. }
  206. public boolean iterate(LongWritable o, List<DoubleWritable> percentiles) {
  207. if (state.percentiles == null) {
  208. if(percentiles != null) {
  209. for (int i = 0; i < percentiles.size(); i++) {
  210. if (percentiles.get(i).get() < 0.0 || percentiles.get(i).get() > 1.0) {
  211. throw new RuntimeException("Percentile value must be within the range of 0 to 1.");
  212. }
  213. }
  214. state.percentiles = new ArrayList<>(percentiles);
  215. }
  216. else {
  217. state.percentiles = Collections.emptyList();
  218. }
  219. }
  220. if (o != null) {
  221. increment(state, o, 1);
  222. }
  223. return true;
  224. }
  225. public State terminatePartial() {
  226. return state;
  227. }
  228. public boolean merge(State other) {
  229. if (other == null || other.counts == null || other.percentiles == null) {
  230. return true;
  231. }
  232. if (state.percentiles == null) {
  233. state.percentiles = new ArrayList<>(other.percentiles);
  234. }
  235. for (Map.Entry<LongWritable, LongWritable> e: other.counts.entrySet()) {
  236. increment(state, e.getKey(), e.getValue().get());
  237. }
  238. return true;
  239. }
  240. private List<DoubleWritable> results;
  241. public List<DoubleWritable> terminate() {
  242. // No input data
  243. if (state.counts == null || state.counts.size() == 0) {
  244. return null;
  245. }
  246. // Get all items into an array and sort them
  247. Set<Map.Entry<LongWritable, LongWritable>> entries = state.counts.entrySet();
  248. List<Map.Entry<LongWritable, LongWritable>> entriesList =
  249. new ArrayList<Map.Entry<LongWritable, LongWritable>>(entries);
  250. Collections.sort(entriesList, new MyComparator());
  251. // accumulate the counts
  252. long total = 0;
  253. for (int i = 0; i < entriesList.size(); i++) {
  254. LongWritable count = entriesList.get(i).getValue();
  255. total += count.get();
  256. count.set(total);
  257. }
  258. // maxPosition is the 1.0 percentile
  259. long maxPosition = total - 1;
  260. // Initialize the results
  261. if (results == null) {
  262. results = new ArrayList<DoubleWritable>();
  263. for (int i = 0; i < state.percentiles.size(); i++) {
  264. results.add(new DoubleWritable());
  265. }
  266. }
  267. // Set the results
  268. for (int i = 0; i < state.percentiles.size(); i++) {
  269. double position = maxPosition * state.percentiles.get(i).get();
  270. results.get(i).set(getPercentile(entriesList, position));
  271. }
  272. return results;
  273. }
  274. }
  275. }