PageRenderTime 33ms CodeModel.GetById 12ms app.highlight 18ms RepoModel.GetById 0ms app.codeStats 1ms

/tags/release-0.1-rc2/hive/external/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFCovariance.java

#
Java | 335 lines | 231 code | 41 blank | 63 comment | 37 complexity | f7aca57ae2ddf04bc10f998d770b9427 MD5 | raw file
  1/**
  2 * Licensed to the Apache Software Foundation (ASF) under one
  3 * or more contributor license agreements.  See the NOTICE file
  4 * distributed with this work for additional information
  5 * regarding copyright ownership.  The ASF licenses this file
  6 * to you under the Apache License, Version 2.0 (the
  7 * "License"); you may not use this file except in compliance
  8 * with the License.  You may obtain a copy of the License at
  9 *
 10 *     http://www.apache.org/licenses/LICENSE-2.0
 11 *
 12 * Unless required by applicable law or agreed to in writing, software
 13 * distributed under the License is distributed on an "AS IS" BASIS,
 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15 * See the License for the specific language governing permissions and
 16 * limitations under the License.
 17 */
 18package org.apache.hadoop.hive.ql.udf.generic;
 19
 20import java.util.ArrayList;
 21
 22import org.apache.commons.logging.Log;
 23import org.apache.commons.logging.LogFactory;
 24import org.apache.hadoop.hive.ql.exec.Description;
 25import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 26import org.apache.hadoop.hive.ql.metadata.HiveException;
 27import org.apache.hadoop.hive.ql.parse.SemanticException;
 28import org.apache.hadoop.hive.serde2.io.DoubleWritable;
 29import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 30import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 31import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 32import org.apache.hadoop.hive.serde2.objectinspector.StructField;
 33import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 34import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
 35import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
 36import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 37import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
 38import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
 39import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 40import org.apache.hadoop.io.LongWritable;
 41import org.apache.hadoop.util.StringUtils;
 42
 43/**
 44 * Compute the covariance covar_pop(x, y), using the following one-pass method
 45 * (ref. "Formulas for Robust, One-Pass Parallel Computation of Covariances and
 46 *  Arbitrary-Order Statistical Moments", Philippe Pebay, Sandia Labs):
 47 *
 48 *  Incremental:
 49 *   n : <count>
 50 *   mx_n = mx_(n-1) + [x_n - mx_(n-1)]/n : <xavg>
 51 *   my_n = my_(n-1) + [y_n - my_(n-1)]/n : <yavg>
 52 *   c_n = c_(n-1) + (x_n - mx_(n-1))*(y_n - my_n) : <covariance * n>
 53 *
 54 *  Merge:
 55 *   c_X = c_A + c_B + (mx_A - mx_B)*(my_A - my_B)*n_A*n_B/n_X
 56 *
 57 */
 58@Description(name = "covariance,covar_pop",
 59    value = "_FUNC_(x,y) - Returns the population covariance of a set of number pairs",
 60    extended = "The function takes as arguments any pair of numeric types and returns a double.\n"
 61        + "Any pair with a NULL is ignored. If the function is applied to an empty set, NULL\n"
 62        + "will be returned. Otherwise, it computes the following:\n"
 63        + "   (SUM(x*y)-SUM(x)*SUM(y)/COUNT(x,y))/COUNT(x,y)\n"
 64        + "where neither x nor y is null.")
 65public class GenericUDAFCovariance extends AbstractGenericUDAFResolver {
 66
 67  static final Log LOG = LogFactory.getLog(GenericUDAFCovariance.class.getName());
 68
 69  @Override
 70  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
 71    if (parameters.length != 2) {
 72      throw new UDFArgumentTypeException(parameters.length - 1,
 73          "Exactly two arguments are expected.");
 74    }
 75
 76    if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
 77      throw new UDFArgumentTypeException(0,
 78          "Only primitive type arguments are accepted but "
 79          + parameters[0].getTypeName() + " is passed.");
 80    }
 81
 82    if (parameters[1].getCategory() != ObjectInspector.Category.PRIMITIVE) {
 83        throw new UDFArgumentTypeException(1,
 84            "Only primitive type arguments are accepted but "
 85            + parameters[1].getTypeName() + " is passed.");
 86    }
 87
 88    switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
 89    case BYTE:
 90    case SHORT:
 91    case INT:
 92    case LONG:
 93    case FLOAT:
 94    case DOUBLE:
 95      switch (((PrimitiveTypeInfo) parameters[1]).getPrimitiveCategory()) {
 96      case BYTE:
 97      case SHORT:
 98      case INT:
 99      case LONG:
100      case FLOAT:
101      case DOUBLE:
102        return new GenericUDAFCovarianceEvaluator();
103      case STRING:
104      case BOOLEAN:
105      default:
106        throw new UDFArgumentTypeException(1,
107            "Only numeric or string type arguments are accepted but "
108            + parameters[1].getTypeName() + " is passed.");
109      }
110    case STRING:
111    case BOOLEAN:
112    default:
113      throw new UDFArgumentTypeException(0,
114          "Only numeric or string type arguments are accepted but "
115          + parameters[0].getTypeName() + " is passed.");
116    }
117  }
118
119  /**
120   * Evaluate the variance using the algorithm described in
121   * http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance,
122   * presumably by  PĂŠbay, Philippe (2008), in "Formulas for Robust,
123   * One-Pass Parallel Computation of Covariances and Arbitrary-Order
124   * Statistical Moments", Technical Report SAND2008-6212,
125   * Sandia National Laboratories,
126   * http://infoserve.sandia.gov/sand_doc/2008/086212.pdf
127   *
128   *  Incremental:
129   *   n : <count>
130   *   mx_n = mx_(n-1) + [x_n - mx_(n-1)]/n : <xavg>
131   *   my_n = my_(n-1) + [y_n - my_(n-1)]/n : <yavg>
132   *   c_n = c_(n-1) + (x_n - mx_(n-1))*(y_n - my_n) : <covariance * n>
133   *
134   *  Merge:
135   *   c_X = c_A + c_B + (mx_A - mx_B)*(my_A - my_B)*n_A*n_B/n_X
136   *
137   *  This one-pass algorithm is stable.
138   *
139   */
140  public static class GenericUDAFCovarianceEvaluator extends GenericUDAFEvaluator {
141
142    // For PARTIAL1 and COMPLETE
143    private PrimitiveObjectInspector xInputOI;
144    private PrimitiveObjectInspector yInputOI;
145
146    // For PARTIAL2 and FINAL
147    private StructObjectInspector soi;
148    private StructField countField;
149    private StructField xavgField;
150    private StructField yavgField;
151    private StructField covarField;
152    private LongObjectInspector countFieldOI;
153    private DoubleObjectInspector xavgFieldOI;
154    private DoubleObjectInspector yavgFieldOI;
155    private DoubleObjectInspector covarFieldOI;
156
157    // For PARTIAL1 and PARTIAL2
158    private Object[] partialResult;
159
160    // For FINAL and COMPLETE
161    private DoubleWritable result;
162
163    @Override
164    public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
165      super.init(m, parameters);
166
167      // init input
168      if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
169        assert (parameters.length == 2);
170        xInputOI = (PrimitiveObjectInspector) parameters[0];
171        yInputOI = (PrimitiveObjectInspector) parameters[1];
172      } else {
173        assert (parameters.length == 1);
174        soi = (StructObjectInspector) parameters[0];
175
176        countField = soi.getStructFieldRef("count");
177        xavgField = soi.getStructFieldRef("xavg");
178        yavgField = soi.getStructFieldRef("yavg");
179        covarField = soi.getStructFieldRef("covar");
180
181        countFieldOI =
182            (LongObjectInspector) countField.getFieldObjectInspector();
183        xavgFieldOI =
184            (DoubleObjectInspector) xavgField.getFieldObjectInspector();
185        yavgFieldOI =
186            (DoubleObjectInspector) yavgField.getFieldObjectInspector();
187        covarFieldOI =
188            (DoubleObjectInspector) covarField.getFieldObjectInspector();
189      }
190
191      // init output
192      if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
193        // The output of a partial aggregation is a struct containing
194        // a long count, two double averages, and a double covariance.
195
196        ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
197
198        foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
199        foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
200        foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
201        foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
202
203        ArrayList<String> fname = new ArrayList<String>();
204        fname.add("count");
205        fname.add("xavg");
206        fname.add("yavg");
207        fname.add("covar");
208
209        partialResult = new Object[4];
210        partialResult[0] = new LongWritable(0);
211        partialResult[1] = new DoubleWritable(0);
212        partialResult[2] = new DoubleWritable(0);
213        partialResult[3] = new DoubleWritable(0);
214
215        return ObjectInspectorFactory.getStandardStructObjectInspector(fname, foi);
216
217      } else {
218        setResult(new DoubleWritable(0));
219        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
220      }
221    }
222
223    static class StdAgg implements AggregationBuffer {
224      long count; // number n of elements
225      double xavg; // average of x elements
226      double yavg; // average of y elements
227      double covar; // n times the covariance
228    };
229
230    @Override
231    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
232      StdAgg result = new StdAgg();
233      reset(result);
234      return result;
235    }
236
237    @Override
238    public void reset(AggregationBuffer agg) throws HiveException {
239      StdAgg myagg = (StdAgg) agg;
240      myagg.count = 0;
241      myagg.xavg = 0;
242      myagg.yavg = 0;
243      myagg.covar = 0;
244    }
245
246    private boolean warned = false;
247
248    @Override
249    public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
250      assert (parameters.length == 2);
251      Object px = parameters[0];
252      Object py = parameters[1];
253      if (px != null && py != null) {
254        StdAgg myagg = (StdAgg) agg;
255        double vx = PrimitiveObjectInspectorUtils.getDouble(px, xInputOI);
256        double vy = PrimitiveObjectInspectorUtils.getDouble(py, yInputOI);
257        myagg.count++;
258        myagg.yavg = myagg.yavg + (vy - myagg.yavg) / myagg.count;
259        if (myagg.count > 1) {
260            myagg.covar += (vx - myagg.xavg) * (vy - myagg.yavg);
261        }
262        myagg.xavg = myagg.xavg + (vx - myagg.xavg) / myagg.count;
263      }
264    }
265
266    @Override
267    public Object terminatePartial(AggregationBuffer agg) throws HiveException {
268      StdAgg myagg = (StdAgg) agg;
269      ((LongWritable) partialResult[0]).set(myagg.count);
270      ((DoubleWritable) partialResult[1]).set(myagg.xavg);
271      ((DoubleWritable) partialResult[2]).set(myagg.yavg);
272      ((DoubleWritable) partialResult[3]).set(myagg.covar);
273      return partialResult;
274    }
275
276    @Override
277    public void merge(AggregationBuffer agg, Object partial) throws HiveException {
278      if (partial != null) {
279        StdAgg myagg = (StdAgg) agg;
280
281        Object partialCount = soi.getStructFieldData(partial, countField);
282        Object partialXAvg = soi.getStructFieldData(partial, xavgField);
283        Object partialYAvg = soi.getStructFieldData(partial, yavgField);
284        Object partialCovar = soi.getStructFieldData(partial, covarField);
285
286        long nA = myagg.count;
287        long nB = countFieldOI.get(partialCount);
288
289        if (nA == 0) {
290            // Just copy the information since there is nothing so far
291            myagg.count = countFieldOI.get(partialCount);
292            myagg.xavg = xavgFieldOI.get(partialXAvg);
293            myagg.yavg = yavgFieldOI.get(partialYAvg);
294            myagg.covar = covarFieldOI.get(partialCovar);
295        }
296
297        if (nA != 0 && nB != 0) {
298          // Merge the two partials
299          double xavgA = myagg.xavg;
300          double yavgA = myagg.yavg;
301          double xavgB = xavgFieldOI.get(partialXAvg);
302          double yavgB = yavgFieldOI.get(partialYAvg);
303          double covarB = covarFieldOI.get(partialCovar);
304
305          myagg.count += nB;
306          myagg.xavg = (xavgA * nA + xavgB * nB) / myagg.count;
307          myagg.yavg = (yavgA * nA + yavgB * nB) / myagg.count;
308          myagg.covar +=
309              covarB + (xavgA - xavgB) * (yavgA - yavgB) * ((double) (nA * nB) / myagg.count);
310        }
311      }
312    }
313
314    @Override
315    public Object terminate(AggregationBuffer agg) throws HiveException {
316      StdAgg myagg = (StdAgg) agg;
317
318      if (myagg.count == 0) { // SQL standard - return null for zero elements
319          return null;
320      } else {
321          getResult().set(myagg.covar / (myagg.count));
322          return getResult();
323      }
324    }
325
326    public void setResult(DoubleWritable result) {
327      this.result = result;
328    }
329
330    public DoubleWritable getResult() {
331      return result;
332    }
333  }
334
335}