PageRenderTime 57ms CodeModel.GetById 21ms RepoModel.GetById 0ms app.codeStats 0ms

/tags/release-0.1-rc2/hive/external/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCovariance.java

#
Java | 335 lines | 231 code | 41 blank | 63 comment | 37 complexity | f7aca57ae2ddf04bc10f998d770b9427 MD5 | raw file
Possible License(s): Apache-2.0, BSD-3-Clause, JSON, CPL-1.0
  1. /**
  2. * Licensed to the Apache Software Foundation (ASF) under one
  3. * or more contributor license agreements. See the NOTICE file
  4. * distributed with this work for additional information
  5. * regarding copyright ownership. The ASF licenses this file
  6. * to you under the Apache License, Version 2.0 (the
  7. * "License"); you may not use this file except in compliance
  8. * with the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. package org.apache.hadoop.hive.ql.udf.generic;
  19. import java.util.ArrayList;
  20. import org.apache.commons.logging.Log;
  21. import org.apache.commons.logging.LogFactory;
  22. import org.apache.hadoop.hive.ql.exec.Description;
  23. import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
  24. import org.apache.hadoop.hive.ql.metadata.HiveException;
  25. import org.apache.hadoop.hive.ql.parse.SemanticException;
  26. import org.apache.hadoop.hive.serde2.io.DoubleWritable;
  27. import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
  28. import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
  29. import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
  30. import org.apache.hadoop.hive.serde2.objectinspector.StructField;
  31. import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
  32. import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
  33. import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
  34. import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
  35. import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
  36. import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
  37. import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
  38. import org.apache.hadoop.io.LongWritable;
  39. import org.apache.hadoop.util.StringUtils;
  40. /**
  41. * Compute the covariance covar_pop(x, y), using the following one-pass method
  42. * (ref. "Formulas for Robust, One-Pass Parallel Computation of Covariances and
  43. * Arbitrary-Order Statistical Moments", Philippe Pebay, Sandia Labs):
  44. *
  45. * Incremental:
  46. * n : <count>
  47. * mx_n = mx_(n-1) + [x_n - mx_(n-1)]/n : <xavg>
  48. * my_n = my_(n-1) + [y_n - my_(n-1)]/n : <yavg>
  49. * c_n = c_(n-1) + (x_n - mx_(n-1))*(y_n - my_n) : <covariance * n>
  50. *
  51. * Merge:
  52. * c_X = c_A + c_B + (mx_A - mx_B)*(my_A - my_B)*n_A*n_B/n_X
  53. *
  54. */
  55. @Description(name = "covariance,covar_pop",
  56. value = "_FUNC_(x,y) - Returns the population covariance of a set of number pairs",
  57. extended = "The function takes as arguments any pair of numeric types and returns a double.\n"
  58. + "Any pair with a NULL is ignored. If the function is applied to an empty set, NULL\n"
  59. + "will be returned. Otherwise, it computes the following:\n"
  60. + " (SUM(x*y)-SUM(x)*SUM(y)/COUNT(x,y))/COUNT(x,y)\n"
  61. + "where neither x nor y is null.")
  62. public class GenericUDAFCovariance extends AbstractGenericUDAFResolver {
  63. static final Log LOG = LogFactory.getLog(GenericUDAFCovariance.class.getName());
  64. @Override
  65. public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
  66. if (parameters.length != 2) {
  67. throw new UDFArgumentTypeException(parameters.length - 1,
  68. "Exactly two arguments are expected.");
  69. }
  70. if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
  71. throw new UDFArgumentTypeException(0,
  72. "Only primitive type arguments are accepted but "
  73. + parameters[0].getTypeName() + " is passed.");
  74. }
  75. if (parameters[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
  76. throw new UDFArgumentTypeException(1,
  77. "Only primitive type arguments are accepted but "
  78. + parameters[1].getTypeName() + " is passed.");
  79. }
  80. switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
  81. case BYTE:
  82. case SHORT:
  83. case INT:
  84. case LONG:
  85. case FLOAT:
  86. case DOUBLE:
  87. switch (((PrimitiveTypeInfo) parameters[1]).getPrimitiveCategory()) {
  88. case BYTE:
  89. case SHORT:
  90. case INT:
  91. case LONG:
  92. case FLOAT:
  93. case DOUBLE:
  94. return new GenericUDAFCovarianceEvaluator();
  95. case STRING:
  96. case BOOLEAN:
  97. default:
  98. throw new UDFArgumentTypeException(1,
  99. "Only numeric or string type arguments are accepted but "
  100. + parameters[1].getTypeName() + " is passed.");
  101. }
  102. case STRING:
  103. case BOOLEAN:
  104. default:
  105. throw new UDFArgumentTypeException(0,
  106. "Only numeric or string type arguments are accepted but "
  107. + parameters[0].getTypeName() + " is passed.");
  108. }
  109. }
  110. /**
  111. * Evaluate the variance using the algorithm described in
  112. * http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance,
  113. * presumably by PĂŠbay, Philippe (2008), in "Formulas for Robust,
  114. * One-Pass Parallel Computation of Covariances and Arbitrary-Order
  115. * Statistical Moments", Technical Report SAND2008-6212,
  116. * Sandia National Laboratories,
  117. * http://infoserve.sandia.gov/sand_doc/2008/086212.pdf
  118. *
  119. * Incremental:
  120. * n : <count>
  121. * mx_n = mx_(n-1) + [x_n - mx_(n-1)]/n : <xavg>
  122. * my_n = my_(n-1) + [y_n - my_(n-1)]/n : <yavg>
  123. * c_n = c_(n-1) + (x_n - mx_(n-1))*(y_n - my_n) : <covariance * n>
  124. *
  125. * Merge:
  126. * c_X = c_A + c_B + (mx_A - mx_B)*(my_A - my_B)*n_A*n_B/n_X
  127. *
  128. * This one-pass algorithm is stable.
  129. *
  130. */
  131. public static class GenericUDAFCovarianceEvaluator extends GenericUDAFEvaluator {
  132. // For PARTIAL1 and COMPLETE
  133. private PrimitiveObjectInspector xInputOI;
  134. private PrimitiveObjectInspector yInputOI;
  135. // For PARTIAL2 and FINAL
  136. private StructObjectInspector soi;
  137. private StructField countField;
  138. private StructField xavgField;
  139. private StructField yavgField;
  140. private StructField covarField;
  141. private LongObjectInspector countFieldOI;
  142. private DoubleObjectInspector xavgFieldOI;
  143. private DoubleObjectInspector yavgFieldOI;
  144. private DoubleObjectInspector covarFieldOI;
  145. // For PARTIAL1 and PARTIAL2
  146. private Object[] partialResult;
  147. // For FINAL and COMPLETE
  148. private DoubleWritable result;
  149. @Override
  150. public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
  151. super.init(m, parameters);
  152. // init input
  153. if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
  154. assert (parameters.length == 2);
  155. xInputOI = (PrimitiveObjectInspector) parameters[0];
  156. yInputOI = (PrimitiveObjectInspector) parameters[1];
  157. } else {
  158. assert (parameters.length == 1);
  159. soi = (StructObjectInspector) parameters[0];
  160. countField = soi.getStructFieldRef("count");
  161. xavgField = soi.getStructFieldRef("xavg");
  162. yavgField = soi.getStructFieldRef("yavg");
  163. covarField = soi.getStructFieldRef("covar");
  164. countFieldOI =
  165. (LongObjectInspector) countField.getFieldObjectInspector();
  166. xavgFieldOI =
  167. (DoubleObjectInspector) xavgField.getFieldObjectInspector();
  168. yavgFieldOI =
  169. (DoubleObjectInspector) yavgField.getFieldObjectInspector();
  170. covarFieldOI =
  171. (DoubleObjectInspector) covarField.getFieldObjectInspector();
  172. }
  173. // init output
  174. if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
  175. // The output of a partial aggregation is a struct containing
  176. // a long count, two double averages, and a double covariance.
  177. ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
  178. foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
  179. foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
  180. foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
  181. foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
  182. ArrayList<String> fname = new ArrayList<String>();
  183. fname.add("count");
  184. fname.add("xavg");
  185. fname.add("yavg");
  186. fname.add("covar");
  187. partialResult = new Object[4];
  188. partialResult[0] = new LongWritable(0);
  189. partialResult[1] = new DoubleWritable(0);
  190. partialResult[2] = new DoubleWritable(0);
  191. partialResult[3] = new DoubleWritable(0);
  192. return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
  193. } else {
  194. setResult(new DoubleWritable(0));
  195. return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
  196. }
  197. }
  198. static class StdAgg implements AggregationBuffer {
  199. long count; // number n of elements
  200. double xavg; // average of x elements
  201. double yavg; // average of y elements
  202. double covar; // n times the covariance
  203. };
  204. @Override
  205. public AggregationBuffer getNewAggregationBuffer() throws HiveException {
  206. StdAgg result = new StdAgg();
  207. reset(result);
  208. return result;
  209. }
  210. @Override
  211. public void reset(AggregationBuffer agg) throws HiveException {
  212. StdAgg myagg = (StdAgg) agg;
  213. myagg.count = 0;
  214. myagg.xavg = 0;
  215. myagg.yavg = 0;
  216. myagg.covar = 0;
  217. }
  218. private boolean warned = false;
  219. @Override
  220. public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
  221. assert (parameters.length == 2);
  222. Object px = parameters[0];
  223. Object py = parameters[1];
  224. if (px != null && py != null) {
  225. StdAgg myagg = (StdAgg) agg;
  226. double vx = PrimitiveObjectInspectorUtils.getDouble(px, xInputOI);
  227. double vy = PrimitiveObjectInspectorUtils.getDouble(py, yInputOI);
  228. myagg.count++;
  229. myagg.yavg = myagg.yavg + (vy - myagg.yavg) / myagg.count;
  230. if (myagg.count > 1) {
  231. myagg.covar += (vx - myagg.xavg) * (vy - myagg.yavg);
  232. }
  233. myagg.xavg = myagg.xavg + (vx - myagg.xavg) / myagg.count;
  234. }
  235. }
  236. @Override
  237. public Object terminatePartial(AggregationBuffer agg) throws HiveException {
  238. StdAgg myagg = (StdAgg) agg;
  239. ((LongWritable) partialResult[0]).set(myagg.count);
  240. ((DoubleWritable) partialResult[1]).set(myagg.xavg);
  241. ((DoubleWritable) partialResult[2]).set(myagg.yavg);
  242. ((DoubleWritable) partialResult[3]).set(myagg.covar);
  243. return partialResult;
  244. }
  245. @Override
  246. public void merge(AggregationBuffer agg, Object partial) throws HiveException {
  247. if (partial != null) {
  248. StdAgg myagg = (StdAgg) agg;
  249. Object partialCount = soi.getStructFieldData(partial, countField);
  250. Object partialXAvg = soi.getStructFieldData(partial, xavgField);
  251. Object partialYAvg = soi.getStructFieldData(partial, yavgField);
  252. Object partialCovar = soi.getStructFieldData(partial, covarField);
  253. long nA = myagg.count;
  254. long nB = countFieldOI.get(partialCount);
  255. if (nA == 0) {
  256. // Just copy the information since there is nothing so far
  257. myagg.count = countFieldOI.get(partialCount);
  258. myagg.xavg = xavgFieldOI.get(partialXAvg);
  259. myagg.yavg = yavgFieldOI.get(partialYAvg);
  260. myagg.covar = covarFieldOI.get(partialCovar);
  261. }
  262. if (nA != 0 && nB != 0) {
  263. // Merge the two partials
  264. double xavgA = myagg.xavg;
  265. double yavgA = myagg.yavg;
  266. double xavgB = xavgFieldOI.get(partialXAvg);
  267. double yavgB = yavgFieldOI.get(partialYAvg);
  268. double covarB = covarFieldOI.get(partialCovar);
  269. myagg.count += nB;
  270. myagg.xavg = (xavgA * nA + xavgB * nB) / myagg.count;
  271. myagg.yavg = (yavgA * nA + yavgB * nB) / myagg.count;
  272. myagg.covar +=
  273. covarB + (xavgA - xavgB) * (yavgA - yavgB) * ((double) (nA * nB) / myagg.count);
  274. }
  275. }
  276. }
  277. @Override
  278. public Object terminate(AggregationBuffer agg) throws HiveException {
  279. StdAgg myagg = (StdAgg) agg;
  280. if (myagg.count == 0) { // SQL standard - return null for zero elements
  281. return null;
  282. } else {
  283. getResult().set(myagg.covar / (myagg.count));
  284. return getResult();
  285. }
  286. }
  287. public void setResult(DoubleWritable result) {
  288. this.result = result;
  289. }
  290. public DoubleWritable getResult() {
  291. return result;
  292. }
  293. }
  294. }