PageRenderTime 79ms CodeModel.GetById 13ms app.highlight 61ms RepoModel.GetById 1ms app.codeStats 0ms

/tags/release-0.0.0-rc0/hive/external/ql/src/java/org/apache/hadoop/hive/ql/optimizer/BucketMapJoinOptimizer.java

#
Java | 497 lines | 394 code | 44 blank | 59 comment | 83 complexity | 110b057574c0dad5d32ed325c04cfb49 MD5 | raw file
  1/**
  2 * Licensed to the Apache Software Foundation (ASF) under one
  3 * or more contributor license agreements.See the NOTICE file
  4 * distributed with this work for additional information
  5 * regarding copyright ownership.The ASF licenses this file
  6 * to you under the Apache License, Version 2.0 (the
  7 * "License"); you may not use this file except in compliance
  8 * with the License.You may obtain a copy of the License at
  9 *
 10 * http://www.apache.org/licenses/LICENSE-2.0
 11 *
 12 * Unless required by applicable law or agreed to in writing, software
 13 * distributed under the License is distributed on an "AS IS" BASIS,
 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 15 * See the License for the specific language governing permissions and
 16 * limitations under the License.
 17 */
 18package org.apache.hadoop.hive.ql.optimizer;
 19
 20import java.io.IOException;
 21import java.io.Serializable;
 22import java.util.ArrayList;
 23import java.util.Collection;
 24import java.util.Collections;
 25import java.util.HashSet;
 26import java.util.Iterator;
 27import java.util.LinkedHashMap;
 28import java.util.List;
 29import java.util.Map;
 30import java.util.Set;
 31import java.util.Stack;
 32import java.util.Map.Entry;
 33
 34import org.apache.commons.logging.Log;
 35import org.apache.commons.logging.LogFactory;
 36import org.apache.hadoop.fs.FileStatus;
 37import org.apache.hadoop.fs.FileSystem;
 38import org.apache.hadoop.fs.Path;
 39import org.apache.hadoop.hive.ql.exec.FunctionRegistry;
 40import org.apache.hadoop.hive.ql.exec.MapJoinOperator;
 41import org.apache.hadoop.hive.ql.exec.Operator;
 42import org.apache.hadoop.hive.ql.exec.TableScanOperator;
 43import org.apache.hadoop.hive.ql.lib.DefaultGraphWalker;
 44import org.apache.hadoop.hive.ql.lib.DefaultRuleDispatcher;
 45import org.apache.hadoop.hive.ql.lib.Dispatcher;
 46import org.apache.hadoop.hive.ql.lib.GraphWalker;
 47import org.apache.hadoop.hive.ql.lib.Node;
 48import org.apache.hadoop.hive.ql.lib.NodeProcessor;
 49import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
 50import org.apache.hadoop.hive.ql.lib.Rule;
 51import org.apache.hadoop.hive.ql.lib.RuleRegExp;
 52import org.apache.hadoop.hive.ql.metadata.HiveException;
 53import org.apache.hadoop.hive.ql.metadata.Partition;
 54import org.apache.hadoop.hive.ql.metadata.Table;
 55import org.apache.hadoop.hive.ql.optimizer.ppr.PartitionPruner;
 56import org.apache.hadoop.hive.ql.parse.ParseContext;
 57import org.apache.hadoop.hive.ql.parse.PrunedPartitionList;
 58import org.apache.hadoop.hive.ql.parse.QBJoinTree;
 59import org.apache.hadoop.hive.ql.parse.SemanticException;
 60import org.apache.hadoop.hive.ql.plan.ExprNodeColumnDesc;
 61import org.apache.hadoop.hive.ql.plan.ExprNodeDesc;
 62import org.apache.hadoop.hive.ql.plan.ExprNodeGenericFuncDesc;
 63import org.apache.hadoop.hive.ql.plan.MapJoinDesc;
 64import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
 65
 66/**
 67 *this transformation does bucket map join optimization.
 68 */
 69public class BucketMapJoinOptimizer implements Transform {
 70
 71  private static final Log LOG = LogFactory.getLog(GroupByOptimizer.class
 72      .getName());
 73
 74  public BucketMapJoinOptimizer() {
 75  }
 76
 77  @Override
 78  public ParseContext transform(ParseContext pctx) throws SemanticException {
 79
 80    Map<Rule, NodeProcessor> opRules = new LinkedHashMap<Rule, NodeProcessor>();
 81    BucketMapjoinOptProcCtx bucketMapJoinOptimizeCtx = new BucketMapjoinOptProcCtx();
 82
 83    // process map joins with no reducers pattern
 84    opRules.put(new RuleRegExp("R1", "MAPJOIN%"), getBucketMapjoinProc(pctx));
 85    opRules.put(new RuleRegExp("R2", "RS%.*MAPJOIN"), getBucketMapjoinRejectProc(pctx));
 86    opRules.put(new RuleRegExp(new String("R3"), "UNION%.*MAPJOIN%"),
 87        getBucketMapjoinRejectProc(pctx));
 88    opRules.put(new RuleRegExp(new String("R4"), "MAPJOIN%.*MAPJOIN%"),
 89        getBucketMapjoinRejectProc(pctx));
 90
 91    // The dispatcher fires the processor corresponding to the closest matching
 92    // rule and passes the context along
 93    Dispatcher disp = new DefaultRuleDispatcher(getDefaultProc(), opRules,
 94        bucketMapJoinOptimizeCtx);
 95    GraphWalker ogw = new DefaultGraphWalker(disp);
 96
 97    // Create a list of topop nodes
 98    ArrayList<Node> topNodes = new ArrayList<Node>();
 99    topNodes.addAll(pctx.getTopOps().values());
100    ogw.startWalking(topNodes, null);
101
102    return pctx;
103  }
104
105  private NodeProcessor getBucketMapjoinRejectProc(ParseContext pctx) {
106    return new NodeProcessor () {
107      @Override
108      public Object process(Node nd, Stack<Node> stack,
109          NodeProcessorCtx procCtx, Object... nodeOutputs)
110          throws SemanticException {
111        MapJoinOperator mapJoinOp = (MapJoinOperator) nd;
112        BucketMapjoinOptProcCtx context = (BucketMapjoinOptProcCtx) procCtx;
113        context.listOfRejectedMapjoins.add(mapJoinOp);
114        return null;
115      }
116    };
117  }
118
119  private NodeProcessor getBucketMapjoinProc(ParseContext pctx) {
120    return new BucketMapjoinOptProc(pctx);
121  }
122
123  private NodeProcessor getDefaultProc() {
124    return new NodeProcessor() {
125      @Override
126      public Object process(Node nd, Stack<Node> stack,
127          NodeProcessorCtx procCtx, Object... nodeOutputs)
128          throws SemanticException {
129        return null;
130      }
131    };
132  }
133
134  class BucketMapjoinOptProc implements NodeProcessor {
135
136    protected ParseContext pGraphContext;
137
138    public BucketMapjoinOptProc(ParseContext pGraphContext) {
139      super();
140      this.pGraphContext = pGraphContext;
141    }
142
143    @Override
144    public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
145        Object... nodeOutputs) throws SemanticException {
146      MapJoinOperator mapJoinOp = (MapJoinOperator) nd;
147      BucketMapjoinOptProcCtx context = (BucketMapjoinOptProcCtx) procCtx;
148
149      if(context.getListOfRejectedMapjoins().contains(mapJoinOp)) {
150        return null;
151      }
152
153      QBJoinTree joinCxt = this.pGraphContext.getMapJoinContext().get(mapJoinOp);
154      if(joinCxt == null) {
155        return null;
156      }
157
158      List<String> joinAliases = new ArrayList<String>();
159      String[] srcs = joinCxt.getBaseSrc();
160      String[] left = joinCxt.getLeftAliases();
161      List<String> mapAlias = joinCxt.getMapAliases();
162      String baseBigAlias = null;
163      for(String s : left) {
164        if(s != null && !joinAliases.contains(s)) {
165          joinAliases.add(s);
166          if(!mapAlias.contains(s)) {
167            baseBigAlias = s;
168          }
169        }
170      }
171      for(String s : srcs) {
172        if(s != null && !joinAliases.contains(s)) {
173          joinAliases.add(s);
174          if(!mapAlias.contains(s)) {
175            baseBigAlias = s;
176          }
177        }
178      }
179
180      MapJoinDesc mjDecs = mapJoinOp.getConf();
181      LinkedHashMap<String, Integer> aliasToBucketNumberMapping = new LinkedHashMap<String, Integer>();
182      LinkedHashMap<String, List<String>> aliasToBucketFileNamesMapping = new LinkedHashMap<String, List<String>>();
183      // right now this code does not work with "a join b on a.key = b.key and
184      // a.ds = b.ds", where ds is a partition column. It only works with joins
185      // with only one partition presents in each join source tables.
186      Map<String, Operator<? extends Serializable>> topOps = this.pGraphContext.getTopOps();
187      Map<TableScanOperator, Table> topToTable = this.pGraphContext.getTopToTable();
188
189      // (partition to bucket file names) and (partition to bucket number) for
190      // the big table;
191      LinkedHashMap<Partition, List<String>> bigTblPartsToBucketFileNames = new LinkedHashMap<Partition, List<String>>();
192      LinkedHashMap<Partition, Integer> bigTblPartsToBucketNumber = new LinkedHashMap<Partition, Integer>();
193
194      for (int index = 0; index < joinAliases.size(); index++) {
195        String alias = joinAliases.get(index);
196        TableScanOperator tso = (TableScanOperator) topOps.get(alias);
197        if (tso == null) {
198          return null;
199        }
200        Table tbl = topToTable.get(tso);
201        if(tbl.isPartitioned()) {
202          PrunedPartitionList prunedParts = null;
203          try {
204            prunedParts = pGraphContext.getOpToPartList().get(tso);
205            if (prunedParts == null) {
206              prunedParts = PartitionPruner.prune(tbl, pGraphContext.getOpToPartPruner().get(tso), pGraphContext.getConf(), alias,
207                pGraphContext.getPrunedPartitions());
208              pGraphContext.getOpToPartList().put(tso, prunedParts);
209            }
210          } catch (HiveException e) {
211            // Has to use full name to make sure it does not conflict with
212            // org.apache.commons.lang.StringUtils
213            LOG.error(org.apache.hadoop.util.StringUtils.stringifyException(e));
214            throw new SemanticException(e.getMessage(), e);
215          }
216          int partNumber = prunedParts.getConfirmedPartns().size()
217              + prunedParts.getUnknownPartns().size();
218
219          if (partNumber > 1) {
220            // only allow one partition for small tables
221            if(alias != baseBigAlias) {
222              return null;
223            }
224            // here is the big table,and we get more than one partitions.
225            // construct a mapping of (Partition->bucket file names) and
226            // (Partition -> bucket number)
227            Iterator<Partition> iter = prunedParts.getConfirmedPartns()
228                .iterator();
229            while (iter.hasNext()) {
230              Partition p = iter.next();
231              if (!checkBucketColumns(p.getBucketCols(), mjDecs, index)) {
232                return null;
233              }
234              List<String> fileNames = getOnePartitionBucketFileNames(p);
235              bigTblPartsToBucketFileNames.put(p, fileNames);
236              bigTblPartsToBucketNumber.put(p, p.getBucketCount());
237            }
238            iter = prunedParts.getUnknownPartns().iterator();
239            while (iter.hasNext()) {
240              Partition p = iter.next();
241              if (!checkBucketColumns(p.getBucketCols(), mjDecs, index)) {
242                return null;
243              }
244              List<String> fileNames = getOnePartitionBucketFileNames(p);
245              bigTblPartsToBucketFileNames.put(p, fileNames);
246              bigTblPartsToBucketNumber.put(p, p.getBucketCount());
247            }
248            // If there are more than one partition for the big
249            // table,aliasToBucketFileNamesMapping and partsToBucketNumber will
250            // not contain mappings for the big table. Instead, the mappings are
251            // contained in bigTblPartsToBucketFileNames and
252            // bigTblPartsToBucketNumber
253
254          } else {
255            Partition part = null;
256            Iterator<Partition> iter = prunedParts.getConfirmedPartns()
257                .iterator();
258            if (iter.hasNext()) {
259              part = iter.next();              
260            }
261            if (part == null) {
262              iter = prunedParts.getUnknownPartns().iterator();
263              if (iter.hasNext()) {
264                part = iter.next();              
265              }
266            }
267            assert part != null;
268            Integer num = new Integer(part.getBucketCount());
269            aliasToBucketNumberMapping.put(alias, num);
270            if (!checkBucketColumns(part.getBucketCols(), mjDecs, index)) {
271              return null;
272            }
273            List<String> fileNames = getOnePartitionBucketFileNames(part);
274            aliasToBucketFileNamesMapping.put(alias, fileNames);
275            if (alias == baseBigAlias) {
276              bigTblPartsToBucketFileNames.put(part, fileNames);
277              bigTblPartsToBucketNumber.put(part, num);
278            }
279          }
280        } else {
281          if (!checkBucketColumns(tbl.getBucketCols(), mjDecs, index)) {
282            return null;
283          }
284          Integer num = new Integer(tbl.getNumBuckets());
285          aliasToBucketNumberMapping.put(alias, num);
286          List<String> fileNames = new ArrayList<String>();
287          try {
288            FileSystem fs = FileSystem.get(tbl.getDataLocation(), this.pGraphContext.getConf());
289            FileStatus[] files = fs.listStatus(new Path(tbl.getDataLocation().toString()));
290            if(files != null) {
291              for(FileStatus file : files) {
292                fileNames.add(file.getPath().toString());
293              }
294            }
295          } catch (IOException e) {
296            throw new SemanticException(e);
297          }
298          aliasToBucketFileNamesMapping.put(alias, fileNames);
299        }
300      }
301
302      // All tables or partitions are bucketed, and their bucket number is
303      // stored in 'bucketNumbers', we need to check if the number of buckets in
304      // the big table can be divided by no of buckets in small tables.
305      if (bigTblPartsToBucketNumber.size() > 0) {
306        Iterator<Entry<Partition, Integer>> bigTblPartToBucketNumber = bigTblPartsToBucketNumber
307            .entrySet().iterator();
308        while (bigTblPartToBucketNumber.hasNext()) {
309          int bucketNumberInPart = bigTblPartToBucketNumber.next().getValue();
310          if (!checkBucketNumberAgainstBigTable(aliasToBucketNumberMapping,
311              bucketNumberInPart)) {
312            return null;
313          }
314        }
315      } else {
316        int bucketNoInBigTbl = aliasToBucketNumberMapping.get(baseBigAlias).intValue();
317        if (!checkBucketNumberAgainstBigTable(aliasToBucketNumberMapping,
318            bucketNoInBigTbl)) {
319          return null;
320        }
321      }
322
323      MapJoinDesc desc = mapJoinOp.getConf();
324
325      LinkedHashMap<String, LinkedHashMap<String, ArrayList<String>>> aliasBucketFileNameMapping =
326        new LinkedHashMap<String, LinkedHashMap<String, ArrayList<String>>>();
327
328      //sort bucket names for the big table
329      if(bigTblPartsToBucketNumber.size() > 0) {
330        Collection<List<String>> bucketNamesAllParts = bigTblPartsToBucketFileNames.values();
331        for(List<String> partBucketNames : bucketNamesAllParts) {
332          Collections.sort(partBucketNames);
333        }
334      } else {
335        Collections.sort(aliasToBucketFileNamesMapping.get(baseBigAlias));
336      }
337
338      // go through all small tables and get the mapping from bucket file name
339      // in the big table to bucket file names in small tables.
340      for (int j = 0; j < joinAliases.size(); j++) {
341        String alias = joinAliases.get(j);
342        if(alias.equals(baseBigAlias)) {
343          continue;
344        }
345        Collections.sort(aliasToBucketFileNamesMapping.get(alias));
346        LinkedHashMap<String, ArrayList<String>> mapping = new LinkedHashMap<String, ArrayList<String>>();
347        aliasBucketFileNameMapping.put(alias, mapping);
348
349        // for each bucket file in big table, get the corresponding bucket file
350        // name in the small table.
351        if (bigTblPartsToBucketNumber.size() > 0) {
352          //more than 1 partition in the big table, do the mapping for each partition
353          Iterator<Entry<Partition, List<String>>> bigTblPartToBucketNames = bigTblPartsToBucketFileNames
354              .entrySet().iterator();
355          Iterator<Entry<Partition, Integer>> bigTblPartToBucketNum = bigTblPartsToBucketNumber
356              .entrySet().iterator();
357          while (bigTblPartToBucketNames.hasNext()) {
358            assert bigTblPartToBucketNum.hasNext();
359            int bigTblBucketNum = bigTblPartToBucketNum.next().getValue().intValue();
360            List<String> bigTblBucketNameList = bigTblPartToBucketNames.next().getValue();
361            fillMapping(baseBigAlias, aliasToBucketNumberMapping,
362                aliasToBucketFileNamesMapping, alias, mapping, bigTblBucketNum,
363                bigTblBucketNameList, desc.getBucketFileNameMapping());
364          }
365        } else {
366          List<String> bigTblBucketNameList = aliasToBucketFileNamesMapping.get(baseBigAlias);
367          int bigTblBucketNum =  aliasToBucketNumberMapping.get(baseBigAlias);
368          fillMapping(baseBigAlias, aliasToBucketNumberMapping,
369              aliasToBucketFileNamesMapping, alias, mapping, bigTblBucketNum,
370              bigTblBucketNameList, desc.getBucketFileNameMapping());
371        }
372      }
373      desc.setAliasBucketFileNameMapping(aliasBucketFileNameMapping);
374      desc.setBigTableAlias(baseBigAlias);
375      return null;
376    }
377
378    private void fillMapping(String baseBigAlias,
379        LinkedHashMap<String, Integer> aliasToBucketNumberMapping,
380        LinkedHashMap<String, List<String>> aliasToBucketFileNamesMapping,
381        String alias, LinkedHashMap<String, ArrayList<String>> mapping,
382        int bigTblBucketNum, List<String> bigTblBucketNameList,
383        LinkedHashMap<String, Integer> bucketFileNameMapping) {
384      for (int index = 0; index < bigTblBucketNameList.size(); index++) {
385        String inputBigTBLBucket = bigTblBucketNameList.get(index);
386        int smallTblBucketNum = aliasToBucketNumberMapping.get(alias);
387        ArrayList<String> resultFileNames = new ArrayList<String>();
388        if (bigTblBucketNum >= smallTblBucketNum) {
389          // if the big table has more buckets than the current small table,
390          // use "MOD" to get small table bucket names. For example, if the big
391          // table has 4 buckets and the small table has 2 buckets, then the
392          // mapping should be 0->0, 1->1, 2->0, 3->1.
393          int toAddSmallIndex = index % smallTblBucketNum;
394          if(toAddSmallIndex < aliasToBucketFileNamesMapping.get(alias).size()) {
395            resultFileNames.add(aliasToBucketFileNamesMapping.get(alias).get(toAddSmallIndex));
396          }
397        } else {
398          int jump = smallTblBucketNum / bigTblBucketNum;
399          for (int i = index; i < aliasToBucketFileNamesMapping.get(alias).size(); i = i + jump) {
400            if(i <= aliasToBucketFileNamesMapping.get(alias).size()) {
401              resultFileNames.add(aliasToBucketFileNamesMapping.get(alias).get(i));
402            }
403          }
404        }
405        mapping.put(inputBigTBLBucket, resultFileNames);
406        bucketFileNameMapping.put(inputBigTBLBucket, index);
407      }
408    }
409
410    private boolean checkBucketNumberAgainstBigTable(
411        LinkedHashMap<String, Integer> aliasToBucketNumber,
412        int bucketNumberInPart) {
413      Iterator<Integer> iter = aliasToBucketNumber.values().iterator();
414      while(iter.hasNext()) {
415        int nxt = iter.next().intValue();
416        boolean ok = (nxt >= bucketNumberInPart) ? nxt % bucketNumberInPart == 0
417            : bucketNumberInPart % nxt == 0;
418        if(!ok) {
419          return false;
420        }
421      }
422      return true;
423    }
424
425    private List<String> getOnePartitionBucketFileNames(Partition part)
426        throws SemanticException {
427      List<String> fileNames = new ArrayList<String>();
428      try {
429        FileSystem fs = FileSystem.get(part.getDataLocation(), this.pGraphContext.getConf());
430        FileStatus[] files = fs.listStatus(new Path(part.getDataLocation()
431            .toString()));
432        if (files != null) {
433          for (FileStatus file : files) {
434            fileNames.add(file.getPath().toString());
435          }
436        }
437      } catch (IOException e) {
438        throw new SemanticException(e);
439      }
440      return fileNames;
441    }
442
443    private boolean checkBucketColumns(List<String> bucketColumns, MapJoinDesc mjDesc, int index) {
444      List<ExprNodeDesc> keys = mjDesc.getKeys().get((byte)index);
445      if (keys == null || bucketColumns == null || bucketColumns.size() == 0) {
446        return false;
447      }
448
449      //get all join columns from join keys stored in MapJoinDesc
450      List<String> joinCols = new ArrayList<String>();
451      List<ExprNodeDesc> joinKeys = new ArrayList<ExprNodeDesc>();
452      joinKeys.addAll(keys);
453      while (joinKeys.size() > 0) {
454        ExprNodeDesc node = joinKeys.remove(0);
455        if (node instanceof ExprNodeColumnDesc) {
456          joinCols.addAll(node.getCols());
457        } else if (node instanceof ExprNodeGenericFuncDesc) {
458          ExprNodeGenericFuncDesc udfNode = ((ExprNodeGenericFuncDesc) node);
459          GenericUDF udf = udfNode.getGenericUDF();
460          if (!FunctionRegistry.isDeterministic(udf)) {
461            return false;
462          }
463          joinKeys.addAll(0, udfNode.getChildExprs());
464        } else {
465          return false;
466        }
467      }
468
469      // to see if the join columns from a table is exactly this same as its
470      // bucket columns
471      if (joinCols.size() == 0 || joinCols.size() != bucketColumns.size()) {
472        return false;
473      }
474
475      for (String col : joinCols) {
476        if (!bucketColumns.contains(col)) {
477          return false;
478        }
479      }
480
481      return true;
482    }
483
484  }
485
486  class BucketMapjoinOptProcCtx implements NodeProcessorCtx {
487    // we only convert map joins that follows a root table scan in the same
488    // mapper. That means there is no reducer between the root table scan and
489    // mapjoin.
490    Set<MapJoinOperator> listOfRejectedMapjoins = new HashSet<MapJoinOperator>();
491
492    public Set<MapJoinOperator> getListOfRejectedMapjoins() {
493      return listOfRejectedMapjoins;
494    }
495
496  }
497}