/JIT/PyTBAliasAnalysis.cc

http://unladen-swallow.googlecode.com/ · C++ · 439 lines · 312 code · 98 blank · 29 comment · 59 complexity · fca8323aae91be5bc81b6f94540aad55 MD5 · raw file

  1. #include "JIT/PyTBAliasAnalysis.h"
  2. #include "JIT/ConstantMirror.h"
  3. #include "JIT/PyTypeBuilder.h"
  4. #include "llvm/Analysis/AliasAnalysis.h"
  5. #include "llvm/ADT/SmallPtrSet.h"
  6. #include "llvm/ADT/DenseMap.h"
  7. #include "llvm/ADT/StringMap.h"
  8. #include <utility>
  9. namespace {
  10. using llvm::AliasAnalysis;
  11. using llvm::BasicBlock;
  12. using llvm::BranchInst;
  13. using llvm::CallInst;
  14. using llvm::Function;
  15. using llvm::FunctionPass;
  16. using llvm::ICmpInst;
  17. using llvm::Instruction;
  18. using llvm::LLVMContext;
  19. using llvm::MDNode;
  20. using llvm::Module;
  21. using llvm::Pass;
  22. using llvm::PassInfo;
  23. using llvm::Value;
  24. using llvm::dyn_cast;
  25. using llvm::isa;
  26. // This function unwinds casts and GEPs until it finds an instruction with
  27. // a TBAA Metadata node. Returns NULL if no metadata is found. Automatically
  28. // tags pointers to PyObject.
  29. static MDNode *getFirstMDNode(const PyGlobalLlvmData *const llvm_data,
  30. const unsigned kind, const Value *V)
  31. {
  32. unsigned MaxLookup = 10;
  33. bool is_pyobject = false;
  34. const llvm::Type *pyobject_type =
  35. PyTypeBuilder<PyObject*>::get(V->getContext());
  36. do {
  37. if (const Instruction *instr = dyn_cast<Instruction>(V))
  38. if (MDNode *n = instr->getMetadata(kind))
  39. return n;
  40. // TODO: This makes the assumption that there is only one struct
  41. // with the structure of a PyObject. If this changes, this may mark
  42. // Values wrongly.
  43. if (V->getType() == pyobject_type)
  44. is_pyobject = true;
  45. // It's save to remove any GEPs. Even if the type of the value changes,
  46. // it is still within some outer structure about which we can make
  47. // aliasing assumptions.
  48. if (const llvm::GEPOperator *GEP = dyn_cast<llvm::GEPOperator>(V)) {
  49. V = GEP->getPointerOperand();
  50. } else if (llvm::Operator::getOpcode(V) == llvm::Instruction::BitCast) {
  51. V = llvm::cast<llvm::Operator>(V)->getOperand(0);
  52. } else if (const llvm::GlobalAlias *GA =
  53. dyn_cast<llvm::GlobalAlias>(V)) {
  54. if (GA->mayBeOverridden()) {
  55. break;
  56. }
  57. V = GA->getAliasee();
  58. } else {
  59. break;
  60. }
  61. } while (--MaxLookup);
  62. if (is_pyobject) {
  63. return llvm_data->tbaa_PyObject.type();
  64. }
  65. return NULL;
  66. }
  67. class PyTBAliasAnalysis : public FunctionPass, public AliasAnalysis {
  68. public:
  69. static char ID;
  70. PyTBAliasAnalysis(PyGlobalLlvmData &global_data)
  71. : FunctionPass(&ID), context_(&global_data.context()),
  72. llvm_data_(&global_data),
  73. kind_(global_data.GetTBAAKind())
  74. {}
  75. PyTBAliasAnalysis()
  76. : FunctionPass(&ID), context_(NULL), llvm_data_(NULL),
  77. kind_(0)
  78. {}
  79. virtual void getAnalysisUsage(llvm::AnalysisUsage &usage) const {
  80. AliasAnalysis::getAnalysisUsage(usage);
  81. usage.setPreservesAll();
  82. }
  83. virtual bool runOnFunction(Function&);
  84. virtual AliasResult alias(const Value *V1, unsigned V1Size,
  85. const Value *V2, unsigned V2Size);
  86. virtual void *getAdjustedAnalysisPointer(const PassInfo *PI) {
  87. if (PI->isPassID(&AliasAnalysis::ID))
  88. return (AliasAnalysis*)this;
  89. return this;
  90. }
  91. private:
  92. bool typesMayAlias(MDNode *T1, MDNode *T2) const;
  93. const LLVMContext *const context_;
  94. const PyGlobalLlvmData *const llvm_data_;
  95. const unsigned kind_;
  96. };
  97. // The address of this variable identifies the pass. See
  98. // http://llvm.org/docs/WritingAnLLVMPass.html#basiccode.
  99. char PyTBAliasAnalysis::ID = 0;
  100. // Register this pass.
  101. static llvm::RegisterPass<PyTBAliasAnalysis>
  102. U("python-tbaa", "Python-specific Type Based Alias Analysis", false, true);
  103. // Declare that we implement the AliasAnalysis interface.
  104. static llvm::RegisterAnalysisGroup<AliasAnalysis> V(U);
  105. bool
  106. PyTBAliasAnalysis::runOnFunction(Function &f)
  107. {
  108. AliasAnalysis::InitializeAliasAnalysis(this);
  109. return false;
  110. }
  111. AliasAnalysis::AliasResult
  112. PyTBAliasAnalysis::alias(const Value *V1, unsigned V1Size,
  113. const Value *V2, unsigned V2Size)
  114. {
  115. MDNode *T1 = getFirstMDNode(this->llvm_data_, this->kind_,
  116. const_cast<Value*>(V1));
  117. MDNode *T2 = getFirstMDNode(this->llvm_data_, this->kind_,
  118. const_cast<Value*>(V2));
  119. if (T1 == NULL || T2 == NULL) {
  120. return AliasAnalysis::alias(V1, V1Size, V2, V2Size);
  121. }
  122. if (!this->typesMayAlias(T1, T2)) {
  123. return NoAlias;
  124. }
  125. return AliasAnalysis::alias(V1, V1Size, V2, V2Size);
  126. }
  127. bool
  128. PyTBAliasAnalysis::typesMayAlias(MDNode *T1, MDNode *T2) const
  129. {
  130. if (T1 == T2)
  131. return true;
  132. if (llvm_data_->IsTBAASubtype(T1, T2))
  133. return true;
  134. if (llvm_data_->IsTBAASubtype(T2, T1))
  135. return true;
  136. return false;
  137. }
  138. } // End of anonymous namespace
  139. Pass *
  140. CreatePyTBAliasAnalysis(PyGlobalLlvmData &global_data)
  141. {
  142. return new PyTBAliasAnalysis(global_data);
  143. }
  144. namespace {
  145. class PyTypeMarkingPass : public FunctionPass {
  146. public:
  147. static char ID;
  148. PyTypeMarkingPass(PyGlobalLlvmData &global_data);
  149. PyTypeMarkingPass()
  150. : FunctionPass(&ID), context_(NULL), llvm_data_(NULL), kind_(0)
  151. {}
  152. virtual void getAnalysisUsage(llvm::AnalysisUsage &usage) const {
  153. usage.setPreservesAll();
  154. }
  155. virtual bool runOnFunction(Function&);
  156. private:
  157. void addMark(const char *name, const PyTBAAType &type);
  158. bool markFunction(CallInst *callInst);
  159. const LLVMContext *const context_;
  160. const PyGlobalLlvmData *const llvm_data_;
  161. const unsigned kind_;
  162. llvm::ValueMap<Function *, const PyTBAAType *> func_map_;
  163. };
  164. // The address of this variable identifies the pass. See
  165. // http://llvm.org/docs/WritingAnLLVMPass.html#basiccode.
  166. char PyTypeMarkingPass::ID = 0;
  167. // Register this pass.
  168. static llvm::RegisterPass<PyTypeMarkingPass>
  169. W("python-typemarking", "Python-specific Type Marking pass", false, true);
  170. PyTypeMarkingPass::PyTypeMarkingPass(PyGlobalLlvmData &global_data)
  171. : FunctionPass(&ID), context_(&global_data.context()),
  172. llvm_data_(&global_data),
  173. kind_(global_data.GetTBAAKind())
  174. {
  175. }
  176. bool
  177. PyTypeMarkingPass::runOnFunction(Function &F)
  178. {
  179. if (this->func_map_.empty()) {
  180. // This functions should get loaded with the BC file.
  181. addMark("PyInt_FromLong", llvm_data_->tbaa_PyIntObject);
  182. addMark("PyInt_FromSsize_t", llvm_data_->tbaa_PyIntObject);
  183. // PyBoolObject is a subtype of PyIntObject
  184. addMark("PyBool_FromLong", llvm_data_->tbaa_PyIntObject);
  185. addMark("PyFloat_FromDouble", llvm_data_->tbaa_PyFloatObject);
  186. addMark("PyString_Format", llvm_data_->tbaa_PyStringObject);
  187. // getFunction needs the real function name. Expand macros.
  188. #ifndef Py_UNICODE_WIDE
  189. addMark("PyUnicodeUCS2_Format", llvm_data_->tbaa_PyUnicodeObject);
  190. #else
  191. addMark("PyUnicodeUCS4_Format", llvm_data_->tbaa_PyUnicodeObject);
  192. #endif
  193. }
  194. bool changed = false;
  195. for (Function::iterator b = F.begin(), be = F.end(); b != be; ++b) {
  196. for (BasicBlock::iterator i = b->begin(), ie = b->end(); i != ie; ++i) {
  197. if (CallInst* callInst = dyn_cast<CallInst>(&*i)) {
  198. changed |= this->markFunction(callInst);
  199. }
  200. }
  201. }
  202. return changed;
  203. }
  204. void
  205. PyTypeMarkingPass::addMark(const char *name, const PyTBAAType &type)
  206. {
  207. const Module *module = this->llvm_data_->module();
  208. Function *func = module->getFunction(name);
  209. // There should already be GVs for functions in the runtime library.
  210. assert(func != NULL);
  211. this->func_map_[func] = &type;
  212. }
  213. bool
  214. PyTypeMarkingPass::markFunction(CallInst *callInst)
  215. {
  216. Function *called = callInst->getCalledFunction();
  217. if (called == NULL)
  218. return false;
  219. if (callInst->getMetadata(this->kind_) != NULL)
  220. return false;
  221. const PyTBAAType *type = this->func_map_.lookup(called);
  222. if (type == NULL)
  223. return false;
  224. type->MarkInstruction(callInst);
  225. return true;
  226. }
  227. } // End of anonymous namespace
  228. Pass *
  229. CreatePyTypeMarkingPass(PyGlobalLlvmData &global_data)
  230. {
  231. return new PyTypeMarkingPass(global_data);
  232. }
  233. namespace {
  234. class PyTypeGuardRemovalPass : public FunctionPass {
  235. public:
  236. static char ID;
  237. PyTypeGuardRemovalPass(PyGlobalLlvmData &global_data);
  238. PyTypeGuardRemovalPass()
  239. : FunctionPass(&ID), context_(NULL), llvm_data_(NULL), kind_(0)
  240. {}
  241. virtual bool runOnFunction(Function&);
  242. private:
  243. void addGuardType(PyObject *obj, const PyTBAAType &type);
  244. bool checkICmp(ICmpInst *icmpIns);
  245. LLVMContext *const context_;
  246. const PyGlobalLlvmData *const llvm_data_;
  247. const unsigned kind_;
  248. typedef llvm::ValueMap<const Value *,
  249. llvm::TrackingVH<MDNode> > GuardTypes;
  250. // Type checks in LlvmIR compare the type field of a PyObject with a
  251. // instance of PyTypeObject. This maps the Python type to a TBAA MDNode.
  252. GuardTypes type_map_;
  253. };
  254. // The address of this variable identifies the pass. See
  255. // http://llvm.org/docs/WritingAnLLVMPass.html#basiccode.
  256. char PyTypeGuardRemovalPass::ID = 0;
  257. // Register this pass.
  258. static llvm::RegisterPass<PyTypeGuardRemovalPass>
  259. X("python-typeguard", "Python-specific Type Guard Removal Pass", false, true);
  260. PyTypeGuardRemovalPass::PyTypeGuardRemovalPass(PyGlobalLlvmData &global_data)
  261. : FunctionPass(&ID), context_(&global_data.context()),
  262. llvm_data_(&global_data),
  263. kind_(global_data.GetTBAAKind())
  264. {
  265. }
  266. void
  267. PyTypeGuardRemovalPass::addGuardType(PyObject *obj, const PyTBAAType &type)
  268. {
  269. const llvm::Value *value =
  270. this->llvm_data_->constant_mirror().GetGlobalVariableFor(obj);
  271. type_map_[value] = type.type();
  272. }
  273. bool
  274. PyTypeGuardRemovalPass::runOnFunction(Function &F)
  275. {
  276. if (type_map_.empty()) {
  277. // Lazy initialisation. GetGlobalVariable does not work during init.
  278. // Connects type objects with Metadata. Do this for every
  279. // *_CheckExact you want to remove.
  280. this->addGuardType((PyObject *)&PyInt_Type,
  281. llvm_data_->tbaa_PyIntObject);
  282. this->addGuardType((PyObject *)&PyFloat_Type,
  283. llvm_data_->tbaa_PyFloatObject);
  284. }
  285. bool changed = false;
  286. for (Function::iterator b = F.begin(), be = F.end(); b != be; ++b) {
  287. for (BasicBlock::iterator i = b->begin(), ie = b->end(); i != ie; ++i) {
  288. if (ICmpInst *icmpInst = dyn_cast<ICmpInst>(&*i)) {
  289. changed |= this->checkICmp(icmpInst);
  290. }
  291. }
  292. }
  293. return changed;
  294. }
  295. bool
  296. PyTypeGuardRemovalPass::checkICmp(ICmpInst *icmpInst)
  297. {
  298. const llvm::Type *type_object =
  299. PyTypeBuilder<PyTypeObject*>::get(*this->context_);
  300. if (icmpInst->getPredicate() != ICmpInst::ICMP_EQ) {
  301. return false;
  302. }
  303. Value *op1 = icmpInst->getOperand(0);
  304. if (op1->getType() != type_object) {
  305. return false;
  306. }
  307. Value *load = op1->getUnderlyingObject();
  308. llvm::LoadInst *loadInst = dyn_cast<llvm::LoadInst>(load);
  309. if (loadInst == NULL) {
  310. return false;
  311. }
  312. Value *src = loadInst->getOperand(0);
  313. MDNode *type_hint = getFirstMDNode(this->llvm_data_, this->kind_, src);
  314. if (type_hint == NULL)
  315. return false;
  316. GuardTypes::iterator it = type_map_.find(icmpInst->getOperand(1));
  317. if (it == type_map_.end())
  318. return false;
  319. MDNode *req_type = it->second;
  320. /* This only removes type checks which would result in true */
  321. if (type_hint != req_type)
  322. return false;
  323. bool changed = false;
  324. for (Value::use_iterator i = icmpInst->use_begin(), e = icmpInst->use_end();
  325. i != e; ++i) {
  326. if (BranchInst *branch = dyn_cast<BranchInst>(*i)) {
  327. if (branch->isConditional()) {
  328. changed = true;
  329. BasicBlock *true_block = branch->getSuccessor(0);
  330. branch->setUnconditionalDest(true_block);
  331. }
  332. }
  333. }
  334. return changed;
  335. }
  336. } // End of anonymous namespace
  337. Pass *
  338. CreatePyTypeGuardRemovalPass(PyGlobalLlvmData &global_data)
  339. {
  340. return new PyTypeGuardRemovalPass(global_data);
  341. }