PageRenderTime 61ms CodeModel.GetById 14ms RepoModel.GetById 0ms app.codeStats 0ms

/mcs/class/referencesource/System.Data.Linq/SqlClient/Query/QueryConverter.cs

http://github.com/mono/mono
C# | 2887 lines | 2342 code | 301 blank | 244 comment | 771 complexity | 74b75f6e2c6ad9ddcfb9453582b373da MD5 | raw file
Possible License(s): GPL-2.0, CC-BY-SA-3.0, LGPL-2.0, MPL-2.0-no-copyleft-exception, LGPL-2.1, Unlicense, Apache-2.0
  1. using System;
  2. using System.Globalization;
  3. using System.Collections;
  4. using System.Collections.Generic;
  5. using System.Data;
  6. using System.Reflection;
  7. using System.Text;
  8. using System.Linq;
  9. using System.Linq.Expressions;
  10. using System.Data.Linq;
  11. using System.Data.Linq.Mapping;
  12. using System.Data.Linq.Provider;
  13. using System.Collections.ObjectModel;
  14. using System.Diagnostics.CodeAnalysis;
  15. namespace System.Data.Linq.SqlClient {
  16. /// <summary>
  17. /// These are application types used to represent types used during intermediate
  18. /// stages of the query building process.
  19. /// </summary>
  20. enum ConverterSpecialTypes {
  21. Row,
  22. Table
  23. }
  24. [Flags]
  25. internal enum ConverterStrategy {
  26. Default = 0x0,
  27. SkipWithRowNumber = 0x1,
  28. CanUseScopeIdentity = 0x2,
  29. CanUseOuterApply = 0x4,
  30. CanUseRowStatus = 0x8,
  31. CanUseJoinOn = 0x10, // Whether or not to use ON clause of JOIN.
  32. CanOutputFromInsert = 0x20
  33. }
  34. [SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification="Unknown reason.")]
  35. internal class QueryConverter {
  36. IDataServices services;
  37. Translator translator;
  38. SqlFactory sql;
  39. TypeSystemProvider typeProvider;
  40. bool outerNode;
  41. Dictionary<ParameterExpression, SqlExpression> map;
  42. Dictionary<ParameterExpression, Expression> exprMap;
  43. Dictionary<ParameterExpression, SqlNode> dupMap;
  44. Dictionary<SqlNode, GroupInfo> gmap;
  45. Expression dominatingExpression;
  46. bool allowDeferred;
  47. ConverterStrategy converterStrategy = ConverterStrategy.Default;
  48. class GroupInfo {
  49. internal SqlSelect SelectWithGroup;
  50. internal SqlExpression ElementOnGroupSource;
  51. }
  52. internal ConverterStrategy ConverterStrategy {
  53. get { return converterStrategy; }
  54. set { converterStrategy = value; }
  55. }
  56. private bool UseConverterStrategy(ConverterStrategy strategy) {
  57. return (this.converterStrategy & strategy) == strategy;
  58. }
  59. internal QueryConverter(IDataServices services, TypeSystemProvider typeProvider, Translator translator, SqlFactory sql) {
  60. if (services == null) {
  61. throw Error.ArgumentNull("services");
  62. }
  63. if (sql == null) {
  64. throw Error.ArgumentNull("sql");
  65. }
  66. if (translator == null) {
  67. throw Error.ArgumentNull("translator");
  68. }
  69. if (typeProvider == null) {
  70. throw Error.ArgumentNull("typeProvider");
  71. }
  72. this.services = services;
  73. this.translator = translator;
  74. this.sql = sql;
  75. this.typeProvider = typeProvider;
  76. this.map = new Dictionary<ParameterExpression, SqlExpression>();
  77. this.exprMap = new Dictionary<ParameterExpression, Expression>();
  78. this.dupMap = new Dictionary<ParameterExpression, SqlNode>();
  79. this.gmap = new Dictionary<SqlNode, GroupInfo>();
  80. this.allowDeferred = true;
  81. }
  82. /// <summary>
  83. /// Convert inner expression from C# expression to basic SQL Query.
  84. /// </summary>
  85. /// <param name="node">The expression to convert.</param>
  86. /// <returns>The converted SQL query.</returns>
  87. internal SqlNode ConvertOuter(Expression node) {
  88. this.dominatingExpression = node;
  89. this.outerNode = true;
  90. SqlNode retNode;
  91. if (typeof(ITable).IsAssignableFrom(node.Type)) {
  92. retNode = this.VisitSequence(node);
  93. }
  94. else {
  95. retNode = this.VisitInner(node);
  96. }
  97. if (retNode.NodeType == SqlNodeType.MethodCall) {
  98. // if a tree consists of a single method call expression only, that method
  99. // must be either a mapped stored procedure or a mapped function
  100. throw Error.InvalidMethodExecution(((SqlMethodCall)retNode).Method.Name);
  101. }
  102. // if after conversion the node is an expression, we must
  103. // wrap it in a select
  104. SqlExpression sqlExpression = retNode as SqlExpression;
  105. if (sqlExpression != null) {
  106. retNode = new SqlSelect(sqlExpression, null, this.dominatingExpression);
  107. }
  108. retNode = new SqlIncludeScope(retNode, this.dominatingExpression);
  109. return retNode;
  110. }
  111. internal SqlNode Visit(Expression node) {
  112. bool tempOuterNode = this.outerNode;
  113. this.outerNode = false;
  114. SqlNode result = this.VisitInner(node);
  115. this.outerNode = tempOuterNode;
  116. return result;
  117. }
  118. /// <summary>
  119. /// Convert inner expression from C# expression to basic SQL Query.
  120. /// </summary>
  121. /// <param name="node">The expression to convert.</param>
  122. /// <param name="dominantExpression">Current dominating expression, used for producing meaningful exception text.</param>
  123. /// <returns>The converted SQL query.</returns>
  124. internal SqlNode ConvertInner(Expression node, Expression dominantExpression) {
  125. this.dominatingExpression = dominantExpression;
  126. bool tempOuterNode = this.outerNode;
  127. this.outerNode = false;
  128. SqlNode result = this.VisitInner(node);
  129. this.outerNode = tempOuterNode;
  130. return result;
  131. }
  132. [SuppressMessage("Microsoft.Performance", "CA1800:DoNotCastUnnecessarily", Justification = "Microsoft: Cast is dependent on node type and casts do not happen unecessarily in a single code path.")]
  133. [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
  134. private SqlNode VisitInner(Expression node) {
  135. if (node == null) return null;
  136. Expression save = this.dominatingExpression;
  137. this.dominatingExpression = ChooseBestDominatingExpression(this.dominatingExpression, node);
  138. try {
  139. switch (node.NodeType) {
  140. case ExpressionType.New:
  141. return this.VisitNew((NewExpression)node);
  142. case ExpressionType.MemberInit:
  143. return this.VisitMemberInit((MemberInitExpression)node);
  144. case ExpressionType.Negate:
  145. case ExpressionType.NegateChecked:
  146. case ExpressionType.Not:
  147. return this.VisitUnary((UnaryExpression)node);
  148. case ExpressionType.UnaryPlus:
  149. if (node.Type == typeof(TimeSpan))
  150. return this.VisitUnary((UnaryExpression)node);
  151. throw Error.UnrecognizedExpressionNode(node.NodeType);
  152. case ExpressionType.Add:
  153. case ExpressionType.AddChecked:
  154. case ExpressionType.Subtract:
  155. case ExpressionType.SubtractChecked:
  156. case ExpressionType.Multiply:
  157. case ExpressionType.MultiplyChecked:
  158. case ExpressionType.Divide:
  159. case ExpressionType.Modulo:
  160. case ExpressionType.And:
  161. case ExpressionType.AndAlso:
  162. case ExpressionType.Or:
  163. case ExpressionType.OrElse:
  164. case ExpressionType.Power:
  165. case ExpressionType.LessThan:
  166. case ExpressionType.LessThanOrEqual:
  167. case ExpressionType.GreaterThan:
  168. case ExpressionType.GreaterThanOrEqual:
  169. case ExpressionType.Equal:
  170. case ExpressionType.NotEqual:
  171. case ExpressionType.Coalesce:
  172. case ExpressionType.ExclusiveOr:
  173. return this.VisitBinary((BinaryExpression)node);
  174. case ExpressionType.ArrayIndex:
  175. return this.VisitArrayIndex((BinaryExpression)node);
  176. case ExpressionType.TypeIs:
  177. return this.VisitTypeBinary((TypeBinaryExpression)node);
  178. case ExpressionType.Convert:
  179. case ExpressionType.ConvertChecked:
  180. return this.VisitCast((UnaryExpression)node);
  181. case ExpressionType.TypeAs:
  182. return this.VisitAs((UnaryExpression)node);
  183. case ExpressionType.Conditional:
  184. return this.VisitConditional((ConditionalExpression)node);
  185. case ExpressionType.Constant:
  186. return this.VisitConstant((ConstantExpression)node);
  187. case ExpressionType.Parameter:
  188. return this.VisitParameter((ParameterExpression)node);
  189. case ExpressionType.MemberAccess:
  190. return this.VisitMemberAccess((MemberExpression)node);
  191. case ExpressionType.Call:
  192. return this.VisitMethodCall((MethodCallExpression)node);
  193. case ExpressionType.ArrayLength:
  194. return this.VisitArrayLength((UnaryExpression)node);
  195. case ExpressionType.NewArrayInit:
  196. return this.VisitNewArrayInit((NewArrayExpression)node);
  197. case ExpressionType.ListInit:
  198. return this.VisitListInit((ListInitExpression)node);
  199. case ExpressionType.Quote:
  200. return this.Visit(((UnaryExpression)node).Operand);
  201. case ExpressionType.Invoke:
  202. return this.VisitInvocation((InvocationExpression)node);
  203. case ExpressionType.Lambda:
  204. return this.VisitLambda((LambdaExpression)node);
  205. case ExpressionType.RightShift:
  206. case ExpressionType.LeftShift:
  207. throw Error.UnsupportedNodeType(node.NodeType);
  208. case (ExpressionType)InternalExpressionType.Known:
  209. return ((KnownExpression)node).Node;
  210. case (ExpressionType)InternalExpressionType.LinkedTable:
  211. return this.VisitLinkedTable((LinkedTableExpression)node);
  212. default:
  213. throw Error.UnrecognizedExpressionNode(node.NodeType);
  214. }
  215. }
  216. finally {
  217. this.dominatingExpression = save;
  218. }
  219. }
  220. /// <summary>
  221. /// Heuristic which chooses the best Expression root to use for displaying user messages
  222. /// and exception text.
  223. /// </summary>
  224. private static Expression ChooseBestDominatingExpression(Expression last, Expression next) {
  225. if (last == null) {
  226. return next;
  227. }
  228. else if (next == null) {
  229. return last;
  230. }
  231. else {
  232. if (next is MethodCallExpression) {
  233. return next;
  234. }
  235. if (last is MethodCallExpression) {
  236. return last;
  237. }
  238. }
  239. return next;
  240. }
  241. private SqlSelect LockSelect(SqlSelect sel) {
  242. if (sel.Selection.NodeType != SqlNodeType.AliasRef ||
  243. sel.Where != null ||
  244. sel.OrderBy.Count > 0 ||
  245. sel.GroupBy.Count > 0 ||
  246. sel.Having != null ||
  247. sel.Top != null ||
  248. sel.OrderingType != SqlOrderingType.Default ||
  249. sel.IsDistinct) {
  250. SqlAlias alias = new SqlAlias(sel);
  251. SqlAliasRef aref = new SqlAliasRef(alias);
  252. return new SqlSelect(aref, alias, this.dominatingExpression);
  253. }
  254. return sel;
  255. }
  256. private SqlSelect VisitSequence(Expression exp) {
  257. return this.CoerceToSequence(this.Visit(exp));
  258. }
  259. private SqlSelect CoerceToSequence(SqlNode node) {
  260. SqlSelect select = node as SqlSelect;
  261. if (select == null) {
  262. if (node.NodeType == SqlNodeType.Value) {
  263. SqlValue sv = (SqlValue)node;
  264. // Check for ITables.
  265. ITable t = sv.Value as ITable;
  266. if (t != null) {
  267. return this.CoerceToSequence(this.TranslateConstantTable(t, null));
  268. }
  269. // Check for IQueryable.
  270. IQueryable query = sv.Value as IQueryable;
  271. if (query != null) {
  272. Expression fex = Funcletizer.Funcletize(query.Expression);
  273. // IQueryables that return self-referencing Constant expressions cause infinite recursion
  274. if (fex.NodeType != ExpressionType.Constant ||
  275. ((ConstantExpression)fex).Value != query) {
  276. return this.VisitSequence(fex);
  277. }
  278. throw Error.IQueryableCannotReturnSelfReferencingConstantExpression();
  279. }
  280. throw Error.CapturedValuesCannotBeSequences();
  281. }
  282. else if (node.NodeType == SqlNodeType.Multiset || node.NodeType == SqlNodeType.Element) {
  283. return ((SqlSubSelect)node).Select;
  284. }
  285. else if (node.NodeType == SqlNodeType.ClientArray) {
  286. throw Error.ConstructedArraysNotSupported();
  287. }
  288. else if (node.NodeType == SqlNodeType.ClientParameter) {
  289. throw Error.ParametersCannotBeSequences();
  290. }
  291. // this needs to be a sequence expression!
  292. SqlExpression sqlExpr = (SqlExpression)node;
  293. SqlAlias sa = new SqlAlias(sqlExpr);
  294. SqlAliasRef aref = new SqlAliasRef(sa);
  295. return new SqlSelect(aref, sa, this.dominatingExpression);
  296. }
  297. return select;
  298. }
  299. //
  300. // Recursive call to VisitInvocation.
  301. [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.NoInlining)]
  302. private SqlNode VisitInvocation(InvocationExpression invoke) {
  303. LambdaExpression lambda =
  304. (invoke.Expression.NodeType == ExpressionType.Quote)
  305. ? (LambdaExpression)((UnaryExpression)invoke.Expression).Operand
  306. : (invoke.Expression as LambdaExpression);
  307. if (lambda != null) {
  308. // just map arg values into lambda's parameters and evaluate lambda's body
  309. for (int i = 0, n = invoke.Arguments.Count; i < n; i++) {
  310. this.exprMap[lambda.Parameters[i]] = invoke.Arguments[i];
  311. }
  312. return this.VisitInner(lambda.Body);
  313. }
  314. else {
  315. // check for compiled query invocation
  316. SqlExpression expr = this.VisitExpression(invoke.Expression);
  317. if (expr.NodeType == SqlNodeType.Value) {
  318. SqlValue value = (SqlValue)expr;
  319. Delegate d = value.Value as Delegate;
  320. if (d != null) {
  321. CompiledQuery cq = d.Target as CompiledQuery;
  322. if (cq != null) {
  323. return this.VisitInvocation(Expression.Invoke(cq.Expression, invoke.Arguments));
  324. } else if (invoke.Arguments.Count == 0) {
  325. object invokeResult;
  326. try {
  327. invokeResult = d.DynamicInvoke(null);
  328. } catch (System.Reflection.TargetInvocationException e) {
  329. throw e.InnerException;
  330. }
  331. return this.sql.ValueFromObject(invokeResult, invoke.Type, true, this.dominatingExpression);
  332. }
  333. }
  334. }
  335. SqlExpression [] args = new SqlExpression[invoke.Arguments.Count];
  336. for(int i = 0; i<args.Length; ++i) {
  337. args[i] = (SqlExpression)this.Visit(invoke.Arguments[i]);
  338. }
  339. var sca = new SqlClientArray(typeof(object[]), this.typeProvider.From(typeof(object[])), args, this.dominatingExpression);
  340. return sql.MethodCall(invoke.Type, typeof(Delegate).GetMethod("DynamicInvoke"), expr, new SqlExpression[] {sca}, this.dominatingExpression);
  341. }
  342. }
  343. // inline lambda expressions w/o invocation are parameterized queries
  344. private SqlNode VisitLambda(LambdaExpression lambda) {
  345. // turn lambda parameters into client parameters
  346. for (int i = 0, n = lambda.Parameters.Count; i < n; i++) {
  347. ParameterExpression p = lambda.Parameters[i];
  348. if (p.Type == typeof(Type)) {
  349. throw Error.BadParameterType(p.Type);
  350. }
  351. // construct accessor for parameter
  352. ParameterExpression pa = Expression.Parameter(typeof(object[]), "args");
  353. LambdaExpression accessor =
  354. Expression.Lambda(
  355. typeof(Func<,>).MakeGenericType(typeof(object[]), p.Type),
  356. Expression.Convert(
  357. #pragma warning disable 618 // Disable the 'obsolete' warning
  358. Expression.ArrayIndex(pa, Expression.Constant(i)),
  359. p.Type
  360. ),
  361. #pragma warning restore 618
  362. pa
  363. );
  364. SqlClientParameter cp = new SqlClientParameter(p.Type, this.typeProvider.From(p.Type), accessor, this.dominatingExpression);
  365. // map references to lambda's parameter to client parameter node
  366. this.dupMap[p] = cp;
  367. }
  368. // call this so we don't erase 'outerNode' setting
  369. return this.VisitInner(lambda.Body);
  370. }
  371. private SqlExpression VisitExpression(Expression exp) {
  372. SqlNode result = this.Visit(exp);
  373. if (result == null) return null;
  374. SqlExpression x = result as SqlExpression;
  375. if (x != null) return x;
  376. SqlSelect select = result as SqlSelect;
  377. if (select != null) {
  378. SqlSubSelect ms = sql.SubSelect(SqlNodeType.Multiset, select, exp.Type);
  379. return ms;
  380. }
  381. throw Error.UnrecognizedExpressionNode(result);
  382. }
  383. private SqlSelect VisitSelect(Expression sequence, LambdaExpression selector) {
  384. SqlSelect source = this.VisitSequence(sequence);
  385. SqlAlias alias = new SqlAlias(source);
  386. SqlAliasRef aref = new SqlAliasRef(alias);
  387. this.map[selector.Parameters[0]] = aref;
  388. SqlNode project = this.Visit(selector.Body);
  389. SqlSelect pselect = project as SqlSelect;
  390. if (pselect != null) {
  391. return new SqlSelect(sql.SubSelect(SqlNodeType.Multiset, pselect, selector.Body.Type), alias, this.dominatingExpression);
  392. }
  393. else if ((project.NodeType == SqlNodeType.Element || project.NodeType == SqlNodeType.ScalarSubSelect) &&
  394. (this.converterStrategy & ConverterStrategy.CanUseOuterApply) != 0) {
  395. SqlSubSelect sub = (SqlSubSelect)project;
  396. SqlSelect inner = sub.Select;
  397. SqlAlias innerAlias = new SqlAlias(inner);
  398. SqlAliasRef innerRef = new SqlAliasRef(innerAlias);
  399. if (project.NodeType == SqlNodeType.Element) {
  400. inner.Selection = new SqlOptionalValue(
  401. new SqlColumn(
  402. "test",
  403. sql.Unary(
  404. SqlNodeType.OuterJoinedValue,
  405. sql.Value(typeof(int?), this.typeProvider.From(typeof(int)), 1, false, this.dominatingExpression)
  406. )
  407. ),
  408. sql.Unary(SqlNodeType.OuterJoinedValue, inner.Selection)
  409. );
  410. }
  411. else {
  412. inner.Selection = sql.Unary(SqlNodeType.OuterJoinedValue, inner.Selection);
  413. }
  414. SqlJoin join = new SqlJoin(SqlJoinType.OuterApply, alias, innerAlias, null, this.dominatingExpression);
  415. return new SqlSelect(innerRef, join, this.dominatingExpression);
  416. }
  417. else {
  418. SqlExpression expr = project as SqlExpression;
  419. if (expr != null) {
  420. return new SqlSelect(expr, alias, this.dominatingExpression);
  421. }
  422. else {
  423. throw Error.BadProjectionInSelect();
  424. }
  425. }
  426. }
  427. private SqlSelect VisitSelectMany(Expression sequence, LambdaExpression colSelector, LambdaExpression resultSelector) {
  428. SqlSelect seqSelect = this.VisitSequence(sequence);
  429. SqlAlias seqAlias = new SqlAlias(seqSelect);
  430. SqlAliasRef seqRef = new SqlAliasRef(seqAlias);
  431. this.map[colSelector.Parameters[0]] = seqRef;
  432. SqlNode colSelectorNode = this.VisitSequence(colSelector.Body);
  433. SqlAlias selAlias = new SqlAlias(colSelectorNode);
  434. SqlAliasRef selRef = new SqlAliasRef(selAlias);
  435. SqlJoin join = new SqlJoin(SqlJoinType.CrossApply, seqAlias, selAlias, null, this.dominatingExpression);
  436. SqlExpression projection = selRef;
  437. if (resultSelector != null) {
  438. this.map[resultSelector.Parameters[0]] = seqRef;
  439. this.map[resultSelector.Parameters[1]] = selRef;
  440. projection = this.VisitExpression(resultSelector.Body);
  441. }
  442. return new SqlSelect(projection, join, this.dominatingExpression);
  443. }
  444. private SqlSelect VisitJoin(Expression outerSequence, Expression innerSequence, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector) {
  445. SqlSelect outerSelect = this.VisitSequence(outerSequence);
  446. SqlSelect innerSelect = this.VisitSequence(innerSequence);
  447. SqlAlias outerAlias = new SqlAlias(outerSelect);
  448. SqlAliasRef outerRef = new SqlAliasRef(outerAlias);
  449. SqlAlias innerAlias = new SqlAlias(innerSelect);
  450. SqlAliasRef innerRef = new SqlAliasRef(innerAlias);
  451. this.map[outerKeySelector.Parameters[0]] = outerRef;
  452. SqlExpression outerKey = this.VisitExpression(outerKeySelector.Body);
  453. this.map[innerKeySelector.Parameters[0]] = innerRef;
  454. SqlExpression innerKey = this.VisitExpression(innerKeySelector.Body);
  455. this.map[resultSelector.Parameters[0]] = outerRef;
  456. this.map[resultSelector.Parameters[1]] = innerRef;
  457. SqlExpression result = this.VisitExpression(resultSelector.Body);
  458. SqlExpression condition = sql.Binary(SqlNodeType.EQ, outerKey, innerKey);
  459. SqlSelect select = null;
  460. if ((this.converterStrategy & ConverterStrategy.CanUseJoinOn) != 0) {
  461. SqlJoin join = new SqlJoin(SqlJoinType.Inner, outerAlias, innerAlias, condition, this.dominatingExpression);
  462. select = new SqlSelect(result, join, this.dominatingExpression);
  463. } else {
  464. SqlJoin join = new SqlJoin(SqlJoinType.Cross, outerAlias, innerAlias, null, this.dominatingExpression);
  465. select = new SqlSelect(result, join, this.dominatingExpression);
  466. select.Where = condition;
  467. }
  468. return select;
  469. }
  470. private SqlSelect VisitGroupJoin(Expression outerSequence, Expression innerSequence, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector) {
  471. SqlSelect outerSelect = this.VisitSequence(outerSequence);
  472. SqlSelect innerSelect = this.VisitSequence(innerSequence);
  473. SqlAlias outerAlias = new SqlAlias(outerSelect);
  474. SqlAliasRef outerRef = new SqlAliasRef(outerAlias);
  475. SqlAlias innerAlias = new SqlAlias(innerSelect);
  476. SqlAliasRef innerRef = new SqlAliasRef(innerAlias);
  477. this.map[outerKeySelector.Parameters[0]] = outerRef;
  478. SqlExpression outerKey = this.VisitExpression(outerKeySelector.Body);
  479. this.map[innerKeySelector.Parameters[0]] = innerRef;
  480. SqlExpression innerKey = this.VisitExpression(innerKeySelector.Body);
  481. // make multiset
  482. SqlExpression pred = sql.Binary(SqlNodeType.EQ, outerKey, innerKey);
  483. SqlSelect select = new SqlSelect(innerRef, innerAlias, this.dominatingExpression);
  484. select.Where = pred;
  485. SqlSubSelect subquery = sql.SubSelect(SqlNodeType.Multiset, select);
  486. // make outer ref & multiset for result-selector params
  487. this.map[resultSelector.Parameters[0]] = outerRef;
  488. this.dupMap[resultSelector.Parameters[1]] = subquery;
  489. SqlExpression result = this.VisitExpression(resultSelector.Body);
  490. return new SqlSelect(result, outerAlias, this.dominatingExpression);
  491. }
  492. private SqlSelect VisitDefaultIfEmpty(Expression sequence) {
  493. SqlSelect select = this.VisitSequence(sequence);
  494. SqlAlias alias = new SqlAlias(select);
  495. SqlAliasRef aliasRef = new SqlAliasRef(alias);
  496. SqlExpression opt = new SqlOptionalValue(
  497. new SqlColumn(
  498. "test",
  499. sql.Unary(SqlNodeType.OuterJoinedValue,
  500. sql.Value(typeof(int?), this.typeProvider.From(typeof(int)), 1, false, this.dominatingExpression)
  501. )
  502. ),
  503. sql.Unary(SqlNodeType.OuterJoinedValue, aliasRef)
  504. );
  505. SqlSelect optSelect = new SqlSelect(opt, alias, this.dominatingExpression);
  506. alias = new SqlAlias(optSelect);
  507. aliasRef = new SqlAliasRef(alias);
  508. SqlExpression litNull = sql.TypedLiteralNull(typeof(string), this.dominatingExpression);
  509. SqlSelect selNull = new SqlSelect(litNull, null, this.dominatingExpression);
  510. SqlAlias aliasNull = new SqlAlias(selNull);
  511. SqlJoin join = new SqlJoin(SqlJoinType.OuterApply, aliasNull, alias, null, this.dominatingExpression);
  512. return new SqlSelect(aliasRef, join, this.dominatingExpression);
  513. }
  514. /// <summary>
  515. /// Rewrite seq.OfType<T> as seq.Select(s=>s as T).Where(p=>p!=null).
  516. /// </summary>
  517. private SqlSelect VisitOfType(Expression sequence, Type ofType) {
  518. SqlSelect select = this.LockSelect(this.VisitSequence(sequence));
  519. SqlAliasRef aref = (SqlAliasRef)select.Selection;
  520. select.Selection = new SqlUnary(SqlNodeType.Treat, ofType, typeProvider.From(ofType), aref, this.dominatingExpression);
  521. select = this.LockSelect(select);
  522. aref = (SqlAliasRef)select.Selection;
  523. // Append the 'is' operator into the WHERE clause.
  524. select.Where = sql.AndAccumulate(select.Where,
  525. sql.Unary(SqlNodeType.IsNotNull, aref, this.dominatingExpression)
  526. );
  527. return select;
  528. }
  529. /// <summary>
  530. /// Rewrite seq.Cast<T> as seq.Select(s=>(T)s).
  531. /// </summary>
  532. private SqlNode VisitSequenceCast(Expression sequence, Type type) {
  533. Type sourceType = TypeSystem.GetElementType(sequence.Type);
  534. ParameterExpression p = Expression.Parameter(sourceType, "pc");
  535. return this.Visit(Expression.Call(
  536. typeof(Enumerable), "Select",
  537. new Type[] {
  538. sourceType, // TSource element type.
  539. type, // TResult element type.
  540. },
  541. sequence,
  542. Expression.Lambda(
  543. Expression.Convert(p, type),
  544. new ParameterExpression[] { p }
  545. ))
  546. );
  547. }
  548. /// <summary>
  549. /// This is the 'is' operator.
  550. /// </summary>
  551. private SqlNode VisitTypeBinary(TypeBinaryExpression b) {
  552. SqlExpression expr = this.VisitExpression(b.Expression);
  553. SqlExpression result = null;
  554. switch (b.NodeType) {
  555. case ExpressionType.TypeIs:
  556. Type ofType = b.TypeOperand;
  557. result = sql.Unary(SqlNodeType.IsNotNull, new SqlUnary(SqlNodeType.Treat, ofType, typeProvider.From(ofType), expr, this.dominatingExpression), this.dominatingExpression);
  558. break;
  559. default:
  560. throw Error.TypeBinaryOperatorNotRecognized();
  561. }
  562. return result;
  563. }
  564. private SqlSelect VisitWhere(Expression sequence, LambdaExpression predicate) {
  565. SqlSelect select = this.LockSelect(this.VisitSequence(sequence));
  566. this.map[predicate.Parameters[0]] = (SqlAliasRef)select.Selection;
  567. select.Where = this.VisitExpression(predicate.Body);
  568. return select;
  569. }
  570. private SqlNode VisitAs(UnaryExpression a) {
  571. SqlNode node = this.Visit(a.Operand);
  572. SqlExpression expr = node as SqlExpression;
  573. if (expr != null) {
  574. return new SqlUnary(SqlNodeType.Treat, a.Type, typeProvider.From(a.Type), expr, a);
  575. }
  576. SqlSelect select = node as SqlSelect;
  577. if (select != null) {
  578. SqlSubSelect ms = sql.SubSelect(SqlNodeType.Multiset, select);
  579. return new SqlUnary(SqlNodeType.Treat, a.Type, typeProvider.From(a.Type), ms, a);
  580. }
  581. throw Error.DidNotExpectAs(a);
  582. }
  583. private SqlNode VisitArrayLength(UnaryExpression c) {
  584. SqlExpression exp = this.VisitExpression(c.Operand);
  585. if (exp.SqlType.IsString || exp.SqlType.IsChar) {
  586. return sql.CLRLENGTH(exp);
  587. }
  588. else {
  589. return sql.DATALENGTH(exp);
  590. }
  591. }
  592. private SqlNode VisitArrayIndex(BinaryExpression b) {
  593. SqlExpression array = this.VisitExpression(b.Left);
  594. SqlExpression index = this.VisitExpression(b.Right);
  595. if (array.NodeType == SqlNodeType.ClientParameter
  596. && index.NodeType == SqlNodeType.Value) {
  597. SqlClientParameter cpArray = (SqlClientParameter)array;
  598. SqlValue vIndex = (SqlValue)index;
  599. return new SqlClientParameter(
  600. b.Type, sql.TypeProvider.From(b.Type),
  601. Expression.Lambda(
  602. #pragma warning disable 618 // Disable the 'obsolete' warning
  603. Expression.ArrayIndex(cpArray.Accessor.Body, Expression.Constant(vIndex.Value, vIndex.ClrType)),
  604. #pragma warning restore 618
  605. cpArray.Accessor.Parameters.ToArray()
  606. ),
  607. this.dominatingExpression
  608. );
  609. }
  610. throw Error.UnrecognizedExpressionNode(b.NodeType);
  611. }
  612. private SqlNode VisitCast(UnaryExpression c) {
  613. if (c.Method != null) {
  614. SqlExpression exp = this.VisitExpression(c.Operand);
  615. return sql.MethodCall(c.Type, c.Method, null, new SqlExpression[] { exp }, dominatingExpression);
  616. }
  617. return this.VisitChangeType(c.Operand, c.Type);
  618. }
  619. private SqlNode VisitChangeType(Expression expression, Type type) {
  620. SqlExpression expr = this.VisitExpression(expression);
  621. return this.ChangeType(expr, type);
  622. }
  623. private SqlNode ConvertDateToDateTime2(SqlExpression expr) {
  624. SqlExpression datetime2 = new SqlVariable(expr.ClrType, expr.SqlType, "DATETIME2", expr.SourceExpression);
  625. return sql.FunctionCall(typeof(DateTime), "CONVERT", new SqlExpression[2] { datetime2, expr }, expr.SourceExpression);
  626. }
  627. private SqlNode ChangeType(SqlExpression expr, Type type) {
  628. if (type == typeof(object)) {
  629. return expr; // Boxing conversion?
  630. }
  631. else if (expr.NodeType == SqlNodeType.Value && ((SqlValue)expr).Value == null) {
  632. return sql.TypedLiteralNull(type, expr.SourceExpression);
  633. }
  634. else if (expr.NodeType == SqlNodeType.ClientParameter) {
  635. SqlClientParameter cp = (SqlClientParameter)expr;
  636. return new SqlClientParameter(
  637. type, sql.TypeProvider.From(type),
  638. Expression.Lambda(Expression.Convert(cp.Accessor.Body, type), cp.Accessor.Parameters.ToArray()),
  639. cp.SourceExpression
  640. );
  641. }
  642. ConversionMethod cm = ChooseConversionMethod(expr.ClrType, type);
  643. switch (cm) {
  644. case ConversionMethod.Convert:
  645. return sql.UnaryConvert(type, typeProvider.From(type), expr, expr.SourceExpression);
  646. case ConversionMethod.Lift:
  647. if (SqlFactory.IsSqlDateType(expr)) {
  648. expr = (SqlExpression) ConvertDateToDateTime2(expr);
  649. }
  650. return new SqlLift(type, expr, this.dominatingExpression);
  651. case ConversionMethod.Ignore:
  652. if (SqlFactory.IsSqlDateType(expr)) {
  653. return ConvertDateToDateTime2(expr);
  654. }
  655. return expr;
  656. case ConversionMethod.Treat:
  657. return new SqlUnary(SqlNodeType.Treat, type, typeProvider.From(type), expr, expr.SourceExpression);
  658. default:
  659. throw Error.UnhandledExpressionType(cm);
  660. }
  661. }
  662. enum ConversionMethod {
  663. Treat,
  664. Ignore,
  665. Convert,
  666. Lift
  667. }
  668. private ConversionMethod ChooseConversionMethod(Type fromType, Type toType) {
  669. Type nnFromType = TypeSystem.GetNonNullableType(fromType);
  670. Type nnToType = TypeSystem.GetNonNullableType(toType);
  671. if (fromType != toType && nnFromType == nnToType) {
  672. return ConversionMethod.Lift;
  673. }
  674. else if (TypeSystem.IsSequenceType(nnFromType) || TypeSystem.IsSequenceType(nnToType)) {
  675. return ConversionMethod.Ignore;
  676. }
  677. ProviderType sfromType = typeProvider.From(nnFromType);
  678. ProviderType stoType = typeProvider.From(nnToType);
  679. bool isRuntimeOnly1 = sfromType.IsRuntimeOnlyType;
  680. bool isRuntimeOnly2 = stoType.IsRuntimeOnlyType;
  681. if (isRuntimeOnly1 || isRuntimeOnly2) {
  682. return ConversionMethod.Treat;
  683. }
  684. if (nnFromType == nnToType // same non-nullable .NET types
  685. || (sfromType.IsString && sfromType.Equals(stoType)) // same SQL string types
  686. || (nnFromType.IsEnum || nnToType.IsEnum) // any .NET enum type
  687. ) {
  688. return ConversionMethod.Ignore;
  689. }
  690. else {
  691. return ConversionMethod.Convert;
  692. }
  693. }
  694. /// <summary>
  695. /// Convert ITable into SqlNodes. If the hierarchy involves inheritance then
  696. /// a type case is built. Abstractly, a type case is a CASE where each WHEN is a possible
  697. /// a typebinding that may be instantianted.
  698. /// </summary>
  699. private SqlNode TranslateConstantTable(ITable table, SqlLink link) {
  700. if (table.Context != this.services.Context) {
  701. throw Error.WrongDataContext();
  702. }
  703. MetaTable metaTable = this.services.Model.GetTable(table.ElementType);
  704. return this.translator.BuildDefaultQuery(metaTable.RowType, this.allowDeferred, link, this.dominatingExpression);
  705. }
  706. private SqlNode VisitLinkedTable(LinkedTableExpression linkedTable) {
  707. return TranslateConstantTable(linkedTable.Table, linkedTable.Link);
  708. }
  709. private SqlNode VisitConstant(ConstantExpression cons) {
  710. // A value constant or null.
  711. Type type = cons.Type;
  712. if (cons.Value == null) {
  713. return sql.TypedLiteralNull(type, this.dominatingExpression);
  714. }
  715. if (type == typeof(object)) {
  716. type = cons.Value.GetType();
  717. }
  718. return sql.ValueFromObject(cons.Value, type, true, this.dominatingExpression);
  719. }
  720. private SqlExpression VisitConditional(ConditionalExpression cond) {
  721. List<SqlWhen> whens = new List<SqlWhen>(1);
  722. whens.Add(new SqlWhen(this.VisitExpression(cond.Test), this.VisitExpression(cond.IfTrue)));
  723. SqlExpression @else = this.VisitExpression(cond.IfFalse);
  724. // combine search cases found in the else clause into a single seach case
  725. while (@else.NodeType == SqlNodeType.SearchedCase) {
  726. SqlSearchedCase sc = (SqlSearchedCase)@else;
  727. whens.AddRange(sc.Whens);
  728. @else = sc.Else;
  729. }
  730. return sql.SearchedCase(whens.ToArray(), @else, this.dominatingExpression);
  731. }
  732. private SqlExpression VisitNew(NewExpression qn) {
  733. if (TypeSystem.IsNullableType(qn.Type) && qn.Arguments.Count == 1 &&
  734. TypeSystem.GetNonNullableType(qn.Type) == qn.Arguments[0].Type) {
  735. return this.VisitCast(Expression.Convert(qn.Arguments[0], qn.Type)) as SqlExpression;
  736. }
  737. else if (qn.Type == typeof(decimal) && qn.Arguments.Count == 1) {
  738. return this.VisitCast(Expression.Convert(qn.Arguments[0], typeof(decimal))) as SqlExpression;
  739. }
  740. MetaType mt = this.services.Model.GetMetaType(qn.Type);
  741. if (mt.IsEntity) {
  742. throw Error.CannotMaterializeEntityType(qn.Type);
  743. }
  744. SqlExpression[] args = null;
  745. if (qn.Arguments.Count > 0) {
  746. args = new SqlExpression[qn.Arguments.Count];
  747. for (int i = 0, n = qn.Arguments.Count; i < n; i++) {
  748. args[i] = this.VisitExpression(qn.Arguments[i]);
  749. }
  750. }
  751. SqlNew tb = sql.New(mt, qn.Constructor, args, PropertyOrFieldOf(qn.Members), null, this.dominatingExpression);
  752. return tb;
  753. }
  754. private SqlExpression VisitMemberInit(MemberInitExpression init) {
  755. MetaType mt = this.services.Model.GetMetaType(init.Type);
  756. if (mt.IsEntity) {
  757. throw Error.CannotMaterializeEntityType(init.Type);
  758. }
  759. SqlExpression[] args = null;
  760. NewExpression qn = init.NewExpression;
  761. if (qn.Type == typeof(decimal) && qn.Arguments.Count == 1) {
  762. return this.VisitCast(Expression.Convert(qn.Arguments[0], typeof(decimal))) as SqlExpression;
  763. }
  764. if (qn.Arguments.Count > 0) {
  765. args = new SqlExpression[qn.Arguments.Count];
  766. for (int i = 0, n = args.Length; i < n; i++) {
  767. args[i] = this.VisitExpression(qn.Arguments[i]);
  768. }
  769. }
  770. int cBindings = init.Bindings.Count;
  771. SqlMemberAssign[] members = new SqlMemberAssign[cBindings];
  772. int[] ordinal = new int[members.Length];
  773. for (int i = 0; i < cBindings; i++) {
  774. MemberAssignment mb = init.Bindings[i] as MemberAssignment;
  775. if (mb != null) {
  776. SqlExpression expr = this.VisitExpression(mb.Expression);
  777. SqlMemberAssign sma = new SqlMemberAssign(mb.Member, expr);
  778. members[i] = sma;
  779. ordinal[i] = mt.GetDataMember(mb.Member).Ordinal;
  780. }
  781. else {
  782. throw Error.UnhandledBindingType(init.Bindings[i].BindingType);
  783. }
  784. }
  785. // put members in type's declaration order
  786. Array.Sort(ordinal, members, 0, members.Length);
  787. SqlNew tb = sql.New(mt, qn.Constructor, args, PropertyOrFieldOf(qn.Members), members, this.dominatingExpression);
  788. return tb;
  789. }
  790. private static IEnumerable<MemberInfo> PropertyOrFieldOf(IEnumerable<MemberInfo> members) {
  791. if (members == null) {
  792. return null;
  793. }
  794. List<MemberInfo> result = new List<MemberInfo>();
  795. foreach (MemberInfo mi in members) {
  796. switch (mi.MemberType) {
  797. case MemberTypes.Method: {
  798. foreach (PropertyInfo pi in mi.DeclaringType.GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic)) {
  799. MethodInfo method = mi as MethodInfo;
  800. if (pi.CanRead && pi.GetGetMethod() == method) {
  801. result.Add(pi);
  802. break;
  803. }
  804. }
  805. break;
  806. }
  807. case MemberTypes.Field:
  808. case MemberTypes.Property: {
  809. result.Add(mi);
  810. break;
  811. }
  812. default: {
  813. throw Error.CouldNotConvertToPropertyOrField(mi);
  814. }
  815. }
  816. }
  817. return result;
  818. }
  819. private SqlSelect VisitDistinct(Expression sequence) {
  820. SqlSelect select = this.LockSelect(this.VisitSequence(sequence));
  821. select.IsDistinct = true;
  822. select.OrderingType = SqlOrderingType.Blocked;
  823. return select;
  824. }
  825. private SqlSelect VisitTake(Expression sequence, Expression count) {
  826. // verify that count >= 0
  827. SqlExpression takeExp = this.VisitExpression(count);
  828. if (takeExp.NodeType == SqlNodeType.Value) {
  829. SqlValue constTakeCount = (SqlValue)takeExp;
  830. if (typeof(int).IsAssignableFrom(constTakeCount.Value.GetType()) && ((int)constTakeCount.Value) < 0) {
  831. throw Error.ArgumentOutOfRange("takeCount");
  832. }
  833. }
  834. MethodCallExpression mce = sequence as MethodCallExpression;
  835. if (mce != null && IsSequenceOperatorCall(mce) && mce.Method.Name == "Skip" && mce.Arguments.Count == 2) {
  836. SqlExpression skipExp = this.VisitExpression(mce.Arguments[1]);
  837. // verify that count >= 0
  838. if (skipExp.NodeType == SqlNodeType.Value) {
  839. SqlValue constSkipCount = (SqlValue)skipExp;
  840. if (typeof(int).IsAssignableFrom(constSkipCount.Value.GetType()) && ((int)constSkipCount.Value) < 0) {
  841. throw Error.ArgumentOutOfRange("skipCount");
  842. }
  843. }
  844. SqlSelect select = this.VisitSequence(mce.Arguments[0]);
  845. return this.GenerateSkipTake(select, skipExp, takeExp);
  846. }
  847. else {
  848. SqlSelect select = this.VisitSequence(sequence);
  849. return this.GenerateSkipTake(select, null, takeExp);
  850. }
  851. }
  852. /// <summary>
  853. /// In order for elements of a sequence to be skipped, they must have identity
  854. /// that can be compared. This excludes elements that are sequences and elements
  855. /// that contain sequences.
  856. /// </summary>
  857. private bool CanSkipOnSelection(SqlExpression selection) {
  858. // we can skip over groupings (since we can compare them by key)
  859. if (IsGrouping(selection.ClrType)) {
  860. return true;
  861. }
  862. // we can skip over entities (since we can compare them by primary key)
  863. MetaTable table = this.services.Model.GetTable(selection.ClrType);
  864. if (table != null) {
  865. return true;
  866. }
  867. // sequences that are not primitives are not skippable
  868. if (TypeSystem.IsSequenceType(selection.ClrType) && !selection.SqlType.CanBeColumn) {
  869. return false;
  870. }
  871. switch (selection.NodeType) {
  872. case SqlNodeType.AliasRef: {
  873. SqlNode node = ((SqlAliasRef)selection).Alias.Node;
  874. SqlSelect select = node as SqlSelect;
  875. if (select != null) {
  876. return CanSkipOnSelection(select.Selection);
  877. }
  878. SqlUnion union = node as SqlUnion;
  879. if (union != null) {
  880. bool left = default(bool);
  881. bool right = default(bool);
  882. SqlSelect selectLeft = union.Left as SqlSelect;
  883. if (selectLeft != null) {
  884. left = CanSkipOnSelection(selectLeft.Selection);
  885. }
  886. SqlSelect selectRight = union.Right as SqlSelect;
  887. if (selectRight != null) {
  888. right = CanSkipOnSelection(selectRight.Selection);
  889. }
  890. return left && right;
  891. }
  892. SqlExpression expr = (SqlExpression)node;
  893. return CanSkipOnSelection(expr);
  894. }
  895. case SqlNodeType.New:
  896. SqlNew sn = (SqlNew)selection;
  897. // check each member of the projection for sequences
  898. foreach (SqlMemberAssign ma in sn.Members) {
  899. if (!CanSkipOnSelection(ma.Expression))
  900. return false;
  901. }
  902. if (sn.ArgMembers != null) {
  903. for (int i = 0, n = sn.ArgMembers.Count; i < n; ++i) {
  904. if (!CanSkipOnSelection(sn.Args[i])) {
  905. return false;
  906. }
  907. }
  908. }
  909. break;
  910. }
  911. return true;
  912. }
  913. /// <summary>
  914. /// SQL2000:
  915. /// SELECT *
  916. /// FROM sequence
  917. /// WHERE NOT EXISTS (
  918. /// SELECT TOP count *
  919. /// FROM sequence)
  920. ///
  921. /// SQL2005: SELECT *
  922. /// FROM (SELECT sequence.*,
  923. /// ROW_NUMBER() OVER (ORDER BY order) AS ROW_NUMBER
  924. /// FROM sequence)
  925. /// WHERE ROW_NUMBER > count
  926. /// </summary>
  927. /// <param name="sequence">Sequence containing elements to skip</param>
  928. /// <param name="count">Number of elements to skip</param>
  929. /// <returns>SELECT node</returns>
  930. private SqlSelect VisitSkip(Expression sequence, Expression skipCount) {
  931. SqlExpression skipExp = this.VisitExpression(skipCount);
  932. // verify that count >= 0
  933. if (skipExp.NodeType == SqlNodeType.Value) {
  934. SqlValue constSkipCount = (SqlValue)skipExp;
  935. if (typeof(int).IsAssignableFrom(constSkipCount.Value.GetType()) && ((int)constSkipCount.Value) < 0) {
  936. throw Error.ArgumentOutOfRange("skipCount");
  937. }
  938. }
  939. SqlSelect select = this.VisitSequence(sequence);
  940. return this.GenerateSkipTake(select, skipExp, null);
  941. }
  942. private SqlSelect GenerateSkipTake(SqlSelect sequence, SqlExpression skipExp, SqlExpression takeExp) {
  943. SqlSelect select = this.LockSelect(sequence);
  944. // no skip?
  945. if (skipExp == null) {
  946. if (takeExp != null) {
  947. select.Top = takeExp;
  948. }
  949. return select;
  950. }
  951. SqlAlias alias = new SqlAlias(select);
  952. SqlAliasRef aref = new SqlAliasRef(alias);
  953. if (this.UseConverterStrategy(ConverterStrategy.SkipWithRowNumber)) {
  954. // use ROW_NUMBER() (preferred)
  955. SqlColumn rowNumber = new SqlColumn("ROW_NUMBER", sql.RowNumber(new List<SqlOrderExpression>(), this.dominatingExpression));
  956. SqlColumnRef rowNumberRef = new SqlColumnRef(rowNumber);
  957. select.Row.Columns.Add(rowNumber);
  958. SqlSelect final = new SqlSelect(aref, alias, this.dominatingExpression);
  959. if (takeExp != null) {
  960. // use BETWEEN for skip+take combo (much faster)
  961. final.Where = sql.Between(
  962. rowNumberRef,
  963. sql.Add(skipExp, 1),
  964. sql.Binary(SqlNodeType.Add, (SqlExpression)SqlDuplicator.Copy(skipExp), takeExp),
  965. this.dominatingExpression
  966. );
  967. }
  968. else {
  969. final.Where = sql.Binary(SqlNodeType.GT, rowNumberRef, skipExp);
  970. }
  971. return final;
  972. }
  973. else {
  974. // Ensure that the sequence contains elements that can be skipped
  975. if (!CanSkipOnSelection(select.Selection)) {
  976. throw Error.SkipNotSupportedForSequenceTypes();
  977. }
  978. // use NOT EXISTS
  979. // Supported cases:
  980. // - Entities
  981. // - Projections that contain all PK columns
  982. //
  983. // .. where there sequence can be traced back to a:
  984. // - Single-table query
  985. // - Distinct
  986. // - Except
  987. // - Intersect
  988. // - Union, where union.All == false
  989. // Not supported: joins
  990. // Sequence should also be ordered, but we can't test for it at this
  991. // point in processing, and we won't know that we need to test it, later.
  992. SingleTableQueryVisitor stqv = new SingleTableQueryVisitor();
  993. stqv.Visit(select);
  994. if (!stqv.IsValid) {
  995. throw Error.SkipRequiresSingleTableQueryWithPKs();
  996. }
  997. SqlSelect dupsel = (SqlSelect)SqlDuplicator.Copy(select);
  998. dupsel.Top = skipExp;
  999. SqlAlias dupAlias = new SqlAlias(dupsel);
  1000. SqlAliasRef dupRef = new SqlAliasRef(dupAlias);
  1001. SqlSelect eqsel = new SqlSelect(dupRef, dupAlias, this.dominatingExpression);
  1002. eqsel.Where = sql.Binary(SqlNodeType.EQ2V, aref, dupRef);
  1003. SqlSubSelect ss = sql.SubSelect(SqlNodeType.Exists, eqsel);
  1004. SqlSelect final = new SqlSelect(aref, alias, this.dominatingExpression);
  1005. final.Where = sql.Unary(SqlNodeType.Not, ss, this.dominatingExpression);
  1006. final.Top = takeExp;
  1007. return final;
  1008. }
  1009. }
  1010. private SqlNode VisitParameter(ParameterExpression p) {
  1011. SqlExpression sqlExpr;
  1012. if (this.map.TryGetValue(p, out sqlExpr))
  1013. return sqlExpr;
  1014. Expression expr;
  1015. if (this.exprMap.TryGetValue(p, out expr))
  1016. return this.Visit(expr);
  1017. SqlNode nodeToDup;
  1018. if (this.dupMap.TryGetValue(p, out nodeToDup)) {
  1019. SqlDuplicator duplicator = new SqlDuplicator(true);
  1020. return duplicator.Duplicate(nodeToDup);
  1021. }
  1022. throw Error.ParameterNotInScope(p.Name);
  1023. }
  1024. /// <summary>
  1025. /// Translate a call to a table valued function expression into a sql select.
  1026. /// </summary>
  1027. private SqlNode TranslateTableValuedFunction(MethodCallExpression mce, MetaFunction function) {
  1028. // translate method call into sql function call
  1029. List<SqlExpression> sqlParams = GetFunctionParameters(mce, function);
  1030. SqlTableValuedFunctionCall functionCall = sql.TableValuedFunctionCall(function.ResultRowTypes[0].InheritanceRoot, mce.Method.ReturnType, function.MappedName, sqlParams, mce);
  1031. SqlAlias alias = new SqlAlias(functionCall);
  1032. SqlAliasRef aref = new SqlAliasRef(alias);
  1033. // Build default projection
  1034. SqlExpression projection = this.translator.BuildProjection(aref, function.ResultRowTypes[0].InheritanceRoot, this.allowDeferred, null, mce);
  1035. SqlSelect select = new SqlSelect(projection, alias, mce);
  1036. return select;
  1037. }
  1038. /// <summary>
  1039. /// Translate a call to a stored procedure
  1040. /// </summary>
  1041. private SqlNode TranslateStoredProcedureCall(MethodCallExpression mce, MetaFunction function) {
  1042. if (!this.outerNode) {
  1043. throw Error.SprocsCannotBeComposed();
  1044. }
  1045. // translate method call into sql function call
  1046. List<SqlExpression> sqlParams = GetFunctionParameters(mce, function);
  1047. SqlStoredProcedureCall spc = new SqlStoredProcedureCall(function, null, sqlParams, mce);
  1048. Type returnType = mce.Method.ReturnType;
  1049. if (returnType.IsGenericType &&
  1050. (returnType.GetGenericTypeDefinition() == typeof(IEnumerable<>) ||
  1051. returnType.GetGenericTypeDefinition() == typeof(ISingleResult<>))) {
  1052. // Since this is a single rowset returning sproc, we use the one
  1053. // and only root metatype.
  1054. MetaType rowType = function.ResultRowTypes[0].InheritanceRoot;
  1055. SqlUserRow rowExp = new SqlUserRow(rowType, this.typeProvider.GetApplicationType((int)ConverterSpecialTypes.Row), spc, mce);
  1056. spc.Projection = this.translator.BuildProjection(rowExp, rowType, this.allowDeferred, null, mce);
  1057. }
  1058. else if (!(
  1059. typeof(IMultipleResults).IsAssignableFrom(returnType)
  1060. || returnType == typeof(int)
  1061. || returnType == typeof(int?)
  1062. )) {
  1063. throw Error.InvalidReturnFromSproc(returnType);
  1064. }
  1065. return spc;
  1066. }
  1067. /// <summary>
  1068. /// Create a list of sql parameters for the specified method call expression,
  1069. /// taking into account any explicit typing applied to the parameters via the
  1070. /// Parameter attribute.
  1071. /// </summary>
  1072. private List<SqlExpression> GetFunctionParameters(MethodCallExpression mce, MetaFunction function) {
  1073. List<SqlExpression> sqlParams = new List<SqlExpression>(mce.Arguments.Count);
  1074. // create sql parameters for each method parameter
  1075. for (int i = 0, n = mce.Arguments.Count; i < n; i++) {
  1076. SqlExpression newParamExpression = this.VisitExpression(mce.Arguments[i]);
  1077. // If the parameter explicitly specifies a type in metadata,
  1078. // use it as the provider type.
  1079. MetaParameter currMetaParam = function.Parameters[i];
  1080. if (!string.IsNullOrEmpty(currMetaParam.DbType)) {
  1081. SqlSimpleTypeExpression typeExpression = newParamExpression as SqlSimpleTypeExpression;
  1082. if (typeExpression != null) {
  1083. // determine provider type, and update the parameter expression
  1084. ProviderType providerType = typeProvider.Parse(currMetaParam.DbType);
  1085. typeExpression.SetSqlType(providerType);
  1086. }
  1087. }
  1088. sqlParams.Add(newParamExpression);
  1089. }
  1090. return sqlParams;
  1091. }
  1092. private SqlUserQuery VisitUserQuery(string query, Expression[] arguments, Type resultType) {
  1093. SqlExpression[] args = new SqlExpression[arguments.Length];
  1094. for (int i = 0, n = args.Length; i < n; i++) {
  1095. args[i] = this.VisitExpression(arguments[i]);
  1096. }
  1097. SqlUserQuery suq = new SqlUserQuery(query, null, args, this.dominatingExpression);
  1098. if (resultType != typeof(void)) {
  1099. Type elementType = TypeSystem.GetElementType(resultType);
  1100. MetaType mType = this.services.Model.GetMetaType(elementType);
  1101. // if the element type is a simple type (int, bool, etc.) we create
  1102. // a single column binding
  1103. if (TypeSystem.IsSimpleType(elementType)) {
  1104. SqlUserColumn col = new SqlUserColumn(elementType, typeProvider.From(elementType), suq, "", false, this.dominatingExpression);
  1105. suq.Columns.Add(col);
  1106. suq.Projection = col;
  1107. }
  1108. else {
  1109. // ... otherwise we generate a default projection
  1110. SqlUserRow rowExp = new SqlUserRow(mType.InheritanceRoot, this.typeProvider.GetApplicationType((int)ConverterSpecialTypes.Row), suq, this.dominatingExpression);
  1111. suq.Projection = this.translator.BuildProjection(rowExp, mType, this.allowDeferred, null, this.dominatingExpression);
  1112. }
  1113. }
  1114. return suq;
  1115. }
  1116. private SqlNode VisitUnary(UnaryExpression u) {
  1117. SqlExpression exp = this.VisitExpression(u.Operand);
  1118. if (u.Method != null) {
  1119. return sql.MethodCall(u.Type, u.Method, null, new SqlExpression[] { exp }, dominatingExpression);
  1120. }
  1121. SqlExpression result = null;
  1122. switch (u.NodeType) {
  1123. case ExpressionType.Negate:
  1124. case ExpressionType.NegateChecked:
  1125. result = sql.Unary(SqlNodeType.Negate, exp, this.dominatingExpression);
  1126. break;
  1127. case ExpressionType.Not:
  1128. if (u.Operand.Type == typeof(bool) || u.Operand.Type == typeof(bool?)) {
  1129. result = sql.Unary(SqlNodeType.Not, exp, this.dominatingExpression);
  1130. }
  1131. else {
  1132. result = sql.Unary(SqlNodeType.BitNot, exp, this.dominatingExpression);
  1133. }
  1134. break;
  1135. case ExpressionType.TypeAs:
  1136. result = sql.Unary(SqlNodeType.Treat, exp, this.dominatingExpression);
  1137. break;
  1138. }
  1139. return result;
  1140. }
  1141. [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
  1142. private SqlNode VisitBinary(BinaryExpression b) {
  1143. SqlExpression left = this.VisitExpression(b.Left);
  1144. SqlExpression right = this.VisitExpression(b.Right);
  1145. if (b.Method != null) {
  1146. return sql.MethodCall(b.Type, b.Method, null, new SqlExpression[] { left, right }, dominatingExpression);
  1147. }
  1148. SqlExpression result = null;
  1149. switch (b.NodeType) {
  1150. case ExpressionType.Add:
  1151. case ExpressionType.AddChecked:
  1152. result = sql.Binary(SqlNodeType.Add, left, right, b.Type);
  1153. break;
  1154. case ExpressionType.Subtract:
  1155. case ExpressionType.SubtractChecked:
  1156. result = sql.Binary(SqlNodeType.Sub, left, right, b.Type);
  1157. break;
  1158. case ExpressionType.Multiply:
  1159. case ExpressionType.MultiplyChecked:
  1160. result = sql.Binary(SqlNodeType.Mul, left, right, b.Type);
  1161. break;
  1162. case ExpressionType.Divide:
  1163. result = sql.Binary(SqlNodeType.Div, left, right, b.Type);
  1164. break;
  1165. case ExpressionType.Modulo:
  1166. result = sql.Binary(SqlNodeType.Mod, left, right, b.Type);
  1167. break;
  1168. case ExpressionType.And:
  1169. if (b.Left.Type == typeof(bool) || b.Left.Type == typeof(bool?)) {
  1170. result = sql.Binary(SqlNodeType.And, left, right, b.Type);
  1171. }
  1172. else {
  1173. result = sql.Binary(SqlNodeType.BitAnd, left, right, b.Type);
  1174. }
  1175. break;
  1176. case ExpressionType.AndAlso:
  1177. result = sql.Binary(SqlNodeType.And, left, right, b.Type);
  1178. break;
  1179. case ExpressionType.Or:
  1180. if (b.Left.Type == typeof(bool) || b.Left.Type == typeof(bool?)) {
  1181. result = sql.Binary(SqlNodeType.Or, left, right, b.Type);
  1182. }
  1183. else {
  1184. result = sql.Binary(SqlNodeType.BitOr, left, right, b.Type);
  1185. }
  1186. break;
  1187. case ExpressionType.OrElse:
  1188. result = sql.Binary(SqlNodeType.Or, left, right, b.Type);
  1189. break;
  1190. case ExpressionType.LessThan:
  1191. result = sql.Binary(SqlNodeType.LT, left, right, b.Type);
  1192. break;
  1193. case ExpressionType.LessThanOrEqual:
  1194. result = sql.Binary(SqlNodeType.LE, left, right, b.Type);
  1195. break;
  1196. case ExpressionType.GreaterThan:
  1197. result = sql.Binary(SqlNodeType.GT, left, right, b.Type);
  1198. break;
  1199. case ExpressionType.GreaterThanOrEqual:
  1200. result = sql.Binary(SqlNodeType.GE, left, right, b.Type);
  1201. break;
  1202. case ExpressionType.Equal:
  1203. result = sql.Binary(SqlNodeType.EQ, left, right, b.Type);
  1204. break;
  1205. case ExpressionType.NotEqual:
  1206. result = sql.Binary(SqlNodeType.NE, left, right, b.Type);
  1207. break;
  1208. case ExpressionType.ExclusiveOr:
  1209. result = sql.Binary(SqlNodeType.BitXor, left, right, b.Type);
  1210. break;
  1211. case ExpressionType.Coalesce:
  1212. result = this.MakeCoalesce(left, right, b.Type);
  1213. break;
  1214. default:
  1215. throw Error.BinaryOperatorNotRecognized(b.NodeType);
  1216. }
  1217. return result;
  1218. }
  1219. private SqlExpression MakeCoalesce(SqlExpression left, SqlExpression right, Type resultType) {
  1220. CompensateForLowerPrecedenceOfDateType(ref left, ref right); // DevDiv 176874
  1221. if (TypeSystem.IsSimpleType(resultType)) {
  1222. return sql.Binary(SqlNodeType.Coalesce, left, right, resultType);
  1223. }
  1224. else {
  1225. List<SqlWhen> whens = new List<SqlWhen>(1);
  1226. whens.Add(new SqlWhen(sql.Unary(SqlNodeType.IsNull, left, left.SourceExpression), right));
  1227. SqlDuplicator dup = new SqlDuplicator(true);
  1228. return sql.SearchedCase(whens.ToArray(), (SqlExpression)dup.Duplicate(left), this.dominatingExpression);
  1229. }
  1230. }
  1231. // The result *type* of a COALESCE function call is that of the operand with the highest precedence.
  1232. // However, the SQL DATE type has a lower precedence than DATETIME or SMALLDATETIME, despite having
  1233. // a hihger range. The following logic compensates for that discrepancy.
  1234. //
  1235. private void CompensateForLowerPrecedenceOfDateType(ref SqlExpression left, ref SqlExpression right) {
  1236. if (SqlFactory.IsSqlDateType(left) && SqlFactory.IsSqlDateTimeType(right)) {
  1237. right = (SqlExpression)ConvertDateToDateTime2(right);
  1238. }
  1239. else if (SqlFactory.IsSqlDateType(right) && SqlFactory.IsSqlDateTimeType(left)) {
  1240. left = (SqlExpression)ConvertDateToDateTime2(left);
  1241. }
  1242. }
  1243. private SqlNode VisitConcat(Expression source1, Expression source2) {
  1244. SqlSelect left = this.VisitSequence(source1);
  1245. SqlSelect right = this.VisitSequence(source2);
  1246. SqlUnion union = new SqlUnion(left, right, true);
  1247. SqlAlias alias = new SqlAlias(union);
  1248. SqlAliasRef aref = new SqlAliasRef(alias);
  1249. SqlSelect result = new SqlSelect(aref, alias, this.dominatingExpression);
  1250. result.OrderingType = SqlOrderingType.Blocked;
  1251. return result;
  1252. }
  1253. private SqlNode VisitUnion(Expression source1, Expression source2) {
  1254. SqlSelect left = this.VisitSequence(source1);
  1255. SqlSelect right = this.VisitSequence(source2);
  1256. SqlUnion union = new SqlUnion(left, right, false);
  1257. SqlAlias alias = new SqlAlias(union);
  1258. SqlAliasRef aref = new SqlAliasRef(alias);
  1259. SqlSelect result = new SqlSelect(aref, alias, this.dominatingExpression);
  1260. result.OrderingType = SqlOrderingType.Blocked;
  1261. return result;
  1262. }
  1263. private SqlNode VisitIntersect(Expression source1, Expression source2) {
  1264. Type type = TypeSystem.GetElementType(source1.Type);
  1265. if (IsGrouping(type)) {
  1266. throw Error.IntersectNotSupportedForHierarchicalTypes();
  1267. }
  1268. SqlSelect select1 = this.LockSelect(this.VisitSequence(source1));
  1269. SqlSelect select2 = this.VisitSequence(source2);
  1270. SqlAlias alias1 = new SqlAlias(select1);
  1271. SqlAliasRef aref1 = new SqlAliasRef(alias1);
  1272. SqlAlias alias2 = new SqlAlias(select2);
  1273. SqlAliasRef aref2 = new SqlAliasRef(alias2);
  1274. SqlExpression any = this.GenerateQuantifier(alias2, sql.Binary(SqlNodeType.EQ2V, aref1, aref2), true);
  1275. SqlSelect result = new SqlSelect(aref1, alias1, select1.SourceExpression);
  1276. result.Where = any;
  1277. result.IsDistinct = true;
  1278. result.OrderingType = SqlOrderingType.Blocked;
  1279. return result;
  1280. }
  1281. private SqlNode VisitExcept(Expression source1, Expression source2) {
  1282. Type type = TypeSystem.GetElementType(source1.Type);
  1283. if (IsGrouping(type)) {
  1284. throw Error.ExceptNotSupportedForHierarchicalTypes();
  1285. }
  1286. SqlSelect select1 = this.LockSelect(this.VisitSequence(source1));
  1287. SqlSelect select2 = this.VisitSequence(source2);
  1288. SqlAlias alias1 = new SqlAlias(select1);
  1289. SqlAliasRef aref1 = new SqlAliasRef(alias1);
  1290. SqlAlias alias2 = new SqlAlias(select2);
  1291. SqlAliasRef aref2 = new SqlAliasRef(alias2);
  1292. SqlExpression any = this.GenerateQuantifier(alias2, sql.Binary(SqlNodeType.EQ2V, aref1, aref2), true);
  1293. SqlSelect result = new SqlSelect(aref1, alias1, select1.SourceExpression);
  1294. result.Where = sql.Unary(SqlNodeType.Not, any);
  1295. result.IsDistinct = true;
  1296. result.OrderingType = SqlOrderingType.Blocked;
  1297. return result;
  1298. }
  1299. /// <summary>
  1300. /// Returns true if the type is an IGrouping.
  1301. /// </summary>
  1302. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  1303. private bool IsGrouping(Type t) {
  1304. if (t.IsGenericType &&
  1305. t.GetGenericTypeDefinition() == typeof(IGrouping<,>))
  1306. return true;
  1307. return false;
  1308. }
  1309. private SqlSelect VisitOrderBy(Expression sequence, LambdaExpression expression, SqlOrderType orderType) {
  1310. if (IsGrouping(expression.Body.Type)) {
  1311. throw Error.GroupingNotSupportedAsOrderCriterion();
  1312. }
  1313. if (!this.typeProvider.From(expression.Body.Type).IsOrderable) {
  1314. throw Error.TypeCannotBeOrdered(expression.Body.Type);
  1315. }
  1316. SqlSelect select = this.LockSelect(this.VisitSequence(sequence));
  1317. if (select.Selection.NodeType != SqlNodeType.AliasRef || select.OrderBy.Count > 0) {
  1318. SqlAlias alias = new SqlAlias(select);
  1319. SqlAliasRef aref = new SqlAliasRef(alias);
  1320. select = new SqlSelect(aref, alias, this.dominatingExpression);
  1321. }
  1322. this.map[expression.Parameters[0]] = (SqlAliasRef)select.Selection;
  1323. SqlExpression expr = this.VisitExpression(expression.Body);
  1324. select.OrderBy.Add(new SqlOrderExpression(orderType, expr));
  1325. return select;
  1326. }
  1327. private SqlSelect VisitThenBy(Expression sequence, LambdaExpression expression, SqlOrderType orderType) {
  1328. if (IsGrouping(expression.Body.Type)) {
  1329. throw Error.GroupingNotSupportedAsOrderCriterion();
  1330. }
  1331. if (!this.typeProvider.From(expression.Body.Type).IsOrderable) {
  1332. throw Error.TypeCannotBeOrdered(expression.Body.Type);
  1333. }
  1334. SqlSelect select = this.VisitSequence(sequence);
  1335. System.Diagnostics.Debug.Assert(select.Selection.NodeType == SqlNodeType.AliasRef);
  1336. this.map[expression.Parameters[0]] = (SqlAliasRef)select.Selection;
  1337. SqlExpression expr = this.VisitExpression(expression.Body);
  1338. select.OrderBy.Add(new SqlOrderExpression(orderType, expr));
  1339. return select;
  1340. }
  1341. private SqlNode VisitGroupBy(Expression sequence, LambdaExpression keyLambda, LambdaExpression elemLambda, LambdaExpression resultSelector) {
  1342. // Convert seq.Group(elem, key) into
  1343. //
  1344. // SELECT s.key, MULTISET(select s2.elem from seq AS s2 where s.key == s2.key)
  1345. // FROM seq AS s
  1346. //
  1347. // where key and elem can be either simple scalars or object constructions
  1348. //
  1349. SqlSelect seq = this.VisitSequence(sequence);
  1350. seq = this.LockSelect(seq);
  1351. SqlAlias seqAlias = new SqlAlias(seq);
  1352. SqlAliasRef seqAliasRef = new SqlAliasRef(seqAlias);
  1353. // evaluate the key expression relative to original sequence
  1354. this.map[keyLambda.Parameters[0]] = seqAliasRef;
  1355. SqlExpression keyExpr = this.VisitExpression(keyLambda.Body);
  1356. // make a duplicate of the original sequence to use as a foundation of our group multiset
  1357. SqlDuplicator sd = new SqlDuplicator();
  1358. SqlSelect selDup = (SqlSelect)sd.Duplicate(seq);
  1359. // rebind key in relative to the duplicate sequence
  1360. SqlAlias selDupAlias = new SqlAlias(selDup);
  1361. SqlAliasRef selDupRef = new SqlAliasRef(selDupAlias);
  1362. this.map[keyLambda.Parameters[0]] = selDupRef;
  1363. SqlExpression keyDup = this.VisitExpression(keyLambda.Body);
  1364. SqlExpression elemExpr = null;
  1365. SqlExpression elemOnGroupSource = null;
  1366. if (elemLambda != null) {
  1367. // evaluate element expression relative to the duplicate sequence
  1368. this.map[elemLambda.Parameters[0]] = selDupRef;
  1369. elemExpr = this.VisitExpression(elemLambda.Body);
  1370. // evaluate element expression relative to original sequence
  1371. this.map[elemLambda.Parameters[0]] = seqAliasRef;
  1372. elemOnGroupSource = this.VisitExpression(elemLambda.Body);
  1373. }
  1374. else {
  1375. // no elem expression supplied, so just use an alias ref to the duplicate sequence.
  1376. // this will resolve to whatever was being produced by the sequence
  1377. elemExpr = selDupRef;
  1378. elemOnGroupSource = seqAliasRef;
  1379. }
  1380. // Make a sub expression out of the key. This will allow a single definition of the
  1381. // expression to be shared at multiple points in the tree (via SqlSharedExpressionRef's)
  1382. SqlSharedExpression keySubExpr = new SqlSharedExpression(keyExpr);
  1383. keyExpr = new SqlSharedExpressionRef(keySubExpr);
  1384. // construct the select clause that picks out the elements (this may be redundant...)
  1385. SqlSelect selElem = new SqlSelect(elemExpr, selDupAlias, this.dominatingExpression);
  1386. selElem.Where = sql.Binary(SqlNodeType.EQ2V, keyExpr, keyDup);
  1387. // Finally, make the MULTISET node. this will be used as part of the final select
  1388. SqlSubSelect ss = sql.SubSelect(SqlNodeType.Multiset, selElem);
  1389. // add a layer to the original sequence before applying the actual group-by clause
  1390. SqlSelect gsel = new SqlSelect(new SqlSharedExpressionRef(keySubExpr), seqAlias, this.dominatingExpression);
  1391. gsel.GroupBy.Add(keySubExpr);
  1392. SqlAlias gselAlias = new SqlAlias(gsel);
  1393. SqlSelect result = null;
  1394. if (resultSelector != null) {
  1395. // Create final select to include construction of group multiset
  1396. // select new Grouping { Key = key, Group = Multiset(select elem from seq where match) } from ...
  1397. Type elementType = typeof(IGrouping<,>).MakeGenericType(keyExpr.ClrType, elemExpr.ClrType);
  1398. SqlExpression keyGroup = new SqlGrouping(elementType, this.typeProvider.From(elementType), keyExpr, ss, this.dominatingExpression);
  1399. SqlSelect keyGroupSel = new SqlSelect(keyGroup, gselAlias, this.dominatingExpression);
  1400. SqlAlias kgAlias = new SqlAlias(keyGroupSel);
  1401. SqlAliasRef kgAliasRef = new SqlAliasRef(kgAlias);
  1402. this.map[resultSelector.Parameters[0]] = sql.Member(kgAliasRef, elementType.GetProperty("Key"));
  1403. this.map[resultSelector.Parameters[1]] = kgAliasRef;
  1404. // remember the select that has the actual group (for optimizing aggregates later)
  1405. this.gmap[kgAliasRef] = new GroupInfo { SelectWithGroup = gsel, ElementOnGroupSource = elemOnGroupSource };
  1406. SqlExpression resultExpr = this.VisitExpression(resultSelector.Body);
  1407. result = new SqlSelect(resultExpr, kgAlias, this.dominatingExpression);
  1408. // remember the select that has the actual group (for optimizing aggregates later)
  1409. this.gmap[resultExpr] = new GroupInfo { SelectWithGroup = gsel, ElementOnGroupSource = elemOnGroupSource };
  1410. }
  1411. else {
  1412. // Create final select to include construction of group multiset
  1413. // select new Grouping { Key = key, Group = Multiset(select elem from seq where match) } from ...
  1414. Type elementType = typeof(IGrouping<,>).MakeGenericType(keyExpr.ClrType, elemExpr.ClrType);
  1415. SqlExpression resultExpr = new SqlGrouping(elementType, this.typeProvider.From(elementType), keyExpr, ss, this.dominatingExpression);
  1416. result = new SqlSelect(resultExpr, gselAlias, this.dominatingExpression);
  1417. // remember the select that has the actual group (for optimizing aggregates later)
  1418. this.gmap[resultExpr] = new GroupInfo { SelectWithGroup = gsel, ElementOnGroupSource = elemOnGroupSource };
  1419. }
  1420. return result;
  1421. }
  1422. [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification="These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
  1423. private SqlNode VisitAggregate(Expression sequence, LambdaExpression lambda, SqlNodeType aggType, Type returnType) {
  1424. // Convert seq.Agg(exp) into
  1425. //
  1426. // 1) SELECT Agg(exp) FROM seq
  1427. // 2) SELECT Agg1 FROM (SELECT Agg(exp) as Agg1 FROM group-seq GROUP BY ...)
  1428. // 3) SCALAR(SELECT Agg(exp) FROM seq)
  1429. //
  1430. bool isCount = aggType == SqlNodeType.Count || aggType == SqlNodeType.LongCount;
  1431. SqlNode source = this.Visit(sequence);
  1432. SqlSelect select = this.CoerceToSequence(source);
  1433. SqlAlias alias = new SqlAlias(select);
  1434. SqlAliasRef aref = new SqlAliasRef(alias);
  1435. // If the sequence is of the form x.Select(expr).Agg() and the lambda for the aggregate is null,
  1436. // or is a no-op parameter expression (like u=>u), clone the group by selection lambda
  1437. // expression, and use for the aggregate.
  1438. // Final form should be x.Agg(expr)
  1439. MethodCallExpression mce = sequence as MethodCallExpression;
  1440. if (!outerNode && !isCount && (lambda == null || (lambda.Parameters.Count == 1 && lambda.Parameters[0] == lambda.Body)) &&
  1441. (mce != null) && IsSequenceOperatorCall(mce, "Select") && select.From is SqlAlias) {
  1442. LambdaExpression selectionLambda = GetLambda(mce.Arguments[1]);
  1443. lambda = Expression.Lambda(selectionLambda.Type, selectionLambda.Body, selectionLambda.Parameters);
  1444. alias = (SqlAlias)select.From;
  1445. aref = new SqlAliasRef(alias);
  1446. }
  1447. if (lambda != null && !TypeSystem.IsSimpleType(lambda.Body.Type)) {
  1448. throw Error.CannotAggregateType(lambda.Body.Type);
  1449. }
  1450. //Empty parameter aggregates are not allowed on anonymous types
  1451. //i.e. db.Customers.Select(c=>new{c.Age}).Max() instead it should be
  1452. // db.Customers.Select(c=>new{c.Age}).Max(c=>c.Age)
  1453. if (select.Selection.SqlType.IsRuntimeOnlyType && !IsGrouping(sequence.Type) && !isCount && lambda == null) {
  1454. throw Error.NonCountAggregateFunctionsAreNotValidOnProjections(aggType);
  1455. }
  1456. if (lambda != null)
  1457. this.map[lambda.Parameters[0]] = aref;
  1458. if (this.outerNode) {
  1459. // If this aggregate is basically the last/outer-most operator of the query
  1460. //
  1461. // produce SELECT Agg(exp) FROM seq
  1462. //
  1463. SqlExpression exp = (lambda != null) ? this.VisitExpression(lambda.Body) : null;
  1464. SqlExpression where = null;
  1465. if (isCount && exp != null) {
  1466. where = exp;
  1467. exp = null;
  1468. }
  1469. else if (exp == null && !isCount) {
  1470. exp = aref;
  1471. }
  1472. if (exp != null) {
  1473. // in case this contains another aggregate
  1474. exp = new SqlSimpleExpression(exp);
  1475. }
  1476. SqlSelect sel = new SqlSelect(
  1477. this.GetAggregate(aggType, returnType, exp),
  1478. alias,
  1479. this.dominatingExpression
  1480. );
  1481. sel.Where = where;
  1482. sel.OrderingType = SqlOrderingType.Never;
  1483. return sel;
  1484. }
  1485. else if (!isCount || lambda == null) {
  1486. // Look to optimize aggregate by pushing its evaluation down to the select node that has the
  1487. // actual group-by operator.
  1488. //
  1489. // Produce: SELECT Agg1 FROM (SELECT Agg(exp) as Agg1 FROM seq GROUP BY ...)
  1490. //
  1491. GroupInfo info = this.FindGroupInfo(source);
  1492. if (info != null) {
  1493. SqlExpression exp = null;
  1494. if (lambda != null) {
  1495. // evaluate expression relative to the group-by select node
  1496. this.map[lambda.Parameters[0]] = (SqlExpression)SqlDuplicator.Copy(info.ElementOnGroupSource);
  1497. exp = this.VisitExpression(lambda.Body);
  1498. } else if (!isCount) {
  1499. // support aggregates w/o an explicit selector specified
  1500. exp = info.ElementOnGroupSource;
  1501. }
  1502. if (exp != null) {
  1503. // in case this contains another aggregate
  1504. exp = new SqlSimpleExpression(exp);
  1505. }
  1506. SqlExpression agg = this.GetAggregate(aggType, returnType, exp);
  1507. SqlColumn c = new SqlColumn(agg.ClrType, agg.SqlType, null, null, agg, this.dominatingExpression);
  1508. info.SelectWithGroup.Row.Columns.Add(c);
  1509. return new SqlColumnRef(c);
  1510. }
  1511. }
  1512. // Otherwise, if we cannot optimize then fall back to generating a nested aggregate in a correlated sub query
  1513. //
  1514. // SCALAR(SELECT Agg(exp) FROM seq)
  1515. {
  1516. SqlExpression exp = (lambda != null) ? this.VisitExpression(lambda.Body) : null;
  1517. if (exp != null) {
  1518. // in case this contains another aggregate
  1519. exp = new SqlSimpleExpression(exp);
  1520. }
  1521. SqlSelect sel = new SqlSelect(
  1522. this.GetAggregate(aggType, returnType, isCount ? null : (lambda == null) ? aref : exp),
  1523. alias,
  1524. this.dominatingExpression
  1525. );
  1526. sel.Where = isCount ? exp : null;
  1527. return sql.SubSelect(SqlNodeType.ScalarSubSelect, sel);
  1528. }
  1529. }
  1530. private GroupInfo FindGroupInfo(SqlNode source) {
  1531. GroupInfo info = null;
  1532. this.gmap.TryGetValue(source, out info);
  1533. if (info != null) {
  1534. return info;
  1535. }
  1536. SqlAlias alias = source as SqlAlias;
  1537. if (alias != null) {
  1538. SqlSelect select = alias.Node as SqlSelect;
  1539. if (select != null) {
  1540. return this.FindGroupInfo(select.Selection);
  1541. }
  1542. // it might be an expression (not yet fully resolved)
  1543. source = alias.Node;
  1544. }
  1545. SqlExpression expr = source as SqlExpression;
  1546. if (expr != null) {
  1547. switch (expr.NodeType) {
  1548. case SqlNodeType.AliasRef:
  1549. return this.FindGroupInfo(((SqlAliasRef)expr).Alias);
  1550. case SqlNodeType.Member:
  1551. return this.FindGroupInfo(((SqlMember)expr).Expression);
  1552. default:
  1553. this.gmap.TryGetValue(expr, out info);
  1554. return info;
  1555. }
  1556. }
  1557. return null;
  1558. }
  1559. private SqlExpression GetAggregate(SqlNodeType aggType, Type clrType, SqlExpression exp) {
  1560. ProviderType sqlType = this.typeProvider.From(clrType);
  1561. return new SqlUnary(aggType, clrType, sqlType, exp, this.dominatingExpression);
  1562. }
  1563. private SqlNode VisitContains(Expression sequence, Expression value) {
  1564. Type elemType = TypeSystem.GetElementType(sequence.Type);
  1565. SqlNode seqNode = this.Visit(sequence);
  1566. if (seqNode.NodeType == SqlNodeType.ClientArray) {
  1567. SqlClientArray array = (SqlClientArray)seqNode;
  1568. return this.GenerateInExpression(this.VisitExpression(value), array.Expressions);
  1569. }
  1570. else if (seqNode.NodeType == SqlNodeType.Value) {
  1571. IEnumerable values = ((SqlValue)seqNode).Value as IEnumerable;
  1572. IQueryable query = values as IQueryable;
  1573. if (query == null) {
  1574. SqlExpression expr = this.VisitExpression(value);
  1575. List<SqlExpression> list = values.OfType<object>().Select(v => sql.ValueFromObject(v, elemType, true, this.dominatingExpression)).ToList();
  1576. return this.GenerateInExpression(expr, list);
  1577. }
  1578. seqNode = this.Visit(query.Expression);
  1579. }
  1580. ParameterExpression p = Expression.Parameter(value.Type, "p");
  1581. LambdaExpression lambda = Expression.Lambda(Expression.Equal(p, value), p);
  1582. return this.VisitQuantifier(this.CoerceToSequence(seqNode), lambda, true);
  1583. }
  1584. private SqlExpression GenerateInExpression(SqlExpression expr, List<SqlExpression> list) {
  1585. if (list.Count == 0) {
  1586. return sql.ValueFromObject(false, this.dominatingExpression);
  1587. }
  1588. else if (list[0].SqlType.CanBeColumn) {
  1589. return sql.In(expr, list, this.dominatingExpression);
  1590. }
  1591. else {
  1592. SqlExpression pred = sql.Binary(SqlNodeType.EQ, expr, list[0]);
  1593. for (int i = 1, n = list.Count; i < n; i++) {
  1594. pred = sql.Binary(SqlNodeType.Or, pred, sql.Binary(SqlNodeType.EQ, (SqlExpression)SqlDuplicator.Copy(expr), list[i]));
  1595. }
  1596. return pred;
  1597. }
  1598. }
  1599. private SqlNode VisitQuantifier(Expression sequence, LambdaExpression lambda, bool isAny) {
  1600. return this.VisitQuantifier(this.VisitSequence(sequence), lambda, isAny);
  1601. }
  1602. private SqlNode VisitQuantifier(SqlSelect select, LambdaExpression lambda, bool isAny) {
  1603. SqlAlias alias = new SqlAlias(select);
  1604. SqlAliasRef aref = new SqlAliasRef(alias);
  1605. if (lambda != null) {
  1606. this.map[lambda.Parameters[0]] = aref;
  1607. }
  1608. SqlExpression cond = lambda != null ? this.VisitExpression(lambda.Body) : null;
  1609. return this.GenerateQuantifier(alias, cond, isAny);
  1610. }
  1611. private SqlExpression GenerateQuantifier(SqlAlias alias, SqlExpression cond, bool isAny) {
  1612. SqlAliasRef aref = new SqlAliasRef(alias);
  1613. if (isAny) {
  1614. SqlSelect sel = new SqlSelect(aref, alias, this.dominatingExpression);
  1615. sel.Where = cond;
  1616. sel.OrderingType = SqlOrderingType.Never;
  1617. SqlSubSelect exists = sql.SubSelect(SqlNodeType.Exists, sel);
  1618. return exists;
  1619. }
  1620. else {
  1621. SqlSelect sel = new SqlSelect(aref, alias, this.dominatingExpression);
  1622. SqlSubSelect ss = sql.SubSelect(SqlNodeType.Exists, sel);
  1623. sel.Where = sql.Unary(SqlNodeType.Not2V, cond, this.dominatingExpression);
  1624. return sql.Unary(SqlNodeType.Not, ss, this.dominatingExpression);
  1625. }
  1626. }
  1627. private void CheckContext(SqlExpression expr) {
  1628. // try to catch use of incorrect context if we can
  1629. SqlValue value = expr as SqlValue;
  1630. if (value != null) {
  1631. DataContext dc = value.Value as DataContext;
  1632. if (dc != null) {
  1633. if (dc != this.services.Context) {
  1634. throw Error.WrongDataContext();
  1635. }
  1636. }
  1637. }
  1638. }
  1639. private SqlNode VisitMemberAccess(MemberExpression ma) {
  1640. Type memberType = TypeSystem.GetMemberType(ma.Member);
  1641. if (memberType.IsGenericType && memberType.GetGenericTypeDefinition() == typeof(Table<>)) {
  1642. Type rowType = memberType.GetGenericArguments()[0];
  1643. CheckContext(this.VisitExpression(ma.Expression));
  1644. ITable table = this.services.Context.GetTable(rowType);
  1645. if (table != null)
  1646. return this.Visit(Expression.Constant(table));
  1647. }
  1648. if (ma.Member.Name == "Count" && TypeSystem.IsSequenceType(ma.Expression.Type)) {
  1649. return this.VisitAggregate(ma.Expression, null, SqlNodeType.Count, typeof(int));
  1650. }
  1651. return sql.Member(VisitExpression(ma.Expression), ma.Member);
  1652. }
  1653. [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
  1654. private SqlNode VisitMethodCall(MethodCallExpression mc) {
  1655. Type declType = mc.Method.DeclaringType;
  1656. if (mc.Method.IsStatic) {
  1657. if (this.IsSequenceOperatorCall(mc)) {
  1658. return this.VisitSequenceOperatorCall(mc);
  1659. }
  1660. else if (IsDataManipulationCall(mc)) {
  1661. return this.VisitDataManipulationCall(mc);
  1662. }
  1663. // why is this handled here and not in SqlMethodCallConverter?
  1664. else if (declType == typeof(DBConvert) || declType == typeof(Convert)) {
  1665. if (mc.Method.Name == "ChangeType") {
  1666. SqlNode sn = null;
  1667. if (mc.Arguments.Count == 2) {
  1668. object value = GetValue(mc.Arguments[1], "ChangeType");
  1669. if (value != null && typeof(Type).IsAssignableFrom(value.GetType())) {
  1670. sn = this.VisitChangeType(mc.Arguments[0], (Type)value);
  1671. }
  1672. }
  1673. if(sn == null) {
  1674. throw Error.MethodFormHasNoSupportConversionToSql(mc.Method.Name, mc.Method);
  1675. }
  1676. return sn;
  1677. }
  1678. }
  1679. }
  1680. else if (typeof(DataContext).IsAssignableFrom(mc.Method.DeclaringType)) {
  1681. switch (mc.Method.Name) {
  1682. case "GetTable": {
  1683. // calls to GetTable<T> can be translated directly as table references
  1684. if (mc.Method.IsGenericMethod) {
  1685. Type[] typeArgs = mc.Method.GetGenericArguments();
  1686. if (typeArgs.Length == 1 && mc.Method.GetParameters().Length == 0) {
  1687. CheckContext(this.VisitExpression(mc.Object));
  1688. ITable table = this.services.Context.GetTable(typeArgs[0]);
  1689. if (table != null) {
  1690. return this.Visit(Expression.Constant(table));
  1691. }
  1692. }
  1693. }
  1694. break;
  1695. }
  1696. case "ExecuteCommand":
  1697. case "ExecuteQuery":
  1698. return this.VisitUserQuery((string)GetValue(mc.Arguments[0], mc.Method.Name), GetArray(mc.Arguments[1]), mc.Type);
  1699. }
  1700. if (this.IsMappedFunctionCall(mc)) {
  1701. return this.VisitMappedFunctionCall(mc);
  1702. }
  1703. }
  1704. else if (
  1705. mc.Method.DeclaringType != typeof(string)
  1706. && mc.Method.Name == "Contains"
  1707. && !mc.Method.IsStatic
  1708. && typeof(IList).IsAssignableFrom(mc.Method.DeclaringType)
  1709. && mc.Type == typeof(bool)
  1710. && mc.Arguments.Count == 1
  1711. && TypeSystem.GetElementType(mc.Method.DeclaringType).IsAssignableFrom(mc.Arguments[0].Type)
  1712. ) {
  1713. return this.VisitContains(mc.Object, mc.Arguments[0]);
  1714. }
  1715. // default: create sql method call node instead
  1716. SqlExpression obj = VisitExpression(mc.Object);
  1717. SqlExpression[] args = new SqlExpression[mc.Arguments.Count];
  1718. for (int i = 0, n = args.Length; i < n; i++) {
  1719. args[i] = VisitExpression(mc.Arguments[i]);
  1720. }
  1721. return sql.MethodCall(mc.Method, obj, args, dominatingExpression);
  1722. }
  1723. private object GetValue(Expression expression, string operation) {
  1724. SqlExpression exp = this.VisitExpression(expression);
  1725. if (exp.NodeType == SqlNodeType.Value) {
  1726. return ((SqlValue)exp).Value;
  1727. }
  1728. throw Error.NonConstantExpressionsNotSupportedFor(operation);
  1729. }
  1730. private static Expression[] GetArray(Expression array) {
  1731. NewArrayExpression n = array as NewArrayExpression;
  1732. if (n != null) {
  1733. return n.Expressions.ToArray();
  1734. }
  1735. ConstantExpression c = array as ConstantExpression;
  1736. if (c != null) {
  1737. object[] obs = c.Value as object[];
  1738. if (obs != null) {
  1739. Type elemType = TypeSystem.GetElementType(c.Type);
  1740. return obs.Select(o => Expression.Constant(o, elemType)).ToArray();
  1741. }
  1742. }
  1743. return new Expression[] { };
  1744. }
  1745. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  1746. private Expression RemoveQuotes(Expression expression) {
  1747. while (expression.NodeType == ExpressionType.Quote) {
  1748. expression = ((UnaryExpression)expression).Operand;
  1749. }
  1750. return expression;
  1751. }
  1752. private bool IsLambda(Expression expression) {
  1753. return this.RemoveQuotes(expression).NodeType == ExpressionType.Lambda;
  1754. }
  1755. private LambdaExpression GetLambda(Expression expression) {
  1756. return this.RemoveQuotes(expression) as LambdaExpression;
  1757. }
  1758. private bool IsMappedFunctionCall(MethodCallExpression mc) {
  1759. MetaFunction function = services.Model.GetFunction(mc.Method);
  1760. return function != null;
  1761. }
  1762. private SqlNode VisitMappedFunctionCall(MethodCallExpression mc) {
  1763. // See if the method maps to a user defined function
  1764. MetaFunction function = services.Model.GetFunction(mc.Method);
  1765. System.Diagnostics.Debug.Assert(function != null);
  1766. CheckContext(this.VisitExpression(mc.Object));
  1767. if (!function.IsComposable) {
  1768. return this.TranslateStoredProcedureCall(mc, function);
  1769. }
  1770. else if (function.ResultRowTypes.Count > 0) {
  1771. return this.TranslateTableValuedFunction(mc, function);
  1772. }
  1773. else {
  1774. ProviderType sqlType = function.ReturnParameter != null && !string.IsNullOrEmpty(function.ReturnParameter.DbType)
  1775. ? this.typeProvider.Parse(function.ReturnParameter.DbType)
  1776. : this.typeProvider.From(mc.Method.ReturnType);
  1777. List<SqlExpression> sqlParams = this.GetFunctionParameters(mc, function);
  1778. return sql.FunctionCall(mc.Method.ReturnType, sqlType, function.MappedName, sqlParams, mc);
  1779. }
  1780. }
  1781. [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
  1782. private bool IsSequenceOperatorCall(MethodCallExpression mc) {
  1783. Type declType = mc.Method.DeclaringType;
  1784. if (declType == typeof(System.Linq.Enumerable) ||
  1785. declType == typeof(System.Linq.Queryable)) {
  1786. return true;
  1787. }
  1788. return false;
  1789. }
  1790. private bool IsSequenceOperatorCall(MethodCallExpression mc, string methodName) {
  1791. if (IsSequenceOperatorCall(mc) && mc.Method.Name == methodName) {
  1792. return true;
  1793. }
  1794. return false;
  1795. }
  1796. [SuppressMessage("Microsoft.Maintainability", "CA1505:AvoidUnmaintainableCode", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
  1797. [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
  1798. private SqlNode VisitSequenceOperatorCall(MethodCallExpression mc) {
  1799. Type declType = mc.Method.DeclaringType;
  1800. bool isSupportedSequenceOperator = false;
  1801. if (IsSequenceOperatorCall(mc)) {
  1802. switch (mc.Method.Name) {
  1803. case "Select":
  1804. isSupportedSequenceOperator = true;
  1805. if (mc.Arguments.Count == 2 &&
  1806. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1807. return this.VisitSelect(mc.Arguments[0], this.GetLambda(mc.Arguments[1]));
  1808. }
  1809. break;
  1810. case "SelectMany":
  1811. isSupportedSequenceOperator = true;
  1812. if (mc.Arguments.Count == 2 &&
  1813. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1814. return this.VisitSelectMany(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), null);
  1815. }
  1816. else if (mc.Arguments.Count == 3 &&
  1817. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1 &&
  1818. this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 2) {
  1819. return this.VisitSelectMany(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), this.GetLambda(mc.Arguments[2]));
  1820. }
  1821. break;
  1822. case "Join":
  1823. isSupportedSequenceOperator = true;
  1824. if (mc.Arguments.Count == 5 &&
  1825. this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 1 &&
  1826. this.IsLambda(mc.Arguments[3]) && this.GetLambda(mc.Arguments[3]).Parameters.Count == 1 &&
  1827. this.IsLambda(mc.Arguments[4]) && this.GetLambda(mc.Arguments[4]).Parameters.Count == 2) {
  1828. return this.VisitJoin(mc.Arguments[0], mc.Arguments[1], this.GetLambda(mc.Arguments[2]), this.GetLambda(mc.Arguments[3]), this.GetLambda(mc.Arguments[4]));
  1829. }
  1830. break;
  1831. case "GroupJoin":
  1832. isSupportedSequenceOperator = true;
  1833. if (mc.Arguments.Count == 5 &&
  1834. this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 1 &&
  1835. this.IsLambda(mc.Arguments[3]) && this.GetLambda(mc.Arguments[3]).Parameters.Count == 1 &&
  1836. this.IsLambda(mc.Arguments[4]) && this.GetLambda(mc.Arguments[4]).Parameters.Count == 2) {
  1837. return this.VisitGroupJoin(mc.Arguments[0], mc.Arguments[1], this.GetLambda(mc.Arguments[2]), this.GetLambda(mc.Arguments[3]), this.GetLambda(mc.Arguments[4]));
  1838. }
  1839. break;
  1840. case "DefaultIfEmpty":
  1841. isSupportedSequenceOperator = true;
  1842. if (mc.Arguments.Count == 1) {
  1843. return this.VisitDefaultIfEmpty(mc.Arguments[0]);
  1844. }
  1845. break;
  1846. case "OfType":
  1847. isSupportedSequenceOperator = true;
  1848. if (mc.Arguments.Count == 1) {
  1849. Type ofType = mc.Method.GetGenericArguments()[0];
  1850. return this.VisitOfType(mc.Arguments[0], ofType);
  1851. }
  1852. break;
  1853. case "Cast":
  1854. isSupportedSequenceOperator = true;
  1855. if (mc.Arguments.Count == 1) {
  1856. Type type = mc.Method.GetGenericArguments()[0];
  1857. return this.VisitSequenceCast(mc.Arguments[0], type);
  1858. }
  1859. break;
  1860. case "Where":
  1861. isSupportedSequenceOperator = true;
  1862. if (mc.Arguments.Count == 2 &&
  1863. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1864. return this.VisitWhere(mc.Arguments[0], this.GetLambda(mc.Arguments[1]));
  1865. }
  1866. break;
  1867. case "First":
  1868. case "FirstOrDefault":
  1869. isSupportedSequenceOperator = true;
  1870. if (mc.Arguments.Count == 1) {
  1871. return this.VisitFirst(mc.Arguments[0], null, true);
  1872. }
  1873. else if (mc.Arguments.Count == 2 &&
  1874. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1875. return this.VisitFirst(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), true);
  1876. }
  1877. break;
  1878. case "Single":
  1879. case "SingleOrDefault":
  1880. isSupportedSequenceOperator = true;
  1881. if (mc.Arguments.Count == 1) {
  1882. return this.VisitFirst(mc.Arguments[0], null, false);
  1883. }
  1884. else if (mc.Arguments.Count == 2 &&
  1885. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1886. return this.VisitFirst(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), false);
  1887. }
  1888. break;
  1889. case "Distinct":
  1890. isSupportedSequenceOperator = true;
  1891. if (mc.Arguments.Count == 1) {
  1892. return this.VisitDistinct(mc.Arguments[0]);
  1893. }
  1894. break;
  1895. case "Concat":
  1896. isSupportedSequenceOperator = true;
  1897. if (mc.Arguments.Count == 2) {
  1898. return this.VisitConcat(mc.Arguments[0], mc.Arguments[1]);
  1899. }
  1900. break;
  1901. case "Union":
  1902. isSupportedSequenceOperator = true;
  1903. if (mc.Arguments.Count == 2) {
  1904. return this.VisitUnion(mc.Arguments[0], mc.Arguments[1]);
  1905. }
  1906. break;
  1907. case "Intersect":
  1908. isSupportedSequenceOperator = true;
  1909. if (mc.Arguments.Count == 2) {
  1910. return this.VisitIntersect(mc.Arguments[0], mc.Arguments[1]);
  1911. }
  1912. break;
  1913. case "Except":
  1914. isSupportedSequenceOperator = true;
  1915. if (mc.Arguments.Count == 2) {
  1916. return this.VisitExcept(mc.Arguments[0], mc.Arguments[1]);
  1917. }
  1918. break;
  1919. case "Any":
  1920. isSupportedSequenceOperator = true;
  1921. if (mc.Arguments.Count == 1) {
  1922. return this.VisitQuantifier(mc.Arguments[0], null, true);
  1923. }
  1924. else if (mc.Arguments.Count == 2 &&
  1925. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1926. return this.VisitQuantifier(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), true);
  1927. }
  1928. break;
  1929. case "All":
  1930. isSupportedSequenceOperator = true;
  1931. if (mc.Arguments.Count == 2 &&
  1932. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1933. return this.VisitQuantifier(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), false);
  1934. }
  1935. break;
  1936. case "Count":
  1937. isSupportedSequenceOperator = true;
  1938. if (mc.Arguments.Count == 1) {
  1939. return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.Count, mc.Type);
  1940. }
  1941. else if (mc.Arguments.Count == 2 &&
  1942. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1943. return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.Count, mc.Type);
  1944. }
  1945. break;
  1946. case "LongCount":
  1947. isSupportedSequenceOperator = true;
  1948. if (mc.Arguments.Count == 1) {
  1949. return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.LongCount, mc.Type);
  1950. }
  1951. else if (mc.Arguments.Count == 2 &&
  1952. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1953. return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.LongCount, mc.Type);
  1954. }
  1955. break;
  1956. case "Sum":
  1957. isSupportedSequenceOperator = true;
  1958. if (mc.Arguments.Count == 1) {
  1959. return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.Sum, mc.Type);
  1960. }
  1961. else if (mc.Arguments.Count == 2 &&
  1962. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1963. return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.Sum, mc.Type);
  1964. }
  1965. break;
  1966. case "Min":
  1967. isSupportedSequenceOperator = true;
  1968. if (mc.Arguments.Count == 1) {
  1969. return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.Min, mc.Type);
  1970. }
  1971. else if (mc.Arguments.Count == 2 &&
  1972. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1973. return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.Min, mc.Type);
  1974. }
  1975. break;
  1976. case "Max":
  1977. isSupportedSequenceOperator = true;
  1978. if (mc.Arguments.Count == 1) {
  1979. return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.Max, mc.Type);
  1980. }
  1981. else if (mc.Arguments.Count == 2 &&
  1982. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1983. return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.Max, mc.Type);
  1984. }
  1985. break;
  1986. case "Average":
  1987. isSupportedSequenceOperator = true;
  1988. if (mc.Arguments.Count == 1) {
  1989. return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.Avg, mc.Type);
  1990. }
  1991. else if (mc.Arguments.Count == 2 &&
  1992. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  1993. return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.Avg, mc.Type);
  1994. }
  1995. break;
  1996. case "GroupBy":
  1997. isSupportedSequenceOperator = true;
  1998. if (mc.Arguments.Count == 2 &&
  1999. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  2000. return this.VisitGroupBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), null, null);
  2001. }
  2002. else if (mc.Arguments.Count == 3 &&
  2003. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1 &&
  2004. this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 1) {
  2005. return this.VisitGroupBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), this.GetLambda(mc.Arguments[2]), null);
  2006. }
  2007. else if (mc.Arguments.Count == 3 &&
  2008. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1 &&
  2009. this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 2) {
  2010. return this.VisitGroupBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), null, this.GetLambda(mc.Arguments[2]));
  2011. }
  2012. else if (mc.Arguments.Count == 4 &&
  2013. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1 &&
  2014. this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 1 &&
  2015. this.IsLambda(mc.Arguments[3]) && this.GetLambda(mc.Arguments[3]).Parameters.Count == 2) {
  2016. return this.VisitGroupBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), this.GetLambda(mc.Arguments[2]), this.GetLambda(mc.Arguments[3]));
  2017. }
  2018. break;
  2019. case "OrderBy":
  2020. isSupportedSequenceOperator = true;
  2021. if (mc.Arguments.Count == 2 &&
  2022. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  2023. return this.VisitOrderBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlOrderType.Ascending);
  2024. }
  2025. break;
  2026. case "OrderByDescending":
  2027. isSupportedSequenceOperator = true;
  2028. if (mc.Arguments.Count == 2 &&
  2029. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  2030. return this.VisitOrderBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlOrderType.Descending);
  2031. }
  2032. break;
  2033. case "ThenBy":
  2034. isSupportedSequenceOperator = true;
  2035. if (mc.Arguments.Count == 2 &&
  2036. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  2037. return this.VisitThenBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlOrderType.Ascending);
  2038. }
  2039. break;
  2040. case "ThenByDescending":
  2041. isSupportedSequenceOperator = true;
  2042. if (mc.Arguments.Count == 2 &&
  2043. this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
  2044. return this.VisitThenBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlOrderType.Descending);
  2045. }
  2046. break;
  2047. case "Take":
  2048. isSupportedSequenceOperator = true;
  2049. if (mc.Arguments.Count == 2) {
  2050. return this.VisitTake(mc.Arguments[0], mc.Arguments[1]);
  2051. }
  2052. break;
  2053. case "Skip":
  2054. isSupportedSequenceOperator = true;
  2055. if (mc.Arguments.Count == 2) {
  2056. return this.VisitSkip(mc.Arguments[0], mc.Arguments[1]);
  2057. }
  2058. break;
  2059. case "Contains":
  2060. isSupportedSequenceOperator = true;
  2061. if (mc.Arguments.Count == 2) {
  2062. return this.VisitContains(mc.Arguments[0], mc.Arguments[1]);
  2063. }
  2064. break;
  2065. case "ToList":
  2066. case "AsEnumerable":
  2067. case "ToArray":
  2068. isSupportedSequenceOperator = true;
  2069. if (mc.Arguments.Count == 1) {
  2070. return this.Visit(mc.Arguments[0]);
  2071. }
  2072. break;
  2073. }
  2074. // If the operator is supported, but the particular overload is not,
  2075. // give an appropriate error message
  2076. if (isSupportedSequenceOperator) {
  2077. throw Error.QueryOperatorOverloadNotSupported(mc.Method.Name);
  2078. }
  2079. throw Error.QueryOperatorNotSupported(mc.Method.Name);
  2080. }
  2081. else {
  2082. throw Error.InvalidSequenceOperatorCall(declType);
  2083. }
  2084. }
  2085. private static bool IsDataManipulationCall(MethodCallExpression mc) {
  2086. return mc.Method.IsStatic && mc.Method.DeclaringType == typeof(DataManipulation);
  2087. }
  2088. private SqlNode VisitDataManipulationCall(MethodCallExpression mc) {
  2089. if (IsDataManipulationCall(mc)) {
  2090. bool isSupportedDML = false;
  2091. switch (mc.Method.Name) {
  2092. case "Insert":
  2093. isSupportedDML = true;
  2094. if (mc.Arguments.Count == 2) {
  2095. return this.VisitInsert(mc.Arguments[0], this.GetLambda(mc.Arguments[1]));
  2096. }
  2097. else if (mc.Arguments.Count == 1) {
  2098. return this.VisitInsert(mc.Arguments[0], null);
  2099. }
  2100. break;
  2101. case "Update":
  2102. isSupportedDML = true;
  2103. if (mc.Arguments.Count == 3) {
  2104. return this.VisitUpdate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), this.GetLambda(mc.Arguments[2]));
  2105. }
  2106. else if (mc.Arguments.Count == 2) {
  2107. if (mc.Method.GetGenericArguments().Length == 1) {
  2108. return this.VisitUpdate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), null);
  2109. }
  2110. else {
  2111. return this.VisitUpdate(mc.Arguments[0], null, this.GetLambda(mc.Arguments[1]));
  2112. }
  2113. }
  2114. else if (mc.Arguments.Count == 1) {
  2115. return this.VisitUpdate(mc.Arguments[0], null, null);
  2116. }
  2117. break;
  2118. case "Delete":
  2119. isSupportedDML = true;
  2120. if (mc.Arguments.Count == 2) {
  2121. return this.VisitDelete(mc.Arguments[0], this.GetLambda(mc.Arguments[1]));
  2122. }
  2123. else if (mc.Arguments.Count == 1) {
  2124. return this.VisitDelete(mc.Arguments[0], null);
  2125. }
  2126. break;
  2127. }
  2128. if (isSupportedDML) {
  2129. throw Error.QueryOperatorOverloadNotSupported(mc.Method.Name);
  2130. }
  2131. throw Error.QueryOperatorNotSupported(mc.Method.Name);
  2132. }
  2133. throw Error.InvalidSequenceOperatorCall(mc.Method.Name);
  2134. }
  2135. private SqlNode VisitFirst(Expression sequence, LambdaExpression lambda, bool isFirst) {
  2136. SqlSelect select = this.LockSelect(this.VisitSequence(sequence));
  2137. if (lambda != null) {
  2138. this.map[lambda.Parameters[0]] = (SqlAliasRef)select.Selection;
  2139. select.Where = this.VisitExpression(lambda.Body);
  2140. }
  2141. if (isFirst) {
  2142. select.Top = this.sql.ValueFromObject(1, false, this.dominatingExpression);
  2143. }
  2144. if (this.outerNode) {
  2145. return select;
  2146. }
  2147. SqlNodeType subType = (this.typeProvider.From(select.Selection.ClrType).CanBeColumn) ? SqlNodeType.ScalarSubSelect : SqlNodeType.Element;
  2148. SqlSubSelect elem = sql.SubSelect(subType, select, sequence.Type);
  2149. return elem;
  2150. }
  2151. [SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification="These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
  2152. private SqlStatement VisitInsert(Expression item, LambdaExpression resultSelector) {
  2153. if (item == null) {
  2154. throw Error.ArgumentNull("item");
  2155. }
  2156. this.dominatingExpression = item;
  2157. MetaTable metaTable = this.services.Model.GetTable(item.Type);
  2158. Expression source = this.services.Context.GetTable(metaTable.RowType.Type).Expression;
  2159. MetaType itemMetaType = null;
  2160. SqlNew sqlItem = null;
  2161. // construct insert assignments from 'item' info
  2162. ConstantExpression conItem = item as ConstantExpression;
  2163. if (conItem == null) {
  2164. throw Error.InsertItemMustBeConstant();
  2165. }
  2166. if (conItem.Value == null) {
  2167. throw Error.ArgumentNull("item");
  2168. }
  2169. // construct insert based on constant value
  2170. List<SqlMemberAssign> bindings = new List<SqlMemberAssign>();
  2171. itemMetaType = metaTable.RowType.GetInheritanceType(conItem.Value.GetType());
  2172. SqlExpression sqlExprItem = sql.ValueFromObject(conItem.Value, true, source);
  2173. foreach (MetaDataMember mm in itemMetaType.PersistentDataMembers) {
  2174. if (!mm.IsAssociation && !mm.IsDbGenerated && !mm.IsVersion) {
  2175. bindings.Add(new SqlMemberAssign(mm.Member, sql.Member(sqlExprItem, mm.Member)));
  2176. }
  2177. }
  2178. ConstructorInfo cons = itemMetaType.Type.GetConstructor(Type.EmptyTypes);
  2179. System.Diagnostics.Debug.Assert(cons != null);
  2180. sqlItem = sql.New(itemMetaType, cons, null, null, bindings, item);
  2181. SqlTable tab = sql.Table(metaTable, metaTable.RowType, this.dominatingExpression);
  2182. SqlInsert sin = new SqlInsert(tab, sqlItem, item);
  2183. if (resultSelector == null) {
  2184. return sin;
  2185. }
  2186. else {
  2187. MetaDataMember id = itemMetaType.DBGeneratedIdentityMember;
  2188. bool isDbGenOnly = false;
  2189. if (id != null) {
  2190. isDbGenOnly = this.IsDbGeneratedKeyProjectionOnly(resultSelector.Body, id);
  2191. if (id.Type == typeof(Guid) && (this.converterStrategy & ConverterStrategy.CanOutputFromInsert) != 0) {
  2192. sin.OutputKey = new SqlColumn(id.Type, sql.Default(id), id.Name, id, null, this.dominatingExpression);
  2193. if (!isDbGenOnly) {
  2194. sin.OutputToLocal = true;
  2195. }
  2196. }
  2197. }
  2198. SqlSelect result = null;
  2199. SqlSelect preResult = null;
  2200. SqlAlias tableAlias = new SqlAlias(tab);
  2201. SqlAliasRef tableAliasRef = new SqlAliasRef(tableAlias);
  2202. System.Diagnostics.Debug.Assert(resultSelector.Parameters.Count == 1);
  2203. this.map.Add(resultSelector.Parameters[0], tableAliasRef);
  2204. SqlExpression projection = this.VisitExpression(resultSelector.Body);
  2205. // build select to return result
  2206. SqlExpression pred = null;
  2207. if (id != null) {
  2208. pred = sql.Binary(
  2209. SqlNodeType.EQ,
  2210. sql.Member(tableAliasRef, id.Member),
  2211. this.GetIdentityExpression(id, sin.OutputKey != null)
  2212. );
  2213. }
  2214. else {
  2215. SqlExpression itemExpression = this.VisitExpression(item);
  2216. pred = sql.Binary(SqlNodeType.EQ2V, tableAliasRef, itemExpression);
  2217. }
  2218. result = new SqlSelect(projection, tableAlias, resultSelector);
  2219. result.Where = pred;
  2220. // Since we're only projecting back a single generated key, we can
  2221. // optimize the query to a simple selection (e.g. SELECT @@IDENTITY)
  2222. // rather than selecting back from the table.
  2223. if (id != null && isDbGenOnly) {
  2224. if (sin.OutputKey == null) {
  2225. SqlExpression exp = this.GetIdentityExpression(id, false);
  2226. if (exp.ClrType != id.Type) {
  2227. ProviderType sqlType = sql.Default(id);
  2228. exp = sql.ConvertTo(id.Type, sqlType, exp);
  2229. }
  2230. // The result selector passed in was bound to the table -
  2231. // we need to rebind to the single result as an array projection
  2232. ParameterExpression p = Expression.Parameter(id.Type, "p");
  2233. Expression[] init = new Expression[1] { Expression.Convert(p, typeof(object)) };
  2234. NewArrayExpression arrExp = Expression.NewArrayInit(typeof(object), init);
  2235. LambdaExpression rs = Expression.Lambda(arrExp, p);
  2236. this.map.Add(p, exp);
  2237. SqlExpression proj = this.VisitExpression(rs.Body);
  2238. preResult = new SqlSelect(proj, null, rs);
  2239. }
  2240. else {
  2241. // case handled in formatter automatically
  2242. }
  2243. result.DoNotOutput = true;
  2244. }
  2245. // combine insert & result into block
  2246. SqlBlock block = new SqlBlock(this.dominatingExpression);
  2247. block.Statements.Add(sin);
  2248. if (preResult != null) {
  2249. block.Statements.Add(preResult);
  2250. }
  2251. block.Statements.Add(result);
  2252. return block;
  2253. }
  2254. }
  2255. private bool IsDbGeneratedKeyProjectionOnly(Expression projection, MetaDataMember keyMember) {
  2256. NewArrayExpression array = projection as NewArrayExpression;
  2257. if (array != null && array.Expressions.Count == 1) {
  2258. Expression exp = array.Expressions[0];
  2259. while (exp.NodeType == ExpressionType.Convert || exp.NodeType == ExpressionType.ConvertChecked) {
  2260. exp = ((UnaryExpression)exp).Operand;
  2261. }
  2262. MemberExpression mex = exp as MemberExpression;
  2263. if (mex != null && mex.Member == keyMember.Member) {
  2264. return true;
  2265. }
  2266. }
  2267. return false;
  2268. }
  2269. private SqlExpression GetIdentityExpression(MetaDataMember id, bool isOutputFromInsert) {
  2270. if (isOutputFromInsert) {
  2271. return new SqlVariable(id.Type, sql.Default(id), "@id", this.dominatingExpression);
  2272. }
  2273. else {
  2274. ProviderType sqlType = sql.Default(id);
  2275. if (!IsLegalIdentityType(sqlType.GetClosestRuntimeType())) {
  2276. throw Error.InvalidDbGeneratedType(sqlType.ToQueryString());
  2277. }
  2278. if ((this.converterStrategy & ConverterStrategy.CanUseScopeIdentity) != 0) {
  2279. return new SqlVariable(typeof(decimal), typeProvider.From(typeof(decimal)), "SCOPE_IDENTITY()", this.dominatingExpression);
  2280. }
  2281. else {
  2282. return new SqlVariable(typeof(decimal), typeProvider.From(typeof(decimal)), "@@IDENTITY", this.dominatingExpression);
  2283. }
  2284. }
  2285. }
  2286. private static bool IsLegalIdentityType(Type type) {
  2287. switch (Type.GetTypeCode(type)) {
  2288. case TypeCode.SByte:
  2289. case TypeCode.Int16:
  2290. case TypeCode.Int32:
  2291. case TypeCode.Int64:
  2292. case TypeCode.Decimal:
  2293. return true;
  2294. }
  2295. return false;
  2296. }
  2297. private SqlExpression GetRowCountExpression() {
  2298. if ((this.converterStrategy & ConverterStrategy.CanUseRowStatus) != 0) {
  2299. return new SqlVariable(typeof(decimal), typeProvider.From(typeof(decimal)), "@@ROWCOUNT", this.dominatingExpression);
  2300. }
  2301. else {
  2302. return new SqlVariable(typeof(decimal), typeProvider.From(typeof(decimal)), "@ROWCOUNT", this.dominatingExpression);
  2303. }
  2304. }
  2305. private SqlStatement VisitUpdate(Expression item, LambdaExpression check, LambdaExpression resultSelector) {
  2306. if (item == null) {
  2307. throw Error.ArgumentNull("item");
  2308. }
  2309. MetaTable metaTable = this.services.Model.GetTable(item.Type);
  2310. Expression source = this.services.Context.GetTable(metaTable.RowType.Type).Expression;
  2311. Type rowType = metaTable.RowType.Type;
  2312. bool saveAllowDeferred = this.allowDeferred;
  2313. this.allowDeferred = false;
  2314. try {
  2315. Expression seq = source;
  2316. // construct identity predicate based on supplied item
  2317. ParameterExpression p = Expression.Parameter(rowType, "p");
  2318. LambdaExpression idPredicate = Expression.Lambda(Expression.Equal(p, item), p);
  2319. // combine predicate and check expression into single find predicate
  2320. LambdaExpression findPredicate = idPredicate;
  2321. if (check != null) {
  2322. findPredicate = Expression.Lambda(Expression.And(Expression.Invoke(findPredicate, p), Expression.Invoke(check, p)), p);
  2323. }
  2324. seq = Expression.Call(typeof(Enumerable), "Where", new Type[] { rowType }, seq, findPredicate);
  2325. // source 'query' is based on table + find predicate
  2326. SqlSelect ss = new RetypeCheckClause().VisitSelect(this.VisitSequence(seq));
  2327. // construct update assignments from 'item' info
  2328. List<SqlAssign> assignments = new List<SqlAssign>();
  2329. ConstantExpression conItem = item as ConstantExpression;
  2330. if (conItem == null) {
  2331. throw Error.UpdateItemMustBeConstant();
  2332. }
  2333. if (conItem.Value == null) {
  2334. throw Error.ArgumentNull("item");
  2335. }
  2336. // get changes from data services to construct update command
  2337. Type entityType = conItem.Value.GetType();
  2338. MetaType metaType = this.services.Model.GetMetaType(entityType);
  2339. ITable table = this.services.Context.GetTable(metaType.InheritanceRoot.Type);
  2340. foreach (ModifiedMemberInfo mmi in table.GetModifiedMembers(conItem.Value)) {
  2341. MetaDataMember mdm = metaType.GetDataMember(mmi.Member);
  2342. assignments.Add(
  2343. new SqlAssign(
  2344. sql.Member(ss.Selection, mmi.Member),
  2345. new SqlValue(mdm.Type, this.typeProvider.From(mdm.Type), mmi.CurrentValue, true, source),
  2346. source
  2347. ));
  2348. }
  2349. SqlUpdate upd = new SqlUpdate(ss, assignments, source);
  2350. if (resultSelector == null) {
  2351. return upd;
  2352. }
  2353. SqlSelect select = null;
  2354. // build select to return result
  2355. seq = source;
  2356. seq = Expression.Call(typeof(Enumerable), "Where", new Type[] { rowType }, seq, idPredicate);
  2357. seq = Expression.Call(typeof(Enumerable), "Select", new Type[] { rowType, resultSelector.Body.Type }, seq, resultSelector);
  2358. select = this.VisitSequence(seq);
  2359. select.Where = sql.AndAccumulate(
  2360. sql.Binary(SqlNodeType.GT, this.GetRowCountExpression(), sql.ValueFromObject(0, false, this.dominatingExpression)),
  2361. select.Where
  2362. );
  2363. // combine update & select into statement block
  2364. SqlBlock block = new SqlBlock(source);
  2365. block.Statements.Add(upd);
  2366. block.Statements.Add(select);
  2367. return block;
  2368. }
  2369. finally {
  2370. this.allowDeferred = saveAllowDeferred;
  2371. }
  2372. }
  2373. private SqlStatement VisitDelete(Expression item, LambdaExpression check) {
  2374. if (item == null) {
  2375. throw Error.ArgumentNull("item");
  2376. }
  2377. bool saveAllowDeferred = this.allowDeferred;
  2378. this.allowDeferred = false;
  2379. try {
  2380. MetaTable metaTable = this.services.Model.GetTable(item.Type);
  2381. Expression source = this.services.Context.GetTable(metaTable.RowType.Type).Expression;
  2382. Type rowType = metaTable.RowType.Type;
  2383. // construct identity predicate based on supplied item
  2384. ParameterExpression p = Expression.Parameter(rowType, "p");
  2385. LambdaExpression idPredicate = Expression.Lambda(Expression.Equal(p, item), p);
  2386. // combine predicate and check expression into single find predicate
  2387. LambdaExpression findPredicate = idPredicate;
  2388. if (check != null) {
  2389. findPredicate = Expression.Lambda(Expression.And(Expression.Invoke(findPredicate, p), Expression.Invoke(check, p)), p);
  2390. }
  2391. Expression seq = Expression.Call(typeof(Enumerable), "Where", new Type[] { rowType }, source, findPredicate);
  2392. SqlSelect ss = new RetypeCheckClause().VisitSelect(this.VisitSequence(seq));
  2393. this.allowDeferred = saveAllowDeferred;
  2394. SqlDelete sd = new SqlDelete(ss, source);
  2395. return sd;
  2396. }
  2397. finally {
  2398. this.allowDeferred = saveAllowDeferred;
  2399. }
  2400. }
  2401. private class RetypeCheckClause : SqlVisitor {
  2402. internal override SqlExpression VisitMethodCall(SqlMethodCall mc) {
  2403. if (mc.Arguments.Count==2 && mc.Method.Name=="op_Equality") {
  2404. var r = mc.Arguments[1];
  2405. if (r.NodeType == SqlNodeType.Value) {
  2406. var v = (SqlValue)r;
  2407. v.SetSqlType(mc.Arguments[0].SqlType);
  2408. }
  2409. }
  2410. return base.VisitMethodCall(mc);
  2411. }
  2412. }
  2413. private SqlExpression VisitNewArrayInit(NewArrayExpression arr) {
  2414. SqlExpression[] exprs = new SqlExpression[arr.Expressions.Count];
  2415. for (int i = 0, n = exprs.Length; i < n; i++) {
  2416. exprs[i] = this.VisitExpression(arr.Expressions[i]);
  2417. }
  2418. return new SqlClientArray(arr.Type, this.typeProvider.From(arr.Type), exprs, this.dominatingExpression);
  2419. }
  2420. private SqlExpression VisitListInit(ListInitExpression list) {
  2421. if (null != list.NewExpression.Constructor && 0 != list.NewExpression.Arguments.Count) {
  2422. // Throw existing exception for unrecognized expressions if list
  2423. // init does not use a default constructor.
  2424. throw Error.UnrecognizedExpressionNode(list.NodeType);
  2425. }
  2426. SqlExpression[] exprs = new SqlExpression[list.Initializers.Count];
  2427. for (int i = 0, n = exprs.Length; i < n; i++) {
  2428. if (1 != list.Initializers[i].Arguments.Count) {
  2429. // Throw existing exception for unrecognized expressions if element
  2430. // init is not adding a single element.
  2431. throw Error.UnrecognizedExpressionNode(list.NodeType);
  2432. }
  2433. exprs[i] = this.VisitExpression(list.Initializers[i].Arguments.Single());
  2434. }
  2435. return new SqlClientArray(list.Type, this.typeProvider.From(list.Type), exprs, this.dominatingExpression);
  2436. }
  2437. }
  2438. class SingleTableQueryVisitor : SqlVisitor {
  2439. public bool IsValid;
  2440. bool IsDistinct;
  2441. List<MemberInfo> IdentityMembers;
  2442. void AddIdentityMembers(IEnumerable<MemberInfo> members) {
  2443. System.Diagnostics.Debug.Assert(this.IdentityMembers == null, "We already have a set of keys -- why are we adding more?");
  2444. this.IdentityMembers = new List<MemberInfo>(members);
  2445. }
  2446. internal SingleTableQueryVisitor(): base() {
  2447. this.IsValid = true;
  2448. }
  2449. internal override SqlNode Visit(SqlNode node) {
  2450. // recurse until we know we're invalid
  2451. if (this.IsValid && node != null) {
  2452. return base.Visit(node);
  2453. }
  2454. return node;
  2455. }
  2456. internal override SqlTable VisitTable(SqlTable tab) {
  2457. // if we're distinct, we don't care about joins
  2458. if (this.IsDistinct) {
  2459. return tab;
  2460. }
  2461. if (this.IdentityMembers != null) {
  2462. this.IsValid = false;
  2463. } else {
  2464. this.AddIdentityMembers(tab.MetaTable.RowType.IdentityMembers.Select(m => m.Member));
  2465. }
  2466. return tab;
  2467. }
  2468. internal override SqlSource VisitSource(SqlSource source) {
  2469. return base.VisitSource(source);
  2470. }
  2471. internal override SqlSelect VisitSelect(SqlSelect select) {
  2472. if (select.IsDistinct) {
  2473. this.IsDistinct = true;
  2474. // get all members from selection
  2475. this.AddIdentityMembers(select.Selection.ClrType.GetProperties());
  2476. return select;
  2477. }
  2478. //
  2479. //
  2480. //
  2481. //
  2482. //
  2483. // We're not distinct, but let's check our sources...
  2484. select.From = (SqlSource)base.Visit(select.From);
  2485. if (this.IdentityMembers == null || this.IdentityMembers.Count == 0) {
  2486. throw Error.SkipRequiresSingleTableQueryWithPKs();
  2487. }
  2488. else {
  2489. switch (select.Selection.NodeType) {
  2490. case SqlNodeType.Column:
  2491. case SqlNodeType.ColumnRef:
  2492. case SqlNodeType.Member: {
  2493. // we've got a bare member/column node, eg "select c.CustomerId"
  2494. // find out if it refers to the table's PK, of which there must be only 1
  2495. if (this.IdentityMembers.Count == 1) {
  2496. MemberInfo column = this.IdentityMembers[0];
  2497. this.IsValid &= IsColumnMatch(column, select.Selection);
  2498. }
  2499. else {
  2500. this.IsValid = false;
  2501. }
  2502. break;
  2503. }
  2504. case SqlNodeType.New:
  2505. case SqlNodeType.AliasRef: {
  2506. select.Selection = this.VisitExpression(select.Selection);
  2507. break;
  2508. }
  2509. case SqlNodeType.Treat:
  2510. case SqlNodeType.TypeCase: {
  2511. break;
  2512. }
  2513. default: {
  2514. this.IsValid = false;
  2515. break;
  2516. }
  2517. }
  2518. }
  2519. return select;
  2520. }
  2521. //
  2522. //
  2523. //
  2524. //
  2525. //
  2526. internal override SqlExpression VisitNew(SqlNew sox) {
  2527. // check the args for the PKs
  2528. foreach (MemberInfo column in this.IdentityMembers) {
  2529. // assume we're invalid unless we find a matching argument which is
  2530. // a bare column/columnRef to the PK
  2531. bool isMatch = false;
  2532. // find a matching arg
  2533. foreach (SqlExpression expr in sox.Args) {
  2534. isMatch = IsColumnMatch(column, expr);
  2535. if (isMatch) {
  2536. break;
  2537. }
  2538. }
  2539. if (!isMatch) {
  2540. foreach (SqlMemberAssign ma in sox.Members) {
  2541. SqlExpression expr = ma.Expression;
  2542. isMatch = IsColumnMatch(column, expr);
  2543. if (isMatch) {
  2544. break;
  2545. }
  2546. }
  2547. }
  2548. this.IsValid &= isMatch;
  2549. if (!this.IsValid) {
  2550. break;
  2551. }
  2552. }
  2553. return sox;
  2554. }
  2555. internal override SqlNode VisitUnion(SqlUnion su) {
  2556. // we don't want to descend inward
  2557. // just check that it's not a UNION ALL
  2558. if (su.All) {
  2559. this.IsValid = false;
  2560. }
  2561. // UNIONs are distinct
  2562. this.IsDistinct = true;
  2563. // get all members from selection
  2564. this.AddIdentityMembers(su.GetClrType().GetProperties());
  2565. return su;
  2566. }
  2567. private static bool IsColumnMatch(MemberInfo column, SqlExpression expr) {
  2568. MemberInfo memberInfo = null;
  2569. switch (expr.NodeType) {
  2570. case SqlNodeType.Column: {
  2571. memberInfo = ((SqlColumn)expr).MetaMember.Member;
  2572. break;
  2573. }
  2574. case SqlNodeType.ColumnRef: {
  2575. memberInfo = (((SqlColumnRef)expr).Column).MetaMember.Member;
  2576. break;
  2577. }
  2578. case SqlNodeType.Member: {
  2579. memberInfo = ((SqlMember)expr).Member;
  2580. break;
  2581. }
  2582. }
  2583. return (memberInfo != null && memberInfo == column);
  2584. }
  2585. }
  2586. }