PageRenderTime 77ms CodeModel.GetById 26ms RepoModel.GetById 0ms app.codeStats 0ms

/tags/release-0.1-rc2/hive/external/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFContextNGrams.java

#
Java | 422 lines | 322 code | 37 blank | 63 comment | 82 complexity | db2650b1b1263b15aab93fee02d3be21 MD5 | raw file
Possible License(s): Apache-2.0, BSD-3-Clause, JSON, CPL-1.0
  1. /**
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. package org.apache.hadoop.hive.ql.udf.generic;
  19. import java.util.ArrayList;
  20. import java.util.List;
  21. import java.util.Iterator;
  22. import java.util.Set;
  23. import java.util.Map;
  24. import java.util.Collections;
  25. import org.apache.commons.logging.Log;
  26. import org.apache.commons.logging.LogFactory;
  27. import org.apache.hadoop.hive.ql.exec.Description;
  28. import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
  29. import org.apache.hadoop.hive.ql.metadata.HiveException;
  30. import org.apache.hadoop.hive.ql.parse.SemanticException;
  31. import org.apache.hadoop.hive.serde2.io.DoubleWritable;
  32. import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
  33. import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
  34. import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
  35. import org.apache.hadoop.hive.serde2.objectinspector.StandardMapObjectInspector;
  36. import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;
  37. import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
  38. import org.apache.hadoop.hive.serde2.objectinspector.StructField;
  39. import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
  40. import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
  41. import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
  42. import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
  43. import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
  44. import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
  45. import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
  46. import org.apache.hadoop.util.StringUtils;
  47. import org.apache.hadoop.io.Text;
  48. /**
  49. * Estimates the top-k contextual n-grams in arbitrary sequential data using a heuristic.
  50. */
  51. @Description(name = "context_ngrams",
  52. value = "_FUNC_(expr, array<string1, string2, ...>, k, pf) estimates the top-k most " +
  53. "frequent n-grams that fit into the specified context. The second parameter specifies " +
  54. "a string of words that specify the positions of the n-gram elements, with a null value " +
  55. "standing in for a 'blank' that must be filled by an n-gram element.",
  56. extended = "The primary expression must be an array of strings, or an array of arrays of " +
  57. "strings, such as the return type of the sentences() UDF. The second parameter specifies " +
  58. "the context -- for example, array(\"i\", \"love\", null) -- which would estimate the top " +
  59. "'k' words that follow the phrase \"i love\" in the primary expression. The optional " +
  60. "fourth parameter 'pf' controls the memory used by the heuristic. Larger values will " +
  61. "yield better accuracy, but use more memory. Example usage:\n" +
  62. " SELECT context_ngrams(sentences(lower(review)), array(\"i\", \"love\", null, null), 10)" +
  63. " FROM movies\n" +
  64. "would attempt to determine the 10 most common two-word phrases that follow \"i love\" " +
  65. "in a database of free-form natural language movie reviews.")
  66. public class GenericUDAFContextNGrams implements GenericUDAFResolver {
  67. static final Log LOG = LogFactory.getLog(GenericUDAFContextNGrams.class.getName());
  68. @Override
  69. public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
  70. if (parameters.length != 3 && parameters.length != 4) {
  71. throw new UDFArgumentTypeException(parameters.length-1,
  72. "Please specify either three or four arguments.");
  73. }
  74. // Validate the first parameter, which is the expression to compute over. This should be an
  75. // array of strings type, or an array of arrays of strings.
  76. PrimitiveTypeInfo pti;
  77. if (parameters[0].getCategory() != ObjectInspector.Category.LIST) {
  78. throw new UDFArgumentTypeException(0,
  79. "Only list type arguments are accepted but "
  80. + parameters[0].getTypeName() + " was passed as parameter 1.");
  81. }
  82. switch (((ListTypeInfo) parameters[0]).getListElementTypeInfo().getCategory()) {
  83. case PRIMITIVE:
  84. // Parameter 1 was an array of primitives, so make sure the primitives are strings.
  85. pti = (PrimitiveTypeInfo) ((ListTypeInfo) parameters[0]).getListElementTypeInfo();
  86. break;
  87. case LIST:
  88. // Parameter 1 was an array of arrays, so make sure that the inner arrays contain
  89. // primitive strings.
  90. ListTypeInfo lti = (ListTypeInfo)
  91. ((ListTypeInfo) parameters[0]).getListElementTypeInfo();
  92. pti = (PrimitiveTypeInfo) lti.getListElementTypeInfo();
  93. break;
  94. default:
  95. throw new UDFArgumentTypeException(0,
  96. "Only arrays of strings or arrays of arrays of strings are accepted but "
  97. + parameters[0].getTypeName() + " was passed as parameter 1.");
  98. }
  99. if(pti.getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
  100. throw new UDFArgumentTypeException(0,
  101. "Only array<string> or array<array<string>> is allowed, but "
  102. + parameters[0].getTypeName() + " was passed as parameter 1.");
  103. }
  104. // Validate the second parameter, which should be an array of strings
  105. if(parameters[1].getCategory() != ObjectInspector.Category.LIST ||
  106. ((ListTypeInfo) parameters[1]).getListElementTypeInfo().getCategory() !=
  107. ObjectInspector.Category.PRIMITIVE) {
  108. throw new UDFArgumentTypeException(1, "Only arrays of strings are accepted but "
  109. + parameters[1].getTypeName() + " was passed as parameter 2.");
  110. }
  111. if(((PrimitiveTypeInfo) ((ListTypeInfo)parameters[1]).getListElementTypeInfo()).
  112. getPrimitiveCategory() != PrimitiveObjectInspector.PrimitiveCategory.STRING) {
  113. throw new UDFArgumentTypeException(1, "Only arrays of strings are accepted but "
  114. + parameters[1].getTypeName() + " was passed as parameter 2.");
  115. }
  116. // Validate the third parameter, which should be an integer to represent 'k'
  117. if(parameters[2].getCategory() != ObjectInspector.Category.PRIMITIVE) {
  118. throw new UDFArgumentTypeException(2, "Only integers are accepted but "
  119. + parameters[2].getTypeName() + " was passed as parameter 3.");
  120. }
  121. switch(((PrimitiveTypeInfo) parameters[2]).getPrimitiveCategory()) {
  122. case BYTE:
  123. case SHORT:
  124. case INT:
  125. case LONG:
  126. break;
  127. default:
  128. throw new UDFArgumentTypeException(2, "Only integers are accepted but "
  129. + parameters[2].getTypeName() + " was passed as parameter 3.");
  130. }
  131. // If the fourth parameter -- precision factor 'pf' -- has been specified, make sure it's
  132. // an integer.
  133. if(parameters.length == 4) {
  134. if(parameters[3].getCategory() != ObjectInspector.Category.PRIMITIVE) {
  135. throw new UDFArgumentTypeException(3, "Only integers are accepted but "
  136. + parameters[3].getTypeName() + " was passed as parameter 4.");
  137. }
  138. switch(((PrimitiveTypeInfo) parameters[3]).getPrimitiveCategory()) {
  139. case BYTE:
  140. case SHORT:
  141. case INT:
  142. case LONG:
  143. break;
  144. default:
  145. throw new UDFArgumentTypeException(3, "Only integers are accepted but "
  146. + parameters[3].getTypeName() + " was passed as parameter 4.");
  147. }
  148. }
  149. return new GenericUDAFContextNGramEvaluator();
  150. }
  151. /**
  152. * A constant-space heuristic to estimate the top-k contextual n-grams.
  153. */
  154. public static class GenericUDAFContextNGramEvaluator extends GenericUDAFEvaluator {
  155. // For PARTIAL1 and COMPLETE: ObjectInspectors for original data
  156. private StandardListObjectInspector outerInputOI;
  157. private StandardListObjectInspector innerInputOI;
  158. private StandardListObjectInspector contextListOI;
  159. private PrimitiveObjectInspector contextOI;
  160. private PrimitiveObjectInspector inputOI;
  161. private PrimitiveObjectInspector kOI;
  162. private PrimitiveObjectInspector pOI;
  163. // For PARTIAL2 and FINAL: ObjectInspectors for partial aggregations
  164. private StandardListObjectInspector loi;
  165. @Override
  166. public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
  167. super.init(m, parameters);
  168. // Init input object inspectors
  169. if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
  170. outerInputOI = (StandardListObjectInspector) parameters[0];
  171. if(outerInputOI.getListElementObjectInspector().getCategory() ==
  172. ObjectInspector.Category.LIST) {
  173. // We're dealing with input that is an array of arrays of strings
  174. innerInputOI = (StandardListObjectInspector) outerInputOI.getListElementObjectInspector();
  175. inputOI = (PrimitiveObjectInspector) innerInputOI.getListElementObjectInspector();
  176. } else {
  177. // We're dealing with input that is an array of strings
  178. inputOI = (PrimitiveObjectInspector) outerInputOI.getListElementObjectInspector();
  179. innerInputOI = null;
  180. }
  181. contextListOI = (StandardListObjectInspector) parameters[1];
  182. contextOI = (PrimitiveObjectInspector) contextListOI.getListElementObjectInspector();
  183. kOI = (PrimitiveObjectInspector) parameters[2];
  184. if(parameters.length == 4) {
  185. pOI = (PrimitiveObjectInspector) parameters[3];
  186. } else {
  187. pOI = null;
  188. }
  189. } else {
  190. // Init the list object inspector for handling partial aggregations
  191. loi = (StandardListObjectInspector) parameters[0];
  192. }
  193. // Init output object inspectors.
  194. //
  195. // The return type for a partial aggregation is still a list of strings.
  196. //
  197. // The return type for FINAL and COMPLETE is a full aggregation result, which is
  198. // an array of structures containing the n-gram and its estimated frequency.
  199. if (m == Mode.PARTIAL1 || m == Mode.PARTIAL2) {
  200. return ObjectInspectorFactory.getStandardListObjectInspector(
  201. PrimitiveObjectInspectorFactory.writableStringObjectInspector);
  202. } else {
  203. // Final return type that goes back to Hive: a list of structs with n-grams and their
  204. // estimated frequencies.
  205. ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
  206. foi.add(ObjectInspectorFactory.getStandardListObjectInspector(
  207. PrimitiveObjectInspectorFactory.writableStringObjectInspector));
  208. foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
  209. ArrayList<String> fname = new ArrayList<String>();
  210. fname.add("ngram");
  211. fname.add("estfrequency");
  212. return ObjectInspectorFactory.getStandardListObjectInspector(
  213. ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi) );
  214. }
  215. }
  216. @Override
  217. public void merge(AggregationBuffer agg, Object obj) throws HiveException {
  218. if(obj == null) {
  219. return;
  220. }
  221. NGramAggBuf myagg = (NGramAggBuf) agg;
  222. List<Text> partial = (List<Text>) loi.getList(obj);
  223. // remove the context words from the end of the list
  224. int contextSize = Integer.parseInt( ((Text)partial.get(partial.size()-1)).toString() );
  225. partial.remove(partial.size()-1);
  226. if(myagg.context.size() > 0) {
  227. if(contextSize != myagg.context.size()) {
  228. throw new HiveException(getClass().getSimpleName() + ": found a mismatch in the" +
  229. " context string lengths. This is usually caused by passing a non-constant" +
  230. " expression for the context.");
  231. }
  232. } else {
  233. for(int i = partial.size()-contextSize; i < partial.size(); i++) {
  234. String word = partial.get(i).toString();
  235. if(word.equals("")) {
  236. myagg.context.add( null );
  237. } else {
  238. myagg.context.add( word );
  239. }
  240. }
  241. partial.subList(partial.size()-contextSize, partial.size()).clear();
  242. myagg.nge.merge(partial);
  243. }
  244. }
  245. @Override
  246. public Object terminatePartial(AggregationBuffer agg) throws HiveException {
  247. NGramAggBuf myagg = (NGramAggBuf) agg;
  248. ArrayList<Text> result = myagg.nge.serialize();
  249. // push the context on to the end of the serialized n-gram estimation
  250. for(int i = 0; i < myagg.context.size(); i++) {
  251. if(myagg.context.get(i) == null) {
  252. result.add(new Text(""));
  253. } else {
  254. result.add(new Text(myagg.context.get(i)));
  255. }
  256. }
  257. result.add(new Text(Integer.toString(myagg.context.size())));
  258. return result;
  259. }
  260. // Finds all contextual n-grams in a sequence of words, and passes the n-grams to the
  261. // n-gram estimator object
  262. private void processNgrams(NGramAggBuf agg, ArrayList<String> seq) throws HiveException {
  263. // generate n-grams wherever the context matches
  264. assert(agg.context.size() > 0);
  265. ArrayList<String> ng = new ArrayList<String>();
  266. for(int i = seq.size() - agg.context.size(); i >= 0; i--) {
  267. // check if the context matches
  268. boolean contextMatches = true;
  269. ng.clear();
  270. for(int j = 0; j < agg.context.size(); j++) {
  271. String contextWord = agg.context.get(j);
  272. if(contextWord == null) {
  273. ng.add(seq.get(i+j));
  274. } else {
  275. if(!contextWord.equals(seq.get(i+j))) {
  276. contextMatches = false;
  277. break;
  278. }
  279. }
  280. }
  281. // add to n-gram estimation only if the context matches
  282. if(contextMatches) {
  283. agg.nge.add(ng);
  284. ng = new ArrayList<String>();
  285. }
  286. }
  287. }
  288. @Override
  289. public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
  290. assert (parameters.length == 3 || parameters.length == 4);
  291. if(parameters[0] == null || parameters[1] == null || parameters[2] == null) {
  292. return;
  293. }
  294. NGramAggBuf myagg = (NGramAggBuf) agg;
  295. // Parse out the context and 'k' if we haven't already done so, and while we're at it,
  296. // also parse out the precision factor 'pf' if the user has supplied one.
  297. if(!myagg.nge.isInitialized()) {
  298. int k = PrimitiveObjectInspectorUtils.getInt(parameters[2], kOI);
  299. int pf = 0;
  300. if(k < 1) {
  301. throw new HiveException(getClass().getSimpleName() + " needs 'k' to be at least 1, "
  302. + "but you supplied " + k);
  303. }
  304. if(parameters.length == 4) {
  305. pf = PrimitiveObjectInspectorUtils.getInt(parameters[3], pOI);
  306. if(pf < 1) {
  307. throw new HiveException(getClass().getSimpleName() + " needs 'pf' to be at least 1, "
  308. + "but you supplied " + pf);
  309. }
  310. } else {
  311. pf = 1; // placeholder; minimum pf value is enforced in NGramEstimator
  312. }
  313. // Parse out the context and make sure it isn't empty
  314. myagg.context.clear();
  315. List<Text> context = (List<Text>) contextListOI.getList(parameters[1]);
  316. int contextNulls = 0;
  317. for(int i = 0; i < context.size(); i++) {
  318. String word = PrimitiveObjectInspectorUtils.getString(context.get(i), contextOI);
  319. if(word == null) {
  320. contextNulls++;
  321. }
  322. myagg.context.add(word);
  323. }
  324. if(context.size() == 0) {
  325. throw new HiveException(getClass().getSimpleName() + " needs a context array " +
  326. "with at least one element.");
  327. }
  328. if(contextNulls == 0) {
  329. throw new HiveException(getClass().getSimpleName() + " the context array needs to " +
  330. "contain at least one 'null' value to indicate what should be counted.");
  331. }
  332. // Set parameters in the n-gram estimator object
  333. myagg.nge.initialize(k, pf, contextNulls);
  334. }
  335. // get the input expression
  336. List<Text> outer = (List<Text>) outerInputOI.getList(parameters[0]);
  337. if(innerInputOI != null) {
  338. // we're dealing with an array of arrays of strings
  339. for(int i = 0; i < outer.size(); i++) {
  340. List<Text> inner = (List<Text>) innerInputOI.getList(outer.get(i));
  341. ArrayList<String> words = new ArrayList<String>();
  342. for(int j = 0; j < inner.size(); j++) {
  343. String word = PrimitiveObjectInspectorUtils.getString(inner.get(j), inputOI);
  344. words.add(word);
  345. }
  346. // parse out n-grams, update frequency counts
  347. processNgrams(myagg, words);
  348. }
  349. } else {
  350. // we're dealing with an array of strings
  351. ArrayList<String> words = new ArrayList<String>();
  352. for(int i = 0; i < outer.size(); i++) {
  353. String word = PrimitiveObjectInspectorUtils.getString(outer.get(i), inputOI);
  354. words.add(word);
  355. }
  356. // parse out n-grams, update frequency counts
  357. processNgrams(myagg, words);
  358. }
  359. }
  360. @Override
  361. public Object terminate(AggregationBuffer agg) throws HiveException {
  362. NGramAggBuf myagg = (NGramAggBuf) agg;
  363. return myagg.nge.getNGrams();
  364. }
  365. // Aggregation buffer methods.
  366. static class NGramAggBuf implements AggregationBuffer {
  367. ArrayList<String> context;
  368. NGramEstimator nge;
  369. };
  370. @Override
  371. public AggregationBuffer getNewAggregationBuffer() throws HiveException {
  372. NGramAggBuf result = new NGramAggBuf();
  373. result.nge = new NGramEstimator();
  374. result.context = new ArrayList<String>();
  375. reset(result);
  376. return result;
  377. }
  378. @Override
  379. public void reset(AggregationBuffer agg) throws HiveException {
  380. NGramAggBuf result = (NGramAggBuf) agg;
  381. result.context.clear();
  382. result.nge.reset();
  383. }
  384. }
  385. }