PageRenderTime 42ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/tags/release-0.1-rc2/hive/external/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/NGramEstimator.java

#
Java | 276 lines | 164 code | 25 blank | 87 comment | 39 complexity | 3dde56bd9d2e612b18882869edbbd97e MD5 | raw file
Possible License(s): Apache-2.0, BSD-3-Clause, JSON, CPL-1.0
  1. /**
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. package org.apache.hadoop.hive.ql.udf.generic;
  19. import java.util.List;
  20. import java.util.ArrayList;
  21. import java.util.HashMap;
  22. import java.util.Map;
  23. import java.util.Collections;
  24. import java.util.Iterator;
  25. import java.util.Comparator;
  26. import org.apache.hadoop.hive.serde2.io.DoubleWritable;
  27. import org.apache.hadoop.io.Text;
  28. import org.apache.hadoop.hive.ql.metadata.HiveException;
  29. import org.apache.commons.logging.Log;
  30. import org.apache.commons.logging.LogFactory;
  31. /**
  32. * A generic, re-usable n-gram estimation class that supports partial aggregations.
  33. * The algorithm is based on the heuristic from the following paper:
  34. * Yael Ben-Haim and Elad Tom-Tov, "A streaming parallel decision tree algorithm",
  35. * J. Machine Learning Research 11 (2010), pp. 849--872.
  36. *
  37. * In particular, it is guaranteed that frequencies will be under-counted. With large
  38. * data and a reasonable precision factor, this undercounting appears to be on the order
  39. * of 5%.
  40. */
  41. public class NGramEstimator {
  42. /* Class private variables */
  43. private int k;
  44. private int pf;
  45. private int n;
  46. private HashMap<ArrayList<String>, Double> ngrams;
  47. /**
  48. * Creates a new n-gram estimator object. The 'n' for n-grams is computed dynamically
  49. * when data is fed to the object.
  50. */
  51. public NGramEstimator() {
  52. k = 0;
  53. pf = 0;
  54. n = 0;
  55. ngrams = new HashMap<ArrayList<String>, Double>();
  56. }
  57. /**
  58. * Returns true if the 'k' and 'pf' parameters have been set.
  59. */
  60. public boolean isInitialized() {
  61. return (k != 0);
  62. }
  63. /**
  64. * Sets the 'k' and 'pf' parameters.
  65. */
  66. public void initialize(int pk, int ppf, int pn) throws HiveException {
  67. assert(pk > 0 && ppf > 0 && pn > 0);
  68. k = pk;
  69. pf = ppf;
  70. n = pn;
  71. // enforce a minimum precision factor
  72. if(k * pf < 1000) {
  73. pf = 1000 / k;
  74. }
  75. }
  76. /**
  77. * Resets an n-gram estimator object to its initial state.
  78. */
  79. public void reset() {
  80. ngrams.clear();
  81. n = pf = k = 0;
  82. }
  83. /**
  84. * Returns the final top-k n-grams in a format suitable for returning to Hive.
  85. */
  86. public ArrayList<Object[]> getNGrams() throws HiveException {
  87. trim(true);
  88. if(ngrams.size() < 1) { // SQL standard - return null for zero elements
  89. return null;
  90. }
  91. // Sort the n-gram list by frequencies in descending order
  92. ArrayList<Object[]> result = new ArrayList<Object[]>();
  93. ArrayList<Map.Entry<ArrayList<String>, Double>> list = new ArrayList(ngrams.entrySet());
  94. Collections.sort(list, new Comparator<Map.Entry<ArrayList<String>, Double>>() {
  95. public int compare(Map.Entry<ArrayList<String>, Double> o1,
  96. Map.Entry<ArrayList<String>, Double> o2) {
  97. return o2.getValue().compareTo(o1.getValue());
  98. }
  99. });
  100. // Convert the n-gram list to a format suitable for Hive
  101. for(int i = 0; i < list.size(); i++) {
  102. ArrayList<String> key = list.get(i).getKey();
  103. Double val = list.get(i).getValue();
  104. Object[] curGram = new Object[2];
  105. ArrayList<Text> ng = new ArrayList<Text>();
  106. for(int j = 0; j < key.size(); j++) {
  107. ng.add(new Text(key.get(j)));
  108. }
  109. curGram[0] = ng;
  110. curGram[1] = new DoubleWritable(val.doubleValue());
  111. result.add(curGram);
  112. }
  113. return result;
  114. }
  115. /**
  116. * Returns the number of n-grams in our buffer.
  117. */
  118. public int size() {
  119. return ngrams.size();
  120. }
  121. /**
  122. * Adds a new n-gram to the estimation.
  123. *
  124. * @param ng The n-gram to add to the estimation
  125. */
  126. public void add(ArrayList<String> ng) throws HiveException {
  127. assert(ng != null && ng.size() > 0 && ng.get(0) != null);
  128. Double curFreq = ngrams.get(ng);
  129. if(curFreq == null) {
  130. // new n-gram
  131. curFreq = new Double(1.0);
  132. } else {
  133. // existing n-gram, just increment count
  134. curFreq++;
  135. }
  136. ngrams.put(ng, curFreq);
  137. // set 'n' if we haven't done so before
  138. if(n == 0) {
  139. n = ng.size();
  140. } else {
  141. if(n != ng.size()) {
  142. throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'n'"
  143. + ", which usually is caused by a non-constant expression. Found '"+n+"' and '"
  144. + ng.size() + "'.");
  145. }
  146. }
  147. // Trim down the total number of n-grams if we've exceeded the maximum amount of memory allowed
  148. //
  149. // NOTE: Although 'k'*'pf' specifies the size of the estimation buffer, we don't want to keep
  150. // performing N.log(N) trim operations each time the maximum hashmap size is exceeded.
  151. // To handle this, we *actually* maintain an estimation buffer of size 2*'k'*'pf', and
  152. // trim down to 'k'*'pf' whenever the hashmap size exceeds 2*'k'*'pf'. This really has
  153. // a significant effect when 'k'*'pf' is very high.
  154. if(ngrams.size() > k * pf * 2) {
  155. trim(false);
  156. }
  157. }
  158. /**
  159. * Trims an n-gram estimation down to either 'pf' * 'k' n-grams, or 'k' n-grams if
  160. * finalTrim is true.
  161. */
  162. private void trim(boolean finalTrim) throws HiveException {
  163. ArrayList<Map.Entry<ArrayList<String>,Double>> list = new ArrayList(ngrams.entrySet());
  164. Collections.sort(list, new Comparator<Map.Entry<ArrayList<String>,Double>>() {
  165. public int compare(Map.Entry<ArrayList<String>,Double> o1,
  166. Map.Entry<ArrayList<String>,Double> o2) {
  167. return o1.getValue().compareTo(o2.getValue());
  168. }
  169. });
  170. for(int i = 0; i < list.size() - (finalTrim ? k : pf*k); i++) {
  171. ngrams.remove( list.get(i).getKey() );
  172. }
  173. }
  174. /**
  175. * Takes a serialized n-gram estimator object created by the serialize() method and merges
  176. * it with the current n-gram object.
  177. *
  178. * @param other A serialized n-gram object created by the serialize() method
  179. * @see merge
  180. */
  181. public void merge(List<Text> other) throws HiveException {
  182. if(other == null) {
  183. return;
  184. }
  185. // Get estimation parameters
  186. int otherK = Integer.parseInt(other.get(0).toString());
  187. int otherN = Integer.parseInt(other.get(1).toString());
  188. int otherPF = Integer.parseInt(other.get(2).toString());
  189. if(k > 0 && k != otherK) {
  190. throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'k'"
  191. + ", which usually is caused by a non-constant expression. Found '"+k+"' and '"
  192. + otherK + "'.");
  193. }
  194. if(n > 0 && otherN != n) {
  195. throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'n'"
  196. + ", which usually is caused by a non-constant expression. Found '"+n+"' and '"
  197. + otherN + "'.");
  198. }
  199. if(pf > 0 && otherPF != pf) {
  200. throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'pf'"
  201. + ", which usually is caused by a non-constant expression. Found '"+pf+"' and '"
  202. + otherPF + "'.");
  203. }
  204. k = otherK;
  205. pf = otherPF;
  206. n = otherN;
  207. // Merge the other estimation into the current one
  208. for(int i = 3; i < other.size(); i++) {
  209. ArrayList<String> key = new ArrayList<String>();
  210. for(int j = 0; j < n; j++) {
  211. Text word = other.get(i+j);
  212. key.add(word.toString());
  213. }
  214. i += n;
  215. double val = Double.parseDouble( other.get(i).toString() );
  216. Double myval = ngrams.get(key);
  217. if(myval == null) {
  218. myval = new Double(val);
  219. } else {
  220. myval += val;
  221. }
  222. ngrams.put(key, myval);
  223. }
  224. trim(false);
  225. }
  226. /**
  227. * In preparation for a Hive merge() call, serializes the current n-gram estimator object into an
  228. * ArrayList of Text objects. This list is deserialized and merged by the
  229. * merge method.
  230. *
  231. * @return An ArrayList of Hadoop Text objects that represents the current
  232. * n-gram estimation.
  233. * @see merge(ArrayList<Text>)
  234. */
  235. public ArrayList<Text> serialize() throws HiveException {
  236. ArrayList<Text> result = new ArrayList<Text>();
  237. result.add(new Text(Integer.toString(k)));
  238. result.add(new Text(Integer.toString(n)));
  239. result.add(new Text(Integer.toString(pf)));
  240. for(Iterator<ArrayList<String> > it = ngrams.keySet().iterator(); it.hasNext(); ) {
  241. ArrayList<String> mykey = it.next();
  242. assert(mykey.size() > 0);
  243. for(int i = 0; i < mykey.size(); i++) {
  244. result.add(new Text(mykey.get(i)));
  245. }
  246. Double myval = ngrams.get(mykey);
  247. result.add(new Text(myval.toString()));
  248. }
  249. return result;
  250. }
  251. }