/src/LinFu.AOP/MethodCallInterception/InterceptMethodCalls.cs

http://github.com/philiplaureano/LinFu · C# · 408 lines · 277 code · 95 blank · 36 comment · 12 complexity · bc8d953038238662efe3cca100d0b4ab MD5 · raw file

  1. using System;
  2. using System.Collections.Generic;
  3. using System.Diagnostics;
  4. using System.Reflection;
  5. using LinFu.AOP.Cecil.Interfaces;
  6. using LinFu.AOP.Interfaces;
  7. using LinFu.Reflection.Emit;
  8. using Mono.Cecil;
  9. using Mono.Cecil.Cil;
  10. namespace LinFu.AOP.Cecil
  11. {
  12. internal class InterceptMethodCalls : InstructionSwapper, IMethodWeaver
  13. {
  14. private readonly IMethodCallFilter _callFilter;
  15. private VariableDefinition _aroundInvokeProvider;
  16. private MethodReference _canReplace;
  17. private VariableDefinition _canReplaceFlag;
  18. private VariableDefinition _currentArgument;
  19. private VariableDefinition _currentArguments;
  20. private MethodReference _getProvider;
  21. private MethodReference _getReplacement;
  22. private MethodReference _getStaticProvider;
  23. private TypeReference _hostInterfaceType;
  24. private VariableDefinition _instanceProvider;
  25. private MethodReference _intercept;
  26. private VariableDefinition _interceptionDisabled;
  27. private VariableDefinition _invocationInfo;
  28. private MethodReference _invocationInfoCtor;
  29. private VariableDefinition _methodReplacementProvider;
  30. private VariableDefinition _parameterTypes;
  31. private MethodReference _popMethod;
  32. private MethodReference _pushMethod;
  33. private VariableDefinition _replacement;
  34. private VariableDefinition _returnValue;
  35. private MethodReference _stackCtor;
  36. private VariableDefinition _staticProvider;
  37. private VariableDefinition _target;
  38. private MethodReference _toArray;
  39. private VariableDefinition _typeArguments;
  40. public InterceptMethodCalls(Func<MethodReference, bool> hostMethodFilter,
  41. Func<MethodReference, bool> methodCallFilter)
  42. {
  43. _callFilter = new MethodCallFilterAdapter(hostMethodFilter, methodCallFilter);
  44. }
  45. public InterceptMethodCalls(IMethodCallFilter callFilter)
  46. {
  47. _callFilter = callFilter;
  48. }
  49. public override void ImportReferences(ModuleDefinition module)
  50. {
  51. var types = new[]
  52. {
  53. typeof(object),
  54. typeof(MethodBase),
  55. typeof(StackTrace),
  56. typeof(Type[]),
  57. typeof(Type[]),
  58. typeof(Type),
  59. typeof(object[])
  60. };
  61. _invocationInfoCtor = module.ImportConstructor<InvocationInfo>(types);
  62. _stackCtor = module.ImportConstructor<Stack<object>>();
  63. _pushMethod = module.ImportMethod<Stack<object>>("Push");
  64. _popMethod = module.ImportMethod<Stack<object>>("Pop");
  65. _toArray = module.ImportMethod<Stack<object>>("ToArray");
  66. _getProvider = module.ImportMethod<IMethodReplacementHost>("get_MethodCallReplacementProvider");
  67. _getStaticProvider = module.ImportMethod("GetProvider", typeof(MethodCallReplacementProviderRegistry));
  68. _canReplace = module.ImportMethod<IMethodReplacementProvider>("CanReplace");
  69. _getReplacement = module.ImportMethod<IMethodReplacementProvider>("GetMethodReplacement");
  70. _hostInterfaceType = module.ImportType<IMethodReplacementHost>();
  71. _intercept = module.ImportMethod<IInterceptor>("Intercept");
  72. }
  73. public override void AddLocals(MethodDefinition hostMethod)
  74. {
  75. var body = hostMethod.Body;
  76. body.InitLocals = true;
  77. _currentArguments = hostMethod.AddLocal<Stack<object>>("__arguments");
  78. _currentArgument = hostMethod.AddLocal<object>("__currentArgument");
  79. _parameterTypes = hostMethod.AddLocal<Type[]>("__parameterTypes");
  80. _typeArguments = hostMethod.AddLocal<Type[]>("__typeArguments");
  81. _invocationInfo = hostMethod.AddLocal<IInvocationInfo>("___invocationInfo");
  82. _target = hostMethod.AddLocal<object>("__target");
  83. _replacement = hostMethod.AddLocal<IInterceptor>("__interceptor");
  84. _canReplaceFlag = hostMethod.AddLocal<bool>("__canReplace");
  85. _staticProvider = hostMethod.AddLocal<IMethodReplacementProvider>("__staticProvider");
  86. _instanceProvider = hostMethod.AddLocal<IMethodReplacementProvider>("__instanceProvider");
  87. _interceptionDisabled = hostMethod.AddLocal<bool>();
  88. _methodReplacementProvider = hostMethod.AddLocal<IMethodReplacementProvider>();
  89. _aroundInvokeProvider = hostMethod.AddLocal<IAroundInvokeProvider>();
  90. _returnValue = hostMethod.AddLocal<object>();
  91. }
  92. protected override void Replace(Instruction oldInstruction, MethodDefinition hostMethod,
  93. ILProcessor IL)
  94. {
  95. var targetMethod = (MethodReference) oldInstruction.Operand;
  96. var callOriginalMethod = IL.Create(OpCodes.Nop);
  97. var returnType = targetMethod.ReturnType;
  98. var endLabel = IL.Create(OpCodes.Nop);
  99. var module = hostMethod.DeclaringType.Module;
  100. // Create the stack that will hold the method arguments
  101. IL.Emit(OpCodes.Newobj, _stackCtor);
  102. IL.Emit(OpCodes.Stloc, _currentArguments);
  103. // Make sure that the argument stack doesn't show up in
  104. // any of the other interception routines
  105. IgnoreLocal(IL, _currentArguments, module);
  106. SaveInvocationInfo(IL, targetMethod, module, returnType);
  107. var getInterceptionDisabled = new GetInterceptionDisabled(hostMethod, _interceptionDisabled);
  108. getInterceptionDisabled.Emit(IL);
  109. var surroundMethodBody = new SurroundMethodBody(_methodReplacementProvider, _aroundInvokeProvider,
  110. _invocationInfo, _interceptionDisabled, _returnValue,
  111. typeof(AroundInvokeMethodCallRegistry),
  112. "AroundMethodCallProvider");
  113. surroundMethodBody.AddProlog(IL);
  114. // Use the MethodReplacementProvider attached to the
  115. // current host instance
  116. Replace(IL, oldInstruction, targetMethod, hostMethod, endLabel, callOriginalMethod);
  117. IL.Append(endLabel);
  118. surroundMethodBody.AddEpilog(IL);
  119. }
  120. private void IgnoreLocal(ILProcessor IL, VariableDefinition targetVariable, ModuleDefinition module)
  121. {
  122. IL.Emit(OpCodes.Ldloc, targetVariable);
  123. var addInstance = module.Import(typeof(IgnoredInstancesRegistry).GetMethod("AddInstance"));
  124. IL.Emit(OpCodes.Call, addInstance);
  125. }
  126. private void Replace(ILProcessor IL, Instruction oldInstruction, MethodReference targetMethod,
  127. MethodDefinition hostMethod, Instruction endLabel, Instruction callOriginalMethod)
  128. {
  129. var returnType = targetMethod.ReturnType;
  130. var module = hostMethod.DeclaringType.Module;
  131. if (!hostMethod.IsStatic)
  132. GetInstanceProvider(IL);
  133. var pushInstance = hostMethod.HasThis ? IL.Create(OpCodes.Ldarg_0) : IL.Create(OpCodes.Ldnull);
  134. // If all else fails, use the static method replacement provider
  135. IL.Append(pushInstance);
  136. IL.Emit(OpCodes.Ldloc, _invocationInfo);
  137. IL.Emit(OpCodes.Call, _getStaticProvider);
  138. IL.Emit(OpCodes.Stloc, _staticProvider);
  139. var restoreArgumentStack = IL.Create(OpCodes.Nop);
  140. var callReplacement = IL.Create(OpCodes.Nop);
  141. var useStaticProvider = IL.Create(OpCodes.Nop);
  142. IL.Emit(OpCodes.Ldloc, _instanceProvider);
  143. IL.Emit(OpCodes.Brfalse, useStaticProvider);
  144. EmitCanReplace(IL, hostMethod, _instanceProvider);
  145. IL.Emit(OpCodes.Ldloc, _canReplaceFlag);
  146. IL.Emit(OpCodes.Brfalse, useStaticProvider);
  147. EmitGetMethodReplacement(IL, hostMethod, _instanceProvider);
  148. IL.Emit(OpCodes.Ldloc, _replacement);
  149. IL.Emit(OpCodes.Brtrue, callReplacement);
  150. IL.Append(useStaticProvider);
  151. // if (!MethodReplacementProvider.CanReplace(info))
  152. // CallOriginalMethod();
  153. EmitCanReplace(IL, hostMethod, _staticProvider);
  154. IL.Emit(OpCodes.Ldloc, _canReplaceFlag);
  155. IL.Emit(OpCodes.Brfalse, restoreArgumentStack);
  156. EmitGetMethodReplacement(IL, hostMethod, _staticProvider);
  157. IL.Append(callReplacement);
  158. // if (replacement == null)
  159. // CallOriginalMethod();
  160. IL.Emit(OpCodes.Ldloc, _replacement);
  161. IL.Emit(OpCodes.Brfalse, restoreArgumentStack);
  162. EmitInterceptorCall(IL);
  163. IL.PackageReturnValue(module, returnType);
  164. IL.Emit(OpCodes.Br, endLabel);
  165. IL.Append(restoreArgumentStack);
  166. // Reconstruct the method arguments if the interceptor
  167. // cannot be found
  168. // Push the target instance
  169. ReconstructMethodArguments(IL, targetMethod);
  170. // Mark the CallOriginalMethod instruction label
  171. IL.Append(callOriginalMethod);
  172. // Call the original method
  173. IL.Append(oldInstruction);
  174. }
  175. private void GetInstanceProvider(ILProcessor IL)
  176. {
  177. var skipInstanceProvider = IL.Create(OpCodes.Nop);
  178. IL.Emit(OpCodes.Ldarg_0);
  179. IL.Emit(OpCodes.Isinst, _hostInterfaceType);
  180. IL.Emit(OpCodes.Brfalse, skipInstanceProvider);
  181. IL.Emit(OpCodes.Ldarg_0);
  182. IL.Emit(OpCodes.Isinst, _hostInterfaceType);
  183. IL.Emit(OpCodes.Callvirt, _getProvider);
  184. IL.Emit(OpCodes.Stloc, _instanceProvider);
  185. IL.Emit(OpCodes.Ldloc, _instanceProvider);
  186. IL.Emit(OpCodes.Brtrue, skipInstanceProvider);
  187. IL.Append(skipInstanceProvider);
  188. }
  189. private void ReconstructMethodArguments(ILProcessor IL, MethodReference targetMethod)
  190. {
  191. if (targetMethod.HasThis)
  192. IL.Emit(OpCodes.Ldloc, _target);
  193. // Push the arguments back onto the stack
  194. foreach (ParameterReference param in targetMethod.Parameters)
  195. {
  196. IL.Emit(OpCodes.Ldloc, _currentArguments);
  197. IL.Emit(OpCodes.Callvirt, _popMethod);
  198. IL.Emit(OpCodes.Unbox_Any, param.ParameterType);
  199. }
  200. }
  201. private void SaveInvocationInfo(ILProcessor IL, MethodReference targetMethod, ModuleDefinition module,
  202. TypeReference returnType)
  203. {
  204. // If the target method is an instance method, then the remaining item on the stack
  205. // will be the target object instance
  206. // Put all the method arguments into the argument stack
  207. foreach (ParameterReference param in targetMethod.Parameters)
  208. {
  209. // Save the current argument
  210. var parameterType = param.ParameterType;
  211. if (parameterType.IsValueType || parameterType is GenericParameter)
  212. IL.Emit(OpCodes.Box, parameterType);
  213. IL.Emit(OpCodes.Stloc, _currentArgument);
  214. IL.Emit(OpCodes.Ldloc, _currentArguments);
  215. IL.Emit(OpCodes.Ldloc, _currentArgument);
  216. IL.Emit(OpCodes.Callvirt, _pushMethod);
  217. }
  218. // Static methods will always have a null reference as the target
  219. if (!targetMethod.HasThis)
  220. IL.Emit(OpCodes.Ldnull);
  221. // Box the target, if necessary
  222. var declaringType = targetMethod.GetDeclaringType();
  223. if (targetMethod.HasThis && (declaringType.IsValueType || declaringType is GenericParameter))
  224. IL.Emit(OpCodes.Box, declaringType);
  225. // Save the target
  226. IL.Emit(OpCodes.Stloc, _target);
  227. IL.Emit(OpCodes.Ldloc, _target);
  228. // Push the current method
  229. IL.PushMethod(targetMethod, module);
  230. // Push the stack trace
  231. PushStackTrace(IL, module);
  232. var systemType = module.Import(typeof(Type));
  233. // Save the parameter types
  234. var parameterCount = targetMethod.Parameters.Count;
  235. IL.Emit(OpCodes.Ldc_I4, parameterCount);
  236. IL.Emit(OpCodes.Newarr, systemType);
  237. IL.Emit(OpCodes.Stloc, _parameterTypes);
  238. IL.SaveParameterTypes(targetMethod, module, _parameterTypes);
  239. IL.Emit(OpCodes.Ldloc, _parameterTypes);
  240. // Save the type arguments
  241. var genericParameterCount = targetMethod.GenericParameters.Count;
  242. IL.Emit(OpCodes.Ldc_I4, genericParameterCount);
  243. IL.Emit(OpCodes.Newarr, systemType);
  244. IL.Emit(OpCodes.Stloc, _typeArguments);
  245. IL.PushGenericArguments(targetMethod, module, _typeArguments);
  246. IL.Emit(OpCodes.Ldloc, _typeArguments);
  247. // Push the return type
  248. IL.PushType(returnType, module);
  249. // Save the method arguments
  250. IL.Emit(OpCodes.Ldloc, _currentArguments);
  251. IL.Emit(OpCodes.Callvirt, _toArray);
  252. IL.Emit(OpCodes.Newobj, _invocationInfoCtor);
  253. IL.Emit(OpCodes.Stloc, _invocationInfo);
  254. IgnoreLocal(IL, _invocationInfo, module);
  255. }
  256. private void PushStackTrace(ILProcessor IL, ModuleDefinition module)
  257. {
  258. IL.PushStackTrace(module);
  259. }
  260. private void EmitInterceptorCall(ILProcessor IL)
  261. {
  262. // var result = replacement.Intercept(info);
  263. IL.Emit(OpCodes.Ldloc, _replacement);
  264. IL.Emit(OpCodes.Ldloc, _invocationInfo);
  265. IL.Emit(OpCodes.Callvirt, _intercept);
  266. }
  267. private void EmitCanReplace(ILProcessor IL, IMethodSignature hostMethod, VariableDefinition provider)
  268. {
  269. var skipGetProvider = IL.Create(OpCodes.Nop);
  270. IL.Emit(OpCodes.Ldloc, provider);
  271. IL.Emit(OpCodes.Brfalse, skipGetProvider);
  272. IL.Emit(OpCodes.Ldloc, provider);
  273. // Push the host instance
  274. var pushInstance = hostMethod.HasThis ? IL.Create(OpCodes.Ldarg_0) : IL.Create(OpCodes.Ldnull);
  275. IL.Append(pushInstance);
  276. IL.Emit(OpCodes.Ldloc, _invocationInfo);
  277. IL.Emit(OpCodes.Callvirt, _canReplace);
  278. IL.Emit(OpCodes.Stloc, _canReplaceFlag);
  279. IL.Append(skipGetProvider);
  280. }
  281. private void EmitGetMethodReplacement(ILProcessor IL, IMethodSignature hostMethod, VariableDefinition provider)
  282. {
  283. // var replacement = MethodReplacementProvider.GetReplacement(info);
  284. IL.Emit(OpCodes.Ldloc, provider);
  285. // Push the host instance
  286. var pushInstance = hostMethod.HasThis ? IL.Create(OpCodes.Ldarg_0) : IL.Create(OpCodes.Ldnull);
  287. IL.Append(pushInstance);
  288. IL.Emit(OpCodes.Ldloc, _invocationInfo);
  289. IL.Emit(OpCodes.Callvirt, _getReplacement);
  290. IL.Emit(OpCodes.Stloc, _replacement);
  291. }
  292. protected override bool ShouldReplace(Instruction oldInstruction, MethodDefinition hostMethod)
  293. {
  294. // Intercept the call and callvirt instructions
  295. var opCode = oldInstruction.OpCode;
  296. if (opCode != OpCodes.Callvirt && opCode != OpCodes.Call)
  297. return false;
  298. var targetMethod = (MethodReference) oldInstruction.Operand;
  299. var declaringType = targetMethod.DeclaringType;
  300. //return _hostMethodFilter(hostMethod) && _methodCallFilter(targetMethod);
  301. return _callFilter.ShouldWeave(hostMethod.DeclaringType, hostMethod, targetMethod);
  302. }
  303. public bool ShouldWeave(MethodDefinition item)
  304. {
  305. // Modify everything by default
  306. return item.HasBody;
  307. }
  308. public void Weave(MethodDefinition item)
  309. {
  310. Rewrite(item,item.GetILGenerator(),item.Body.Instructions.ToArray());
  311. }
  312. }
  313. }