PageRenderTime 40ms CodeModel.GetById 15ms app.highlight 21ms RepoModel.GetById 1ms app.codeStats 0ms

/tags/release-0.1-rc2/hive/external/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFContextNGrams.java

#
Java | 422 lines | 322 code | 37 blank | 63 comment | 82 complexity | db2650b1b1263b15aab93fee02d3be21 MD5 | raw file
  1/**
  2 * Licensed to the Apache Software Foundation (ASF) under one
  3 * or more contributor license agreements.  See the NOTICE file
  4 * distributed with this work for additional information
  5 * regarding copyright ownership.  The ASF licenses this file
  6 * to you under the Apache License, Version 2.0 (the
  7 * "License"); you may not use this file except in compliance
  8 * with the License.  You may obtain a copy of the License at
  9 *
 10 *     http://www.apache.org/licenses/LICENSE-2.0
 11 *
 12 * Unless required by applicable law or agreed to in writing, software
 13 * distributed under the License is distributed on an "AS IS" BASIS,
 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15 * See the License for the specific language governing permissions and
 16 * limitations under the License.
 17 */
 18package org.apache.hadoop.hive.ql.udf.generic;
 19
 20import java.util.ArrayList;
 21import java.util.List;
 22import java.util.Iterator;
 23import java.util.Set;
 24import java.util.Map;
 25import java.util.Collections;
 26
 27import org.apache.commons.logging.Log;
 28import org.apache.commons.logging.LogFactory;
 29import org.apache.hadoop.hive.ql.exec.Description;
 30import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 31import org.apache.hadoop.hive.ql.metadata.HiveException;
 32import org.apache.hadoop.hive.ql.parse.SemanticException;
 33import org.apache.hadoop.hive.serde2.io.DoubleWritable;
 34import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 35import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 36import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 37import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
 38import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
 39import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 40import org.apache.hadoop.hive.serde2.objectinspector.StructField;
 41import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
 42import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
 43import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 44import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
 45import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
 46import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
 47import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 48import org.apache.hadoop.util.StringUtils;
 49import org.apache.hadoop.io.Text;
 50
 51/**
 52 * Estimates the top-k contextual n-grams in arbitrary sequential data using a heuristic.
 53 */
 54@Description(name = "context_ngrams",
 55    value = "_FUNC_(expr, array<string1, string2, ...>, k, pf) estimates the top-k most " +
 56      "frequent n-grams that fit into the specified context. The second parameter specifies " +
 57      "a string of words that specify the positions of the n-gram elements, with a null value " +
 58      "standing in for a 'blank' that must be filled by an n-gram element.",
 59    extended = "The primary expression must be an array of strings, or an array of arrays of " +
 60      "strings, such as the return type of the sentences() UDF. The second parameter specifies " +
 61      "the context -- for example, array(\"i\", \"love\", null) -- which would estimate the top " +
 62      "'k' words that follow the phrase \"i love\" in the primary expression. The optional " +
 63      "fourth parameter 'pf' controls the memory used by the heuristic. Larger values will " +
 64      "yield better accuracy, but use more memory. Example usage:\n" +
 65      "  SELECT context_ngrams(sentences(lower(review)), array(\"i\", \"love\", null, null), 10)" +
 66      " FROM movies\n" +
 67      "would attempt to determine the 10 most common two-word phrases that follow \"i love\" " +
 68      "in a database of free-form natural language movie reviews.")
 69public class GenericUDAFContextNGrams implements GenericUDAFResolver {
 70  static final Log LOG = LogFactory.getLog(GenericUDAFContextNGrams.class.getName());
 71
 72  @Override
 73  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
 74    if (parameters.length != 3 && parameters.length != 4) {
 75      throw new UDFArgumentTypeException(parameters.length-1,
 76          "Please specify either three or four arguments.");
 77    }
 78    
 79    // Validate the first parameter, which is the expression to compute over. This should be an
 80    // array of strings type, or an array of arrays of strings.
 81    PrimitiveTypeInfo pti;
 82    if (parameters[0].getCategory() != ObjectInspector.Category.LIST) {
 83      throw new UDFArgumentTypeException(0,
 84          "Only list type arguments are accepted but "
 85          + parameters[0].getTypeName() + " was passed as parameter 1.");
 86    }
 87    switch (((ListTypeInfo) parameters[0]).getListElementTypeInfo().getCategory()) {
 88    case PRIMITIVE:
 89      // Parameter 1 was an array of primitives, so make sure the primitives are strings.
 90      pti = (PrimitiveTypeInfo) ((ListTypeInfo) parameters[0]).getListElementTypeInfo();
 91      break;
 92
 93    case LIST:
 94      // Parameter 1 was an array of arrays, so make sure that the inner arrays contain
 95      // primitive strings.
 96      ListTypeInfo lti = (ListTypeInfo)
 97                         ((ListTypeInfo) parameters[0]).getListElementTypeInfo();
 98      pti = (PrimitiveTypeInfo) lti.getListElementTypeInfo();
 99      break;
100
101    default:
102      throw new UDFArgumentTypeException(0,
103          "Only arrays of strings or arrays of arrays of strings are accepted but "
104          + parameters[0].getTypeName() + " was passed as parameter 1.");
105    }
106    if(pti.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
107      throw new UDFArgumentTypeException(0,
108          "Only array<string> or array<array<string>> is allowed, but " 
109          + parameters[0].getTypeName() + " was passed as parameter 1.");
110    }
111
112    // Validate the second parameter, which should be an array of strings
113    if(parameters[1].getCategory() != ObjectInspector.Category.LIST ||
114       ((ListTypeInfo) parameters[1]).getListElementTypeInfo().getCategory() !=
115         ObjectInspector.Category.PRIMITIVE) {
116      throw new UDFArgumentTypeException(1, "Only arrays of strings are accepted but "
117          + parameters[1].getTypeName() + " was passed as parameter 2.");
118    } 
119    if(((PrimitiveTypeInfo) ((ListTypeInfo)parameters[1]).getListElementTypeInfo()).
120        getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
121      throw new UDFArgumentTypeException(1, "Only arrays of strings are accepted but "
122          + parameters[1].getTypeName() + " was passed as parameter 2.");
123    }
124
125    // Validate the third parameter, which should be an integer to represent 'k'
126    if(parameters[2].getCategory() != ObjectInspector.Category.PRIMITIVE) {
127      throw new UDFArgumentTypeException(2, "Only integers are accepted but "
128            + parameters[2].getTypeName() + " was passed as parameter 3.");
129    } 
130    switch(((PrimitiveTypeInfo) parameters[2]).getPrimitiveCategory()) {
131    case BYTE:
132    case SHORT:
133    case INT:
134    case LONG:
135      break;
136
137    default:
138      throw new UDFArgumentTypeException(2, "Only integers are accepted but "
139            + parameters[2].getTypeName() + " was passed as parameter 3.");
140    }
141
142    // If the fourth parameter -- precision factor 'pf' -- has been specified, make sure it's
143    // an integer.
144    if(parameters.length == 4) {
145      if(parameters[3].getCategory() != ObjectInspector.Category.PRIMITIVE) {
146        throw new UDFArgumentTypeException(3, "Only integers are accepted but "
147            + parameters[3].getTypeName() + " was passed as parameter 4.");
148      } 
149      switch(((PrimitiveTypeInfo) parameters[3]).getPrimitiveCategory()) {
150      case BYTE:
151      case SHORT:
152      case INT:
153      case LONG:
154        break;
155
156      default:
157        throw new UDFArgumentTypeException(3, "Only integers are accepted but "
158            + parameters[3].getTypeName() + " was passed as parameter 4.");
159      }
160    }
161
162    return new GenericUDAFContextNGramEvaluator();
163  }
164
165  /**
166   * A constant-space heuristic to estimate the top-k contextual n-grams.
167   */
168  public static class GenericUDAFContextNGramEvaluator extends GenericUDAFEvaluator {
169    // For PARTIAL1 and COMPLETE: ObjectInspectors for original data
170    private StandardListObjectInspector outerInputOI;
171    private StandardListObjectInspector innerInputOI;
172    private StandardListObjectInspector contextListOI;
173    private PrimitiveObjectInspector contextOI;
174    private PrimitiveObjectInspector inputOI;
175    private PrimitiveObjectInspector kOI;
176    private PrimitiveObjectInspector pOI;
177
178    // For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations 
179    private StandardListObjectInspector loi;
180
181    @Override
182    public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
183      super.init(m, parameters);
184
185      // Init input object inspectors
186      if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
187        outerInputOI = (StandardListObjectInspector) parameters[0];
188        if(outerInputOI.getListElementObjectInspector().getCategory() ==
189            ObjectInspector.Category.LIST) {
190          // We're dealing with input that is an array of arrays of strings
191          innerInputOI = (StandardListObjectInspector) outerInputOI.getListElementObjectInspector();
192          inputOI = (PrimitiveObjectInspector) innerInputOI.getListElementObjectInspector();
193        } else {
194          // We're dealing with input that is an array of strings
195          inputOI = (PrimitiveObjectInspector) outerInputOI.getListElementObjectInspector();
196          innerInputOI = null;
197        }
198        contextListOI = (StandardListObjectInspector) parameters[1];
199        contextOI = (PrimitiveObjectInspector) contextListOI.getListElementObjectInspector();
200        kOI = (PrimitiveObjectInspector) parameters[2];
201        if(parameters.length == 4) {
202          pOI = (PrimitiveObjectInspector) parameters[3];
203        } else {
204          pOI = null;
205        }
206      } else {
207          // Init the list object inspector for handling partial aggregations
208          loi = (StandardListObjectInspector) parameters[0];
209      }
210
211      // Init output object inspectors.
212      //
213      // The return type for a partial aggregation is still a list of strings.
214      // 
215      // The return type for FINAL and COMPLETE is a full aggregation result, which is 
216      // an array of structures containing the n-gram and its estimated frequency.
217      if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
218        return ObjectInspectorFactory.getStandardListObjectInspector(
219            PrimitiveObjectInspectorFactory.writableStringObjectInspector);
220      } else {
221        // Final return type that goes back to Hive: a list of structs with n-grams and their
222        // estimated frequencies.
223        ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
224        foi.add(ObjectInspectorFactory.getStandardListObjectInspector(
225                  PrimitiveObjectInspectorFactory.writableStringObjectInspector));
226        foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
227        ArrayList<String> fname = new ArrayList<String>();
228        fname.add("ngram");
229        fname.add("estfrequency");               
230        return ObjectInspectorFactory.getStandardListObjectInspector(
231                 ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi) );
232      }
233    }
234
235    @Override
236    public void merge(AggregationBuffer agg, Object obj) throws HiveException {
237      if(obj == null) { 
238        return;
239      }
240      NGramAggBuf myagg = (NGramAggBuf) agg;
241      List<Text> partial = (List<Text>) loi.getList(obj);
242
243      // remove the context words from the end of the list
244      int contextSize = Integer.parseInt( ((Text)partial.get(partial.size()-1)).toString() );
245      partial.remove(partial.size()-1);
246      if(myagg.context.size() > 0)  {
247        if(contextSize != myagg.context.size()) {
248          throw new HiveException(getClass().getSimpleName() + ": found a mismatch in the" +
249              " context string lengths. This is usually caused by passing a non-constant" +
250              " expression for the context.");
251        }
252      } else {
253        for(int i = partial.size()-contextSize; i < partial.size(); i++) {
254          String word = partial.get(i).toString();
255          if(word.equals("")) {
256            myagg.context.add( null );
257          } else {
258            myagg.context.add( word );
259          } 
260        }
261        partial.subList(partial.size()-contextSize, partial.size()).clear();
262        myagg.nge.merge(partial);
263      }
264    }
265
266    @Override
267    public Object terminatePartial(AggregationBuffer agg) throws HiveException {
268      NGramAggBuf myagg = (NGramAggBuf) agg;
269      ArrayList<Text> result = myagg.nge.serialize();
270
271      // push the context on to the end of the serialized n-gram estimation
272      for(int i = 0; i < myagg.context.size(); i++) {
273        if(myagg.context.get(i) == null) {
274          result.add(new Text(""));
275        } else {
276          result.add(new Text(myagg.context.get(i)));
277        }
278      }
279      result.add(new Text(Integer.toString(myagg.context.size())));
280
281      return result;
282    }
283
284    // Finds all contextual n-grams in a sequence of words, and passes the n-grams to the
285    // n-gram estimator object
286    private void processNgrams(NGramAggBuf agg, ArrayList<String> seq) throws HiveException {
287      // generate n-grams wherever the context matches
288      assert(agg.context.size() > 0);
289      ArrayList<String> ng = new ArrayList<String>();
290      for(int i = seq.size() - agg.context.size(); i >= 0; i--) {
291        // check if the context matches
292        boolean contextMatches = true;
293        ng.clear();
294        for(int j = 0; j < agg.context.size(); j++) {
295          String contextWord = agg.context.get(j);
296          if(contextWord == null) {
297            ng.add(seq.get(i+j));
298          } else {
299            if(!contextWord.equals(seq.get(i+j))) {
300              contextMatches = false;
301              break;
302            }
303          }
304        }
305
306        // add to n-gram estimation only if the context matches
307        if(contextMatches) {
308          agg.nge.add(ng);
309          ng = new ArrayList<String>();
310        }
311      }
312    }
313
314    @Override
315    public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
316      assert (parameters.length == 3 || parameters.length == 4);
317      if(parameters[0] == null || parameters[1] == null || parameters[2] == null) {
318        return;
319      }
320      NGramAggBuf myagg = (NGramAggBuf) agg;
321    
322      // Parse out the context and 'k' if we haven't already done so, and while we're at it,
323      // also parse out the precision factor 'pf' if the user has supplied one.
324      if(!myagg.nge.isInitialized()) {
325        int k = PrimitiveObjectInspectorUtils.getInt(parameters[2], kOI);
326        int pf = 0;
327        if(k < 1) {
328          throw new HiveException(getClass().getSimpleName() + " needs 'k' to be at least 1, "
329                                  + "but you supplied " + k);
330        }
331        if(parameters.length == 4) {
332          pf = PrimitiveObjectInspectorUtils.getInt(parameters[3], pOI);
333          if(pf < 1) {
334            throw new HiveException(getClass().getSimpleName() + " needs 'pf' to be at least 1, "
335                + "but you supplied " + pf);
336          }
337        } else {
338          pf = 1; // placeholder; minimum pf value is enforced in NGramEstimator
339        }
340
341        // Parse out the context and make sure it isn't empty
342        myagg.context.clear();
343        List<Text> context = (List<Text>) contextListOI.getList(parameters[1]);
344        int contextNulls = 0;
345        for(int i = 0; i < context.size(); i++) {
346          String word = PrimitiveObjectInspectorUtils.getString(context.get(i), contextOI);
347          if(word == null) {
348            contextNulls++;
349          }
350          myagg.context.add(word);
351        }
352        if(context.size() == 0) {
353          throw new HiveException(getClass().getSimpleName() + " needs a context array " +
354            "with at least one element.");
355        }
356        if(contextNulls == 0) {
357          throw new HiveException(getClass().getSimpleName() + " the context array needs to " +
358            "contain at least one 'null' value to indicate what should be counted.");
359        }
360
361        // Set parameters in the n-gram estimator object
362        myagg.nge.initialize(k, pf, contextNulls);
363      }
364
365      // get the input expression
366      List<Text> outer = (List<Text>) outerInputOI.getList(parameters[0]);
367      if(innerInputOI != null) {
368        // we're dealing with an array of arrays of strings
369        for(int i = 0; i < outer.size(); i++) {
370          List<Text> inner = (List<Text>) innerInputOI.getList(outer.get(i));
371          ArrayList<String> words = new ArrayList<String>();
372          for(int j = 0; j < inner.size(); j++) {
373            String word = PrimitiveObjectInspectorUtils.getString(inner.get(j), inputOI);
374            words.add(word);
375          }
376
377          // parse out n-grams, update frequency counts
378          processNgrams(myagg, words);
379        } 
380      } else {
381        // we're dealing with an array of strings
382        ArrayList<String> words = new ArrayList<String>();
383        for(int i = 0; i < outer.size(); i++) {
384          String word = PrimitiveObjectInspectorUtils.getString(outer.get(i), inputOI);
385          words.add(word);
386        }
387
388        // parse out n-grams, update frequency counts
389        processNgrams(myagg, words);
390      }
391    }
392
393    @Override
394    public Object terminate(AggregationBuffer agg) throws HiveException {
395      NGramAggBuf myagg = (NGramAggBuf) agg;
396      return myagg.nge.getNGrams();
397    }
398
399
400    // Aggregation buffer methods. 
401    static class NGramAggBuf implements AggregationBuffer {
402      ArrayList<String> context;
403      NGramEstimator nge;
404    };
405
406    @Override
407    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
408      NGramAggBuf result = new NGramAggBuf();
409      result.nge = new NGramEstimator();
410      result.context = new ArrayList<String>();
411      reset(result);
412      return result;
413    }
414
415    @Override
416    public void reset(AggregationBuffer agg) throws HiveException {
417      NGramAggBuf result = (NGramAggBuf) agg;
418      result.context.clear();
419      result.nge.reset();
420    }
421  }
422}