PageRenderTime 33ms CodeModel.GetById 17ms app.highlight 12ms RepoModel.GetById 2ms app.codeStats 0ms

/tags/release-0.1-rc2/hive/external/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFVariance.java

#
Java | 289 lines | 200 code | 40 blank | 49 comment | 34 complexity | bcc6aabcf735745cc315f1b0ebb46d13 MD5 | raw file
  1/**
  2 * Licensed to the Apache Software Foundation (ASF) under one
  3 * or more contributor license agreements.  See the NOTICE file
  4 * distributed with this work for additional information
  5 * regarding copyright ownership.  The ASF licenses this file
  6 * to you under the Apache License, Version 2.0 (the
  7 * "License"); you may not use this file except in compliance
  8 * with the License.  You may obtain a copy of the License at
  9 *
 10 *     http://www.apache.org/licenses/LICENSE-2.0
 11 *
 12 * Unless required by applicable law or agreed to in writing, software
 13 * distributed under the License is distributed on an "AS IS" BASIS,
 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15 * See the License for the specific language governing permissions and
 16 * limitations under the License.
 17 */
 18package org.apache.hadoop.hive.ql.udf.generic;
 19
 20import java.util.ArrayList;
 21
 22import org.apache.commons.logging.Log;
 23import org.apache.commons.logging.LogFactory;
 24import org.apache.hadoop.hive.ql.exec.Description;
 25import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
 26import org.apache.hadoop.hive.ql.metadata.HiveException;
 27import org.apache.hadoop.hive.ql.parse.SemanticException;
 28import org.apache.hadoop.hive.serde2.io.DoubleWritable;
 29import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
 30import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
 31import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
 32import org.apache.hadoop.hive.serde2.objectinspector.StructField;
 33import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
 34import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector;
 35import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector;
 36import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
 37import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
 38import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
 39import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
 40import org.apache.hadoop.io.LongWritable;
 41import org.apache.hadoop.util.StringUtils;
 42
 43/**
 44 * Compute the variance. This class is extended by: GenericUDAFVarianceSample
 45 * GenericUDAFStd GenericUDAFStdSample
 46 * 
 47 */
 48@Description(name = "variance,var_pop",
 49    value = "_FUNC_(x) - Returns the variance of a set of numbers")
 50public class GenericUDAFVariance extends AbstractGenericUDAFResolver {
 51
 52  static final Log LOG = LogFactory.getLog(GenericUDAFVariance.class.getName());
 53
 54  @Override
 55  public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
 56    if (parameters.length != 1) {
 57      throw new UDFArgumentTypeException(parameters.length - 1,
 58          "Exactly one argument is expected.");
 59    }
 60
 61    if (parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
 62      throw new UDFArgumentTypeException(0,
 63          "Only primitive type arguments are accepted but "
 64          + parameters[0].getTypeName() + " is passed.");
 65    }
 66    switch (((PrimitiveTypeInfo) parameters[0]).getPrimitiveCategory()) {
 67    case BYTE:
 68    case SHORT:
 69    case INT:
 70    case LONG:
 71    case FLOAT:
 72    case DOUBLE:
 73    case STRING:
 74      return new GenericUDAFVarianceEvaluator();
 75    case BOOLEAN:
 76    default:
 77      throw new UDFArgumentTypeException(0,
 78          "Only numeric or string type arguments are accepted but "
 79          + parameters[0].getTypeName() + " is passed.");
 80    }
 81  }
 82
 83  /**
 84   * Evaluate the variance using the algorithm described by Chan, Golub, and LeVeque in
 85   * "Algorithms for computing the sample variance: analysis and recommendations"
 86   * The American Statistician, 37 (1983) pp. 242--247.
 87   * 
 88   * variance = variance1 + variance2 + n/(m*(m+n)) * pow(((m/n)*t1 - t2),2)
 89   * 
 90   * where: - variance is sum[x-avg^2] (this is actually n times the variance)
 91   * and is updated at every step. - n is the count of elements in chunk1 - m is
 92   * the count of elements in chunk2 - t1 = sum of elements in chunk1, t2 = 
 93   * sum of elements in chunk2.
 94   *
 95   * This algorithm was proven to be numerically stable by J.L. Barlow in
 96   * "Error analysis of a pairwise summation algorithm to compute sample variance"
 97   * Numer. Math, 58 (1991) pp. 583--590
 98   * 
 99   */
100  public static class GenericUDAFVarianceEvaluator extends GenericUDAFEvaluator {
101
102    // For PARTIAL1 and COMPLETE
103    private PrimitiveObjectInspector inputOI;
104
105    // For PARTIAL2 and FINAL
106    private StructObjectInspector soi;
107    private StructField countField;
108    private StructField sumField;
109    private StructField varianceField;
110    private LongObjectInspector countFieldOI;
111    private DoubleObjectInspector sumFieldOI;
112    private DoubleObjectInspector varianceFieldOI;
113
114    // For PARTIAL1 and PARTIAL2
115    private Object[] partialResult;
116
117    // For FINAL and COMPLETE
118    private DoubleWritable result;
119
120    @Override
121    public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
122      assert (parameters.length == 1);
123      super.init(m, parameters);
124
125      // init input
126      if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {
127        inputOI = (PrimitiveObjectInspector) parameters[0];
128      } else {
129        soi = (StructObjectInspector) parameters[0];
130
131        countField = soi.getStructFieldRef("count");
132        sumField = soi.getStructFieldRef("sum");
133        varianceField = soi.getStructFieldRef("variance");
134
135        countFieldOI = (LongObjectInspector) countField
136            .getFieldObjectInspector();
137        sumFieldOI = (DoubleObjectInspector) sumField.getFieldObjectInspector();
138        varianceFieldOI = (DoubleObjectInspector) varianceField
139            .getFieldObjectInspector();
140      }
141
142      // init output
143      if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {
144        // The output of a partial aggregation is a struct containing
145        // a long count and doubles sum and variance.
146
147        ArrayList<ObjectInspector> foi = new ArrayList<ObjectInspector>();
148
149        foi.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
150        foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
151        foi.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
152
153        ArrayList<String> fname = new ArrayList<String>();
154        fname.add("count");
155        fname.add("sum");
156        fname.add("variance");
157
158        partialResult = new Object[3];
159        partialResult[0] = new LongWritable(0);
160        partialResult[1] = new DoubleWritable(0);
161        partialResult[2] = new DoubleWritable(0);
162
163        return ObjectInspectorFactory.getStandardStructObjectInspector(fname,
164            foi);
165
166      } else {
167        setResult(new DoubleWritable(0));
168        return PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
169      }
170    }
171
172    static class StdAgg implements AggregationBuffer {
173      long count; // number of elements
174      double sum; // sum of elements
175      double variance; // sum[x-avg^2] (this is actually n times the variance)
176    };
177
178    @Override
179    public AggregationBuffer getNewAggregationBuffer() throws HiveException {
180      StdAgg result = new StdAgg();
181      reset(result);
182      return result;
183    }
184
185    @Override
186    public void reset(AggregationBuffer agg) throws HiveException {
187      StdAgg myagg = (StdAgg) agg;
188      myagg.count = 0;
189      myagg.sum = 0;
190      myagg.variance = 0;
191    }
192
193    private boolean warned = false;
194
195    @Override
196    public void iterate(AggregationBuffer agg, Object[] parameters)
197        throws HiveException {
198      assert (parameters.length == 1);
199      Object p = parameters[0];
200      if (p != null) {
201        StdAgg myagg = (StdAgg) agg;
202        try {
203          double v = PrimitiveObjectInspectorUtils.getDouble(p, inputOI);
204          myagg.count++;
205          myagg.sum += v;
206          if(myagg.count > 1) {
207            double t = myagg.count*v - myagg.sum;
208            myagg.variance += (t*t) / ((double)myagg.count*(myagg.count-1));
209          }
210        } catch (NumberFormatException e) {
211          if (!warned) {
212            warned = true;
213            LOG.warn(getClass().getSimpleName() + " "
214                + StringUtils.stringifyException(e));
215            LOG.warn(getClass().getSimpleName()
216                + " ignoring similar exceptions.");
217          }
218        }
219      }
220    }
221
222    @Override
223    public Object terminatePartial(AggregationBuffer agg) throws HiveException {
224      StdAgg myagg = (StdAgg) agg;
225      ((LongWritable) partialResult[0]).set(myagg.count);
226      ((DoubleWritable) partialResult[1]).set(myagg.sum);
227      ((DoubleWritable) partialResult[2]).set(myagg.variance);
228      return partialResult;
229    }
230
231    @Override
232    public void merge(AggregationBuffer agg, Object partial) throws HiveException {
233      if (partial != null) {
234        StdAgg myagg = (StdAgg) agg;
235
236        Object partialCount = soi.getStructFieldData(partial, countField);
237        Object partialSum = soi.getStructFieldData(partial, sumField);
238        Object partialVariance = soi.getStructFieldData(partial, varianceField);
239
240        long n = myagg.count;
241        long m = countFieldOI.get(partialCount);
242
243        if (n == 0) {
244          // Just copy the information since there is nothing so far
245          myagg.variance = sumFieldOI.get(partialVariance);
246          myagg.count = countFieldOI.get(partialCount);
247          myagg.sum = sumFieldOI.get(partialSum);
248        }
249
250        if (m != 0 && n != 0) {
251          // Merge the two partials
252
253          double a = myagg.sum;
254          double b = sumFieldOI.get(partialSum);
255
256          myagg.count += m;
257          myagg.sum += b;
258          double t = (m/(double)n)*a - b;
259          myagg.variance += sumFieldOI.get(partialVariance) + ((n/(double)m)/((double)n+m)) * t * t;
260        }
261      }
262    }
263
264    @Override
265    public Object terminate(AggregationBuffer agg) throws HiveException {
266      StdAgg myagg = (StdAgg) agg;
267
268      if (myagg.count == 0) { // SQL standard - return null for zero elements
269        return null;
270      } else {
271        if (myagg.count > 1) {
272          getResult().set(myagg.variance / (myagg.count));
273        } else { // for one element the variance is always 0
274          getResult().set(0);
275        }
276        return getResult();
277      }
278    }
279
280    public void setResult(DoubleWritable result) {
281      this.result = result;
282    }
283
284    public DoubleWritable getResult() {
285      return result;
286    }
287  }
288
289}