PageRenderTime 55ms CodeModel.GetById 23ms RepoModel.GetById 0ms app.codeStats 0ms

/mcs/class/referencesource/System.Data.Linq/SqlClient/Query/SqlExpander.cs

http://github.com/mono/mono
C# | 349 lines | 306 code | 25 blank | 18 comment | 81 complexity | 409dca123fcb759e071498b8f3a52f1d MD5 | raw file
Possible License(s): GPL-2.0, CC-BY-SA-3.0, LGPL-2.0, MPL-2.0-no-copyleft-exception, LGPL-2.1, Unlicense, Apache-2.0
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq.Expressions;
  4. using System.Reflection;
  5. using System.Data.Linq;
  6. using System.Data.Linq.Mapping;
  7. using System.Data.Linq.Provider;
  8. using System.Linq;
  9. using System.Data.Linq.SqlClient;
  10. using System.Diagnostics.CodeAnalysis;
  11. namespace System.Data.Linq.SqlClient {
  12. // duplicates an expression up until a column or column ref is encountered
  13. // goes 'deep' through alias ref's
  14. // assumes that columnizing has been done already
  15. internal class SqlExpander {
  16. SqlFactory factory;
  17. internal SqlExpander(SqlFactory factory) {
  18. this.factory = factory;
  19. }
  20. internal SqlExpression Expand(SqlExpression exp) {
  21. return (new Visitor(this.factory)).VisitExpression(exp);
  22. }
  23. class Visitor : SqlDuplicator.DuplicatingVisitor {
  24. SqlFactory factory;
  25. Expression sourceExpression;
  26. internal Visitor(SqlFactory factory)
  27. : base(true) {
  28. this.factory = factory;
  29. }
  30. internal override SqlExpression VisitColumnRef(SqlColumnRef cref) {
  31. return cref;
  32. }
  33. internal override SqlExpression VisitColumn(SqlColumn col) {
  34. return new SqlColumnRef(col);
  35. }
  36. internal override SqlExpression VisitSharedExpression(SqlSharedExpression shared) {
  37. return this.VisitExpression(shared.Expression);
  38. }
  39. internal override SqlExpression VisitSharedExpressionRef(SqlSharedExpressionRef sref) {
  40. return this.VisitExpression(sref.SharedExpression.Expression);
  41. }
  42. internal override SqlExpression VisitAliasRef(SqlAliasRef aref) {
  43. SqlNode node = aref.Alias.Node;
  44. if (node is SqlTable || node is SqlTableValuedFunctionCall) {
  45. return aref;
  46. }
  47. SqlUnion union = node as SqlUnion;
  48. if (union != null) {
  49. return this.ExpandUnion(union);
  50. }
  51. SqlSelect ss = node as SqlSelect;
  52. if (ss != null) {
  53. return this.VisitExpression(ss.Selection);
  54. }
  55. SqlExpression exp = node as SqlExpression;
  56. if (exp != null)
  57. return this.VisitExpression(exp);
  58. throw Error.CouldNotHandleAliasRef(node.NodeType);
  59. }
  60. internal override SqlExpression VisitSubSelect(SqlSubSelect ss) {
  61. return (SqlExpression)new SqlDuplicator().Duplicate(ss);
  62. }
  63. internal override SqlNode VisitLink(SqlLink link) {
  64. SqlExpression expansion = this.VisitExpression(link.Expansion);
  65. SqlExpression[] exprs = new SqlExpression[link.KeyExpressions.Count];
  66. for (int i = 0, n = exprs.Length; i < n; i++) {
  67. exprs[i] = this.VisitExpression(link.KeyExpressions[i]);
  68. }
  69. return new SqlLink(link.Id, link.RowType, link.ClrType, link.SqlType, link.Expression, link.Member, exprs, expansion, link.SourceExpression);
  70. }
  71. private SqlExpression ExpandUnion(SqlUnion union) {
  72. List<SqlExpression> exprs = new List<SqlExpression>(2);
  73. this.GatherUnionExpressions(union, exprs);
  74. this.sourceExpression = union.SourceExpression;
  75. SqlExpression result = this.ExpandTogether(exprs);
  76. return result;
  77. }
  78. private void GatherUnionExpressions(SqlNode node, List<SqlExpression> exprs) {
  79. SqlUnion union = node as SqlUnion;
  80. if (union != null) {
  81. this.GatherUnionExpressions(union.Left, exprs);
  82. this.GatherUnionExpressions(union.Right, exprs);
  83. }
  84. else {
  85. SqlSelect sel = node as SqlSelect;
  86. if (sel != null) {
  87. SqlAliasRef aref = sel.Selection as SqlAliasRef;
  88. if (aref != null) {
  89. this.GatherUnionExpressions(aref.Alias.Node, exprs);
  90. }
  91. else {
  92. exprs.Add(sel.Selection);
  93. }
  94. }
  95. }
  96. }
  97. [SuppressMessage("Microsoft.Performance", "CA1809:AvoidExcessiveLocals", Justification="These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
  98. [SuppressMessage("Microsoft.Maintainability", "CA1505:AvoidUnmaintainableCode", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
  99. [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
  100. private SqlExpression ExpandTogether(List<SqlExpression> exprs) {
  101. switch (exprs[0].NodeType) {
  102. case SqlNodeType.MethodCall: {
  103. SqlMethodCall[] mcs = new SqlMethodCall[exprs.Count];
  104. for (int i = 0; i < mcs.Length; ++i) {
  105. mcs[i] = (SqlMethodCall)exprs[i];
  106. }
  107. List<SqlExpression> expandedArgs = new List<SqlExpression>();
  108. for (int i = 0; i < mcs[0].Arguments.Count; ++i) {
  109. List<SqlExpression> args = new List<SqlExpression>();
  110. for (int j = 0; j < mcs.Length; ++j) {
  111. args.Add(mcs[j].Arguments[i]);
  112. }
  113. SqlExpression expanded = this.ExpandTogether(args);
  114. expandedArgs.Add(expanded);
  115. }
  116. return factory.MethodCall(mcs[0].Method, mcs[0].Object, expandedArgs.ToArray(), mcs[0].SourceExpression);
  117. }
  118. case SqlNodeType.ClientCase: {
  119. // Are they all the same?
  120. SqlClientCase[] scs = new SqlClientCase[exprs.Count];
  121. scs[0] = (SqlClientCase)exprs[0];
  122. for (int i = 1; i < scs.Length; ++i) {
  123. scs[i] = (SqlClientCase)exprs[i];
  124. }
  125. // Expand expressions together.
  126. List<SqlExpression> expressions = new List<SqlExpression>();
  127. for (int i = 0; i < scs.Length; ++i) {
  128. expressions.Add(scs[i].Expression);
  129. }
  130. SqlExpression expression = this.ExpandTogether(expressions);
  131. // Expand individual expressions together.
  132. List<SqlClientWhen> whens = new List<SqlClientWhen>();
  133. for (int i = 0; i < scs[0].Whens.Count; ++i) {
  134. List<SqlExpression> scos = new List<SqlExpression>();
  135. for (int j = 0; j < scs.Length; ++j) {
  136. SqlClientWhen when = scs[j].Whens[i];
  137. scos.Add(when.Value);
  138. }
  139. whens.Add(new SqlClientWhen(scs[0].Whens[i].Match, this.ExpandTogether(scos)));
  140. }
  141. return new SqlClientCase(scs[0].ClrType, expression, whens, scs[0].SourceExpression);
  142. }
  143. case SqlNodeType.TypeCase: {
  144. // Are they all the same?
  145. SqlTypeCase[] tcs = new SqlTypeCase[exprs.Count];
  146. tcs[0] = (SqlTypeCase)exprs[0];
  147. for (int i = 1; i < tcs.Length; ++i) {
  148. tcs[i] = (SqlTypeCase)exprs[i];
  149. }
  150. // Expand discriminators together.
  151. List<SqlExpression> discriminators = new List<SqlExpression>();
  152. for (int i = 0; i < tcs.Length; ++i) {
  153. discriminators.Add(tcs[i].Discriminator);
  154. }
  155. SqlExpression discriminator = this.ExpandTogether(discriminators);
  156. // Write expanded discriminators back in.
  157. for (int i = 0; i < tcs.Length; ++i) {
  158. tcs[i].Discriminator = discriminators[i];
  159. }
  160. // Expand individual type bindings together.
  161. List<SqlTypeCaseWhen> whens = new List<SqlTypeCaseWhen>();
  162. for (int i = 0; i < tcs[0].Whens.Count; ++i) {
  163. List<SqlExpression> scos = new List<SqlExpression>();
  164. for (int j = 0; j < tcs.Length; ++j) {
  165. SqlTypeCaseWhen when = tcs[j].Whens[i];
  166. scos.Add(when.TypeBinding);
  167. }
  168. SqlExpression expanded = this.ExpandTogether(scos);
  169. whens.Add(new SqlTypeCaseWhen(tcs[0].Whens[i].Match, expanded));
  170. }
  171. return factory.TypeCase(tcs[0].ClrType, tcs[0].RowType, discriminator, whens, tcs[0].SourceExpression);
  172. }
  173. case SqlNodeType.New: {
  174. // first verify all are similar client objects...
  175. SqlNew[] cobs = new SqlNew[exprs.Count];
  176. cobs[0] = (SqlNew)exprs[0];
  177. for (int i = 1, n = exprs.Count; i < n; i++) {
  178. if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.New)
  179. throw Error.UnionIncompatibleConstruction();
  180. cobs[i] = (SqlNew)exprs[1];
  181. if (cobs[i].Members.Count != cobs[0].Members.Count)
  182. throw Error.UnionDifferentMembers();
  183. for (int m = 0, mn = cobs[0].Members.Count; m < mn; m++) {
  184. if (cobs[i].Members[m].Member != cobs[0].Members[m].Member) {
  185. throw Error.UnionDifferentMemberOrder();
  186. }
  187. }
  188. }
  189. SqlMemberAssign[] bindings = new SqlMemberAssign[cobs[0].Members.Count];
  190. for (int m = 0, mn = bindings.Length; m < mn; m++) {
  191. List<SqlExpression> mexprs = new List<SqlExpression>();
  192. for (int i = 0, n = exprs.Count; i < n; i++) {
  193. mexprs.Add(cobs[i].Members[m].Expression);
  194. }
  195. bindings[m] = new SqlMemberAssign(cobs[0].Members[m].Member, this.ExpandTogether(mexprs));
  196. for (int i = 0, n = exprs.Count; i < n; i++) {
  197. cobs[i].Members[m].Expression = mexprs[i];
  198. }
  199. }
  200. SqlExpression[] arguments = new SqlExpression[cobs[0].Args.Count];
  201. for (int m = 0, mn = arguments.Length; m < mn; ++m) {
  202. List<SqlExpression> mexprs = new List<SqlExpression>();
  203. for (int i = 0, n = exprs.Count; i < n; i++) {
  204. mexprs.Add(cobs[i].Args[m]);
  205. }
  206. arguments[m] = ExpandTogether(mexprs);
  207. }
  208. return factory.New(cobs[0].MetaType, cobs[0].Constructor, arguments, cobs[0].ArgMembers, bindings, exprs[0].SourceExpression);
  209. }
  210. case SqlNodeType.Link: {
  211. SqlLink[] links = new SqlLink[exprs.Count];
  212. links[0] = (SqlLink)exprs[0];
  213. for (int i = 1, n = exprs.Count; i < n; i++) {
  214. if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.Link)
  215. throw Error.UnionIncompatibleConstruction();
  216. links[i] = (SqlLink)exprs[i];
  217. if (links[i].KeyExpressions.Count != links[0].KeyExpressions.Count ||
  218. links[i].Member != links[0].Member ||
  219. (links[i].Expansion != null) != (links[0].Expansion != null))
  220. throw Error.UnionIncompatibleConstruction();
  221. }
  222. SqlExpression[] kexprs = new SqlExpression[links[0].KeyExpressions.Count];
  223. List<SqlExpression> lexprs = new List<SqlExpression>();
  224. for (int k = 0, nk = links[0].KeyExpressions.Count; k < nk; k++) {
  225. lexprs.Clear();
  226. for (int i = 0, n = exprs.Count; i < n; i++) {
  227. lexprs.Add(links[i].KeyExpressions[k]);
  228. }
  229. kexprs[k] = this.ExpandTogether(lexprs);
  230. for (int i = 0, n = exprs.Count; i < n; i++) {
  231. links[i].KeyExpressions[k] = lexprs[i];
  232. }
  233. }
  234. SqlExpression expansion = null;
  235. if (links[0].Expansion != null) {
  236. lexprs.Clear();
  237. for (int i = 0, n = exprs.Count; i < n; i++) {
  238. lexprs.Add(links[i].Expansion);
  239. }
  240. expansion = this.ExpandTogether(lexprs);
  241. for (int i = 0, n = exprs.Count; i < n; i++) {
  242. links[i].Expansion = lexprs[i];
  243. }
  244. }
  245. return new SqlLink(links[0].Id, links[0].RowType, links[0].ClrType, links[0].SqlType, links[0].Expression, links[0].Member, kexprs, expansion, links[0].SourceExpression);
  246. }
  247. case SqlNodeType.Value: {
  248. /*
  249. * ExprSet of all literals of the same value reduce to just a single literal.
  250. */
  251. SqlValue val0 = (SqlValue)exprs[0];
  252. for (int i = 1; i < exprs.Count; ++i) {
  253. SqlValue val = (SqlValue)exprs[i];
  254. if (!object.Equals(val.Value, val0.Value))
  255. return this.ExpandIntoExprSet(exprs);
  256. }
  257. return val0;
  258. }
  259. case SqlNodeType.OptionalValue: {
  260. if (exprs[0].SqlType.CanBeColumn) {
  261. goto default;
  262. }
  263. List<SqlExpression> hvals = new List<SqlExpression>(exprs.Count);
  264. List<SqlExpression> vals = new List<SqlExpression>(exprs.Count);
  265. for (int i = 0, n = exprs.Count; i < n; i++) {
  266. if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.OptionalValue) {
  267. throw Error.UnionIncompatibleConstruction();
  268. }
  269. SqlOptionalValue sov = (SqlOptionalValue)exprs[i];
  270. hvals.Add(sov.HasValue);
  271. vals.Add(sov.Value);
  272. }
  273. return new SqlOptionalValue(this.ExpandTogether(hvals), this.ExpandTogether(vals));
  274. }
  275. case SqlNodeType.OuterJoinedValue: {
  276. if (exprs[0].SqlType.CanBeColumn) {
  277. goto default;
  278. }
  279. List<SqlExpression> values = new List<SqlExpression>(exprs.Count);
  280. for (int i = 0, n = exprs.Count; i < n; i++) {
  281. if (exprs[i] == null || exprs[i].NodeType != SqlNodeType.OuterJoinedValue) {
  282. throw Error.UnionIncompatibleConstruction();
  283. }
  284. SqlUnary su = (SqlUnary)exprs[i];
  285. values.Add(su.Operand);
  286. }
  287. return factory.Unary(SqlNodeType.OuterJoinedValue, this.ExpandTogether(values));
  288. }
  289. case SqlNodeType.DiscriminatedType: {
  290. SqlDiscriminatedType sdt0 = (SqlDiscriminatedType)exprs[0];
  291. List<SqlExpression> foos = new List<SqlExpression>(exprs.Count);
  292. foos.Add(sdt0.Discriminator);
  293. for (int i = 1, n = exprs.Count; i < n; i++) {
  294. SqlDiscriminatedType sdtN = (SqlDiscriminatedType)exprs[i];
  295. if (sdtN.TargetType != sdt0.TargetType) {
  296. throw Error.UnionIncompatibleConstruction();
  297. }
  298. foos.Add(sdtN.Discriminator);
  299. }
  300. return factory.DiscriminatedType(this.ExpandTogether(foos), ((SqlDiscriminatedType)exprs[0]).TargetType);
  301. }
  302. case SqlNodeType.ClientQuery:
  303. case SqlNodeType.Multiset:
  304. case SqlNodeType.Element:
  305. case SqlNodeType.Grouping:
  306. throw Error.UnionWithHierarchy();
  307. default:
  308. return this.ExpandIntoExprSet(exprs);
  309. }
  310. }
  311. /// <summary>
  312. /// Expand a set of expressions into a single expr set.
  313. /// This is typically a fallback when there is no other way to unify a set of expressions.
  314. /// </summary>
  315. private SqlExpression ExpandIntoExprSet(List<SqlExpression> exprs) {
  316. SqlExpression[] rexprs = new SqlExpression[exprs.Count];
  317. for (int i = 0, n = exprs.Count; i < n; i++) {
  318. rexprs[i] = this.VisitExpression(exprs[i]);
  319. }
  320. return this.factory.ExprSet(rexprs, this.sourceExpression);
  321. }
  322. }
  323. }
  324. }