PageRenderTime 65ms CodeModel.GetById 30ms app.highlight 28ms RepoModel.GetById 2ms app.codeStats 0ms

/tags/release-0.1-rc2/hive/external/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/NGramEstimator.java

#
Java | 276 lines | 164 code | 25 blank | 87 comment | 39 complexity | 3dde56bd9d2e612b18882869edbbd97e MD5 | raw file
  1/**
  2 * Licensed to the Apache Software Foundation (ASF) under one
  3 * or more contributor license agreements.  See the NOTICE file
  4 * distributed with this work for additional information
  5 * regarding copyright ownership.  The ASF licenses this file
  6 * to you under the Apache License, Version 2.0 (the
  7 * "License"); you may not use this file except in compliance
  8 * with the License.  You may obtain a copy of the License at
  9 *
 10 *     http://www.apache.org/licenses/LICENSE-2.0
 11 *
 12 * Unless required by applicable law or agreed to in writing, software
 13 * distributed under the License is distributed on an "AS IS" BASIS,
 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15 * See the License for the specific language governing permissions and
 16 * limitations under the License.
 17 */
 18package org.apache.hadoop.hive.ql.udf.generic;
 19
 20import java.util.List;
 21import java.util.ArrayList;
 22import java.util.HashMap;
 23import java.util.Map;
 24import java.util.Collections;
 25import java.util.Iterator;
 26import java.util.Comparator;
 27import org.apache.hadoop.hive.serde2.io.DoubleWritable;
 28import org.apache.hadoop.io.Text;
 29import org.apache.hadoop.hive.ql.metadata.HiveException;
 30import org.apache.commons.logging.Log;
 31import org.apache.commons.logging.LogFactory;
 32
 33/**
 34 * A generic, re-usable n-gram estimation class that supports partial aggregations.
 35 * The algorithm is based on the heuristic from the following paper:
 36 * Yael Ben-Haim and Elad Tom-Tov, "A streaming parallel decision tree algorithm",
 37 * J. Machine Learning Research 11 (2010), pp. 849--872. 
 38 *
 39 * In particular, it is guaranteed that frequencies will be under-counted. With large
 40 * data and a reasonable precision factor, this undercounting appears to be on the order
 41 * of 5%.
 42 */
 43public class NGramEstimator {
 44  /* Class private variables */
 45  private int k;
 46  private int pf;
 47  private int n;
 48  private HashMap<ArrayList<String>, Double> ngrams;
 49  
 50
 51  /**
 52   * Creates a new n-gram estimator object. The 'n' for n-grams is computed dynamically
 53   * when data is fed to the object. 
 54   */
 55  public NGramEstimator() {
 56    k  = 0;
 57    pf = 0;
 58    n  = 0;
 59    ngrams = new HashMap<ArrayList<String>, Double>();
 60  }
 61
 62  /**
 63   * Returns true if the 'k' and 'pf' parameters have been set.
 64   */
 65  public boolean isInitialized() {
 66    return (k != 0);
 67  }
 68
 69  /**
 70   * Sets the 'k' and 'pf' parameters.
 71   */
 72  public void initialize(int pk, int ppf, int pn) throws HiveException {
 73    assert(pk > 0 && ppf > 0 && pn > 0);
 74    k = pk;
 75    pf = ppf;
 76    n = pn;
 77
 78    // enforce a minimum precision factor
 79    if(k * pf < 1000) {
 80      pf = 1000 / k;
 81    }
 82  }
 83
 84  /**
 85   * Resets an n-gram estimator object to its initial state. 
 86   */
 87  public void reset() {
 88    ngrams.clear();
 89    n = pf = k = 0;
 90  }
 91
 92  /**
 93   * Returns the final top-k n-grams in a format suitable for returning to Hive.
 94   */
 95  public ArrayList<Object[]> getNGrams() throws HiveException {
 96    trim(true);
 97    if(ngrams.size() < 1) { // SQL standard - return null for zero elements
 98      return null;
 99    } 
100
101    // Sort the n-gram list by frequencies in descending order
102    ArrayList<Object[]> result = new ArrayList<Object[]>();
103    ArrayList<Map.Entry<ArrayList<String>, Double>> list = new ArrayList(ngrams.entrySet());
104    Collections.sort(list, new Comparator<Map.Entry<ArrayList<String>, Double>>() {
105      public int compare(Map.Entry<ArrayList<String>, Double> o1, 
106                         Map.Entry<ArrayList<String>, Double> o2) {
107        return o2.getValue().compareTo(o1.getValue());
108      }
109    });
110
111    // Convert the n-gram list to a format suitable for Hive
112    for(int i = 0; i < list.size(); i++) {
113      ArrayList<String> key = list.get(i).getKey();
114      Double val = list.get(i).getValue();
115
116      Object[] curGram = new Object[2];
117      ArrayList<Text> ng = new ArrayList<Text>();
118      for(int j = 0; j < key.size(); j++) {
119        ng.add(new Text(key.get(j)));
120      }
121      curGram[0] = ng;
122      curGram[1] = new DoubleWritable(val.doubleValue());
123      result.add(curGram);
124    }
125
126    return result;    
127  }
128
129  /**
130   * Returns the number of n-grams in our buffer.
131   */
132  public int size() {
133    return ngrams.size();
134  }
135
136  /**
137   * Adds a new n-gram to the estimation.
138   *
139   * @param ng The n-gram to add to the estimation
140   */
141  public void add(ArrayList<String> ng) throws HiveException {
142    assert(ng != null && ng.size() > 0 && ng.get(0) != null);
143    Double curFreq = ngrams.get(ng);
144    if(curFreq == null) {
145      // new n-gram
146      curFreq = new Double(1.0);
147    } else {
148      // existing n-gram, just increment count
149      curFreq++;
150    }
151    ngrams.put(ng, curFreq);
152
153    // set 'n' if we haven't done so before
154    if(n == 0) {
155      n = ng.size();
156    } else {
157      if(n != ng.size()) {
158        throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'n'" 
159            + ", which usually is caused by a non-constant expression. Found '"+n+"' and '"
160            + ng.size() + "'.");
161      }
162    }
163
164    // Trim down the total number of n-grams if we've exceeded the maximum amount of memory allowed
165    // 
166    // NOTE: Although 'k'*'pf' specifies the size of the estimation buffer, we don't want to keep
167    //       performing N.log(N) trim operations each time the maximum hashmap size is exceeded.
168    //       To handle this, we *actually* maintain an estimation buffer of size 2*'k'*'pf', and
169    //       trim down to 'k'*'pf' whenever the hashmap size exceeds 2*'k'*'pf'. This really has
170    //       a significant effect when 'k'*'pf' is very high.
171    if(ngrams.size() > k * pf * 2) {
172      trim(false);
173    }
174  }
175
176  /**
177   * Trims an n-gram estimation down to either 'pf' * 'k' n-grams, or 'k' n-grams if 
178   * finalTrim is true.
179   */
180  private void trim(boolean finalTrim) throws HiveException {
181    ArrayList<Map.Entry<ArrayList<String>,Double>> list = new ArrayList(ngrams.entrySet());
182    Collections.sort(list, new Comparator<Map.Entry<ArrayList<String>,Double>>() {
183      public int compare(Map.Entry<ArrayList<String>,Double> o1, 
184                         Map.Entry<ArrayList<String>,Double> o2) {
185        return o1.getValue().compareTo(o2.getValue());
186      }
187    });
188    for(int i = 0; i < list.size() - (finalTrim ? k : pf*k); i++) {
189      ngrams.remove( list.get(i).getKey() );
190    }
191  }
192
193  /**
194   * Takes a serialized n-gram estimator object created by the serialize() method and merges
195   * it with the current n-gram object.
196   *
197   * @param other A serialized n-gram object created by the serialize() method
198   * @see merge
199   */
200  public void merge(List<Text> other) throws HiveException {
201    if(other == null) {
202      return;
203    }
204
205    // Get estimation parameters
206    int otherK = Integer.parseInt(other.get(0).toString());
207    int otherN = Integer.parseInt(other.get(1).toString());
208    int otherPF = Integer.parseInt(other.get(2).toString());
209    if(k > 0 && k != otherK) {
210      throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'k'" 
211          + ", which usually is caused by a non-constant expression. Found '"+k+"' and '"
212          + otherK + "'.");
213    }
214    if(n > 0 && otherN != n) {
215      throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'n'" 
216          + ", which usually is caused by a non-constant expression. Found '"+n+"' and '"
217          + otherN + "'.");
218    }
219    if(pf > 0 && otherPF != pf) {
220      throw new HiveException(getClass().getSimpleName() + ": mismatch in value for 'pf'" 
221          + ", which usually is caused by a non-constant expression. Found '"+pf+"' and '"
222          + otherPF + "'.");
223    }
224    k = otherK;
225    pf = otherPF;
226    n = otherN;
227
228    // Merge the other estimation into the current one
229    for(int i = 3; i < other.size(); i++) {
230      ArrayList<String> key = new ArrayList<String>();
231      for(int j = 0; j < n; j++) {
232        Text word = other.get(i+j);
233        key.add(word.toString());
234      }
235      i += n;
236      double val = Double.parseDouble( other.get(i).toString() );
237      Double myval = ngrams.get(key);
238      if(myval == null) {
239        myval = new Double(val);
240      } else {
241        myval += val;
242      }
243      ngrams.put(key, myval);      
244    }
245
246    trim(false);
247  }
248
249
250  /**
251   * In preparation for a Hive merge() call, serializes the current n-gram estimator object into an
252   * ArrayList of Text objects. This list is deserialized and merged by the 
253   * merge method.
254   *
255   * @return An ArrayList of Hadoop Text objects that represents the current
256   * n-gram estimation.
257   * @see merge(ArrayList<Text>)
258   */
259  public ArrayList<Text> serialize() throws HiveException {
260    ArrayList<Text> result = new ArrayList<Text>();    
261    result.add(new Text(Integer.toString(k)));
262    result.add(new Text(Integer.toString(n)));
263    result.add(new Text(Integer.toString(pf)));
264    for(Iterator<ArrayList<String> > it = ngrams.keySet().iterator(); it.hasNext(); ) {
265      ArrayList<String> mykey = it.next();
266      assert(mykey.size() > 0);
267      for(int i = 0; i < mykey.size(); i++) {
268        result.add(new Text(mykey.get(i)));
269      }
270      Double myval = ngrams.get(mykey);
271      result.add(new Text(myval.toString()));
272    }
273
274    return result;
275  }
276}