/mcs/class/referencesource/System.Data.Linq/SqlClient/Query/QueryConverter.cs
C# | 2887 lines | 2342 code | 301 blank | 244 comment | 771 complexity | 74b75f6e2c6ad9ddcfb9453582b373da MD5 | raw file
Possible License(s): GPL-2.0, CC-BY-SA-3.0, LGPL-2.0, MPL-2.0-no-copyleft-exception, LGPL-2.1, Unlicense, Apache-2.0
- using System;
- using System.Globalization;
- using System.Collections;
- using System.Collections.Generic;
- using System.Data;
- using System.Reflection;
- using System.Text;
- using System.Linq;
- using System.Linq.Expressions;
- using System.Data.Linq;
- using System.Data.Linq.Mapping;
- using System.Data.Linq.Provider;
- using System.Collections.ObjectModel;
- using System.Diagnostics.CodeAnalysis;
- namespace System.Data.Linq.SqlClient {
- /// <summary>
- /// These are application types used to represent types used during intermediate
- /// stages of the query building process.
- /// </summary>
- enum ConverterSpecialTypes {
- Row,
- Table
- }
- [Flags]
- internal enum ConverterStrategy {
- Default = 0x0,
- SkipWithRowNumber = 0x1,
- CanUseScopeIdentity = 0x2,
- CanUseOuterApply = 0x4,
- CanUseRowStatus = 0x8,
- CanUseJoinOn = 0x10, // Whether or not to use ON clause of JOIN.
- CanOutputFromInsert = 0x20
- }
- [SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification="Unknown reason.")]
- internal class QueryConverter {
- IDataServices services;
- Translator translator;
- SqlFactory sql;
- TypeSystemProvider typeProvider;
- bool outerNode;
- Dictionary<ParameterExpression, SqlExpression> map;
- Dictionary<ParameterExpression, Expression> exprMap;
- Dictionary<ParameterExpression, SqlNode> dupMap;
- Dictionary<SqlNode, GroupInfo> gmap;
- Expression dominatingExpression;
- bool allowDeferred;
- ConverterStrategy converterStrategy = ConverterStrategy.Default;
- class GroupInfo {
- internal SqlSelect SelectWithGroup;
- internal SqlExpression ElementOnGroupSource;
- }
- internal ConverterStrategy ConverterStrategy {
- get { return converterStrategy; }
- set { converterStrategy = value; }
- }
- private bool UseConverterStrategy(ConverterStrategy strategy) {
- return (this.converterStrategy & strategy) == strategy;
- }
- internal QueryConverter(IDataServices services, TypeSystemProvider typeProvider, Translator translator, SqlFactory sql) {
- if (services == null) {
- throw Error.ArgumentNull("services");
- }
- if (sql == null) {
- throw Error.ArgumentNull("sql");
- }
- if (translator == null) {
- throw Error.ArgumentNull("translator");
- }
- if (typeProvider == null) {
- throw Error.ArgumentNull("typeProvider");
- }
- this.services = services;
- this.translator = translator;
- this.sql = sql;
- this.typeProvider = typeProvider;
- this.map = new Dictionary<ParameterExpression, SqlExpression>();
- this.exprMap = new Dictionary<ParameterExpression, Expression>();
- this.dupMap = new Dictionary<ParameterExpression, SqlNode>();
- this.gmap = new Dictionary<SqlNode, GroupInfo>();
- this.allowDeferred = true;
- }
- /// <summary>
- /// Convert inner expression from C# expression to basic SQL Query.
- /// </summary>
- /// <param name="node">The expression to convert.</param>
- /// <returns>The converted SQL query.</returns>
- internal SqlNode ConvertOuter(Expression node) {
- this.dominatingExpression = node;
- this.outerNode = true;
- SqlNode retNode;
- if (typeof(ITable).IsAssignableFrom(node.Type)) {
- retNode = this.VisitSequence(node);
- }
- else {
- retNode = this.VisitInner(node);
- }
- if (retNode.NodeType == SqlNodeType.MethodCall) {
- // if a tree consists of a single method call expression only, that method
- // must be either a mapped stored procedure or a mapped function
- throw Error.InvalidMethodExecution(((SqlMethodCall)retNode).Method.Name);
- }
- // if after conversion the node is an expression, we must
- // wrap it in a select
- SqlExpression sqlExpression = retNode as SqlExpression;
- if (sqlExpression != null) {
- retNode = new SqlSelect(sqlExpression, null, this.dominatingExpression);
- }
- retNode = new SqlIncludeScope(retNode, this.dominatingExpression);
- return retNode;
- }
- internal SqlNode Visit(Expression node) {
- bool tempOuterNode = this.outerNode;
- this.outerNode = false;
- SqlNode result = this.VisitInner(node);
- this.outerNode = tempOuterNode;
- return result;
- }
- /// <summary>
- /// Convert inner expression from C# expression to basic SQL Query.
- /// </summary>
- /// <param name="node">The expression to convert.</param>
- /// <param name="dominantExpression">Current dominating expression, used for producing meaningful exception text.</param>
- /// <returns>The converted SQL query.</returns>
- internal SqlNode ConvertInner(Expression node, Expression dominantExpression) {
- this.dominatingExpression = dominantExpression;
- bool tempOuterNode = this.outerNode;
- this.outerNode = false;
- SqlNode result = this.VisitInner(node);
- this.outerNode = tempOuterNode;
- return result;
- }
- [SuppressMessage("Microsoft.Performance", "CA1800:DoNotCastUnnecessarily", Justification = "Microsoft: Cast is dependent on node type and casts do not happen unecessarily in a single code path.")]
- [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
- private SqlNode VisitInner(Expression node) {
- if (node == null) return null;
- Expression save = this.dominatingExpression;
- this.dominatingExpression = ChooseBestDominatingExpression(this.dominatingExpression, node);
- try {
- switch (node.NodeType) {
- case ExpressionType.New:
- return this.VisitNew((NewExpression)node);
- case ExpressionType.MemberInit:
- return this.VisitMemberInit((MemberInitExpression)node);
- case ExpressionType.Negate:
- case ExpressionType.NegateChecked:
- case ExpressionType.Not:
- return this.VisitUnary((UnaryExpression)node);
- case ExpressionType.UnaryPlus:
- if (node.Type == typeof(TimeSpan))
- return this.VisitUnary((UnaryExpression)node);
- throw Error.UnrecognizedExpressionNode(node.NodeType);
- case ExpressionType.Add:
- case ExpressionType.AddChecked:
- case ExpressionType.Subtract:
- case ExpressionType.SubtractChecked:
- case ExpressionType.Multiply:
- case ExpressionType.MultiplyChecked:
- case ExpressionType.Divide:
- case ExpressionType.Modulo:
- case ExpressionType.And:
- case ExpressionType.AndAlso:
- case ExpressionType.Or:
- case ExpressionType.OrElse:
- case ExpressionType.Power:
- case ExpressionType.LessThan:
- case ExpressionType.LessThanOrEqual:
- case ExpressionType.GreaterThan:
- case ExpressionType.GreaterThanOrEqual:
- case ExpressionType.Equal:
- case ExpressionType.NotEqual:
- case ExpressionType.Coalesce:
- case ExpressionType.ExclusiveOr:
- return this.VisitBinary((BinaryExpression)node);
- case ExpressionType.ArrayIndex:
- return this.VisitArrayIndex((BinaryExpression)node);
- case ExpressionType.TypeIs:
- return this.VisitTypeBinary((TypeBinaryExpression)node);
- case ExpressionType.Convert:
- case ExpressionType.ConvertChecked:
- return this.VisitCast((UnaryExpression)node);
- case ExpressionType.TypeAs:
- return this.VisitAs((UnaryExpression)node);
- case ExpressionType.Conditional:
- return this.VisitConditional((ConditionalExpression)node);
- case ExpressionType.Constant:
- return this.VisitConstant((ConstantExpression)node);
- case ExpressionType.Parameter:
- return this.VisitParameter((ParameterExpression)node);
- case ExpressionType.MemberAccess:
- return this.VisitMemberAccess((MemberExpression)node);
- case ExpressionType.Call:
- return this.VisitMethodCall((MethodCallExpression)node);
- case ExpressionType.ArrayLength:
- return this.VisitArrayLength((UnaryExpression)node);
- case ExpressionType.NewArrayInit:
- return this.VisitNewArrayInit((NewArrayExpression)node);
- case ExpressionType.ListInit:
- return this.VisitListInit((ListInitExpression)node);
- case ExpressionType.Quote:
- return this.Visit(((UnaryExpression)node).Operand);
- case ExpressionType.Invoke:
- return this.VisitInvocation((InvocationExpression)node);
- case ExpressionType.Lambda:
- return this.VisitLambda((LambdaExpression)node);
- case ExpressionType.RightShift:
- case ExpressionType.LeftShift:
- throw Error.UnsupportedNodeType(node.NodeType);
- case (ExpressionType)InternalExpressionType.Known:
- return ((KnownExpression)node).Node;
- case (ExpressionType)InternalExpressionType.LinkedTable:
- return this.VisitLinkedTable((LinkedTableExpression)node);
- default:
- throw Error.UnrecognizedExpressionNode(node.NodeType);
- }
- }
- finally {
- this.dominatingExpression = save;
- }
- }
- /// <summary>
- /// Heuristic which chooses the best Expression root to use for displaying user messages
- /// and exception text.
- /// </summary>
- private static Expression ChooseBestDominatingExpression(Expression last, Expression next) {
- if (last == null) {
- return next;
- }
- else if (next == null) {
- return last;
- }
- else {
- if (next is MethodCallExpression) {
- return next;
- }
- if (last is MethodCallExpression) {
- return last;
- }
- }
- return next;
- }
- private SqlSelect LockSelect(SqlSelect sel) {
- if (sel.Selection.NodeType != SqlNodeType.AliasRef ||
- sel.Where != null ||
- sel.OrderBy.Count > 0 ||
- sel.GroupBy.Count > 0 ||
- sel.Having != null ||
- sel.Top != null ||
- sel.OrderingType != SqlOrderingType.Default ||
- sel.IsDistinct) {
- SqlAlias alias = new SqlAlias(sel);
- SqlAliasRef aref = new SqlAliasRef(alias);
- return new SqlSelect(aref, alias, this.dominatingExpression);
- }
- return sel;
- }
- private SqlSelect VisitSequence(Expression exp) {
- return this.CoerceToSequence(this.Visit(exp));
- }
- private SqlSelect CoerceToSequence(SqlNode node) {
- SqlSelect select = node as SqlSelect;
- if (select == null) {
- if (node.NodeType == SqlNodeType.Value) {
- SqlValue sv = (SqlValue)node;
- // Check for ITables.
- ITable t = sv.Value as ITable;
- if (t != null) {
- return this.CoerceToSequence(this.TranslateConstantTable(t, null));
- }
- // Check for IQueryable.
- IQueryable query = sv.Value as IQueryable;
- if (query != null) {
- Expression fex = Funcletizer.Funcletize(query.Expression);
- // IQueryables that return self-referencing Constant expressions cause infinite recursion
- if (fex.NodeType != ExpressionType.Constant ||
- ((ConstantExpression)fex).Value != query) {
- return this.VisitSequence(fex);
- }
- throw Error.IQueryableCannotReturnSelfReferencingConstantExpression();
- }
- throw Error.CapturedValuesCannotBeSequences();
- }
- else if (node.NodeType == SqlNodeType.Multiset || node.NodeType == SqlNodeType.Element) {
- return ((SqlSubSelect)node).Select;
- }
- else if (node.NodeType == SqlNodeType.ClientArray) {
- throw Error.ConstructedArraysNotSupported();
- }
- else if (node.NodeType == SqlNodeType.ClientParameter) {
- throw Error.ParametersCannotBeSequences();
- }
- // this needs to be a sequence expression!
- SqlExpression sqlExpr = (SqlExpression)node;
- SqlAlias sa = new SqlAlias(sqlExpr);
- SqlAliasRef aref = new SqlAliasRef(sa);
- return new SqlSelect(aref, sa, this.dominatingExpression);
- }
- return select;
- }
- //
- // Recursive call to VisitInvocation.
- [System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.NoInlining)]
- private SqlNode VisitInvocation(InvocationExpression invoke) {
- LambdaExpression lambda =
- (invoke.Expression.NodeType == ExpressionType.Quote)
- ? (LambdaExpression)((UnaryExpression)invoke.Expression).Operand
- : (invoke.Expression as LambdaExpression);
- if (lambda != null) {
- // just map arg values into lambda's parameters and evaluate lambda's body
- for (int i = 0, n = invoke.Arguments.Count; i < n; i++) {
- this.exprMap[lambda.Parameters[i]] = invoke.Arguments[i];
- }
- return this.VisitInner(lambda.Body);
- }
- else {
- // check for compiled query invocation
- SqlExpression expr = this.VisitExpression(invoke.Expression);
- if (expr.NodeType == SqlNodeType.Value) {
- SqlValue value = (SqlValue)expr;
- Delegate d = value.Value as Delegate;
- if (d != null) {
- CompiledQuery cq = d.Target as CompiledQuery;
- if (cq != null) {
- return this.VisitInvocation(Expression.Invoke(cq.Expression, invoke.Arguments));
- } else if (invoke.Arguments.Count == 0) {
- object invokeResult;
- try {
- invokeResult = d.DynamicInvoke(null);
- } catch (System.Reflection.TargetInvocationException e) {
- throw e.InnerException;
- }
- return this.sql.ValueFromObject(invokeResult, invoke.Type, true, this.dominatingExpression);
- }
- }
- }
- SqlExpression [] args = new SqlExpression[invoke.Arguments.Count];
- for(int i = 0; i<args.Length; ++i) {
- args[i] = (SqlExpression)this.Visit(invoke.Arguments[i]);
- }
- var sca = new SqlClientArray(typeof(object[]), this.typeProvider.From(typeof(object[])), args, this.dominatingExpression);
- return sql.MethodCall(invoke.Type, typeof(Delegate).GetMethod("DynamicInvoke"), expr, new SqlExpression[] {sca}, this.dominatingExpression);
- }
- }
- // inline lambda expressions w/o invocation are parameterized queries
- private SqlNode VisitLambda(LambdaExpression lambda) {
- // turn lambda parameters into client parameters
- for (int i = 0, n = lambda.Parameters.Count; i < n; i++) {
- ParameterExpression p = lambda.Parameters[i];
- if (p.Type == typeof(Type)) {
- throw Error.BadParameterType(p.Type);
- }
- // construct accessor for parameter
- ParameterExpression pa = Expression.Parameter(typeof(object[]), "args");
- LambdaExpression accessor =
- Expression.Lambda(
- typeof(Func<,>).MakeGenericType(typeof(object[]), p.Type),
- Expression.Convert(
- #pragma warning disable 618 // Disable the 'obsolete' warning
- Expression.ArrayIndex(pa, Expression.Constant(i)),
- p.Type
- ),
- #pragma warning restore 618
- pa
- );
- SqlClientParameter cp = new SqlClientParameter(p.Type, this.typeProvider.From(p.Type), accessor, this.dominatingExpression);
- // map references to lambda's parameter to client parameter node
- this.dupMap[p] = cp;
- }
- // call this so we don't erase 'outerNode' setting
- return this.VisitInner(lambda.Body);
- }
- private SqlExpression VisitExpression(Expression exp) {
- SqlNode result = this.Visit(exp);
- if (result == null) return null;
- SqlExpression x = result as SqlExpression;
- if (x != null) return x;
- SqlSelect select = result as SqlSelect;
- if (select != null) {
- SqlSubSelect ms = sql.SubSelect(SqlNodeType.Multiset, select, exp.Type);
- return ms;
- }
- throw Error.UnrecognizedExpressionNode(result);
- }
- private SqlSelect VisitSelect(Expression sequence, LambdaExpression selector) {
- SqlSelect source = this.VisitSequence(sequence);
- SqlAlias alias = new SqlAlias(source);
- SqlAliasRef aref = new SqlAliasRef(alias);
- this.map[selector.Parameters[0]] = aref;
- SqlNode project = this.Visit(selector.Body);
- SqlSelect pselect = project as SqlSelect;
- if (pselect != null) {
- return new SqlSelect(sql.SubSelect(SqlNodeType.Multiset, pselect, selector.Body.Type), alias, this.dominatingExpression);
- }
- else if ((project.NodeType == SqlNodeType.Element || project.NodeType == SqlNodeType.ScalarSubSelect) &&
- (this.converterStrategy & ConverterStrategy.CanUseOuterApply) != 0) {
- SqlSubSelect sub = (SqlSubSelect)project;
- SqlSelect inner = sub.Select;
- SqlAlias innerAlias = new SqlAlias(inner);
- SqlAliasRef innerRef = new SqlAliasRef(innerAlias);
- if (project.NodeType == SqlNodeType.Element) {
- inner.Selection = new SqlOptionalValue(
- new SqlColumn(
- "test",
- sql.Unary(
- SqlNodeType.OuterJoinedValue,
- sql.Value(typeof(int?), this.typeProvider.From(typeof(int)), 1, false, this.dominatingExpression)
- )
- ),
- sql.Unary(SqlNodeType.OuterJoinedValue, inner.Selection)
- );
- }
- else {
- inner.Selection = sql.Unary(SqlNodeType.OuterJoinedValue, inner.Selection);
- }
- SqlJoin join = new SqlJoin(SqlJoinType.OuterApply, alias, innerAlias, null, this.dominatingExpression);
- return new SqlSelect(innerRef, join, this.dominatingExpression);
- }
- else {
- SqlExpression expr = project as SqlExpression;
- if (expr != null) {
- return new SqlSelect(expr, alias, this.dominatingExpression);
- }
- else {
- throw Error.BadProjectionInSelect();
- }
- }
- }
- private SqlSelect VisitSelectMany(Expression sequence, LambdaExpression colSelector, LambdaExpression resultSelector) {
- SqlSelect seqSelect = this.VisitSequence(sequence);
- SqlAlias seqAlias = new SqlAlias(seqSelect);
- SqlAliasRef seqRef = new SqlAliasRef(seqAlias);
- this.map[colSelector.Parameters[0]] = seqRef;
- SqlNode colSelectorNode = this.VisitSequence(colSelector.Body);
- SqlAlias selAlias = new SqlAlias(colSelectorNode);
- SqlAliasRef selRef = new SqlAliasRef(selAlias);
- SqlJoin join = new SqlJoin(SqlJoinType.CrossApply, seqAlias, selAlias, null, this.dominatingExpression);
- SqlExpression projection = selRef;
- if (resultSelector != null) {
- this.map[resultSelector.Parameters[0]] = seqRef;
- this.map[resultSelector.Parameters[1]] = selRef;
- projection = this.VisitExpression(resultSelector.Body);
- }
- return new SqlSelect(projection, join, this.dominatingExpression);
- }
- private SqlSelect VisitJoin(Expression outerSequence, Expression innerSequence, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector) {
- SqlSelect outerSelect = this.VisitSequence(outerSequence);
- SqlSelect innerSelect = this.VisitSequence(innerSequence);
- SqlAlias outerAlias = new SqlAlias(outerSelect);
- SqlAliasRef outerRef = new SqlAliasRef(outerAlias);
- SqlAlias innerAlias = new SqlAlias(innerSelect);
- SqlAliasRef innerRef = new SqlAliasRef(innerAlias);
- this.map[outerKeySelector.Parameters[0]] = outerRef;
- SqlExpression outerKey = this.VisitExpression(outerKeySelector.Body);
- this.map[innerKeySelector.Parameters[0]] = innerRef;
- SqlExpression innerKey = this.VisitExpression(innerKeySelector.Body);
- this.map[resultSelector.Parameters[0]] = outerRef;
- this.map[resultSelector.Parameters[1]] = innerRef;
- SqlExpression result = this.VisitExpression(resultSelector.Body);
- SqlExpression condition = sql.Binary(SqlNodeType.EQ, outerKey, innerKey);
- SqlSelect select = null;
- if ((this.converterStrategy & ConverterStrategy.CanUseJoinOn) != 0) {
- SqlJoin join = new SqlJoin(SqlJoinType.Inner, outerAlias, innerAlias, condition, this.dominatingExpression);
- select = new SqlSelect(result, join, this.dominatingExpression);
- } else {
- SqlJoin join = new SqlJoin(SqlJoinType.Cross, outerAlias, innerAlias, null, this.dominatingExpression);
- select = new SqlSelect(result, join, this.dominatingExpression);
- select.Where = condition;
- }
- return select;
- }
- private SqlSelect VisitGroupJoin(Expression outerSequence, Expression innerSequence, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector, LambdaExpression resultSelector) {
- SqlSelect outerSelect = this.VisitSequence(outerSequence);
- SqlSelect innerSelect = this.VisitSequence(innerSequence);
- SqlAlias outerAlias = new SqlAlias(outerSelect);
- SqlAliasRef outerRef = new SqlAliasRef(outerAlias);
- SqlAlias innerAlias = new SqlAlias(innerSelect);
- SqlAliasRef innerRef = new SqlAliasRef(innerAlias);
- this.map[outerKeySelector.Parameters[0]] = outerRef;
- SqlExpression outerKey = this.VisitExpression(outerKeySelector.Body);
- this.map[innerKeySelector.Parameters[0]] = innerRef;
- SqlExpression innerKey = this.VisitExpression(innerKeySelector.Body);
- // make multiset
- SqlExpression pred = sql.Binary(SqlNodeType.EQ, outerKey, innerKey);
- SqlSelect select = new SqlSelect(innerRef, innerAlias, this.dominatingExpression);
- select.Where = pred;
- SqlSubSelect subquery = sql.SubSelect(SqlNodeType.Multiset, select);
- // make outer ref & multiset for result-selector params
- this.map[resultSelector.Parameters[0]] = outerRef;
- this.dupMap[resultSelector.Parameters[1]] = subquery;
- SqlExpression result = this.VisitExpression(resultSelector.Body);
- return new SqlSelect(result, outerAlias, this.dominatingExpression);
- }
- private SqlSelect VisitDefaultIfEmpty(Expression sequence) {
- SqlSelect select = this.VisitSequence(sequence);
- SqlAlias alias = new SqlAlias(select);
- SqlAliasRef aliasRef = new SqlAliasRef(alias);
- SqlExpression opt = new SqlOptionalValue(
- new SqlColumn(
- "test",
- sql.Unary(SqlNodeType.OuterJoinedValue,
- sql.Value(typeof(int?), this.typeProvider.From(typeof(int)), 1, false, this.dominatingExpression)
- )
- ),
- sql.Unary(SqlNodeType.OuterJoinedValue, aliasRef)
- );
- SqlSelect optSelect = new SqlSelect(opt, alias, this.dominatingExpression);
- alias = new SqlAlias(optSelect);
- aliasRef = new SqlAliasRef(alias);
- SqlExpression litNull = sql.TypedLiteralNull(typeof(string), this.dominatingExpression);
- SqlSelect selNull = new SqlSelect(litNull, null, this.dominatingExpression);
- SqlAlias aliasNull = new SqlAlias(selNull);
- SqlJoin join = new SqlJoin(SqlJoinType.OuterApply, aliasNull, alias, null, this.dominatingExpression);
- return new SqlSelect(aliasRef, join, this.dominatingExpression);
- }
- /// <summary>
- /// Rewrite seq.OfType<T> as seq.Select(s=>s as T).Where(p=>p!=null).
- /// </summary>
- private SqlSelect VisitOfType(Expression sequence, Type ofType) {
- SqlSelect select = this.LockSelect(this.VisitSequence(sequence));
- SqlAliasRef aref = (SqlAliasRef)select.Selection;
- select.Selection = new SqlUnary(SqlNodeType.Treat, ofType, typeProvider.From(ofType), aref, this.dominatingExpression);
- select = this.LockSelect(select);
- aref = (SqlAliasRef)select.Selection;
- // Append the 'is' operator into the WHERE clause.
- select.Where = sql.AndAccumulate(select.Where,
- sql.Unary(SqlNodeType.IsNotNull, aref, this.dominatingExpression)
- );
- return select;
- }
- /// <summary>
- /// Rewrite seq.Cast<T> as seq.Select(s=>(T)s).
- /// </summary>
- private SqlNode VisitSequenceCast(Expression sequence, Type type) {
- Type sourceType = TypeSystem.GetElementType(sequence.Type);
- ParameterExpression p = Expression.Parameter(sourceType, "pc");
- return this.Visit(Expression.Call(
- typeof(Enumerable), "Select",
- new Type[] {
- sourceType, // TSource element type.
- type, // TResult element type.
- },
- sequence,
- Expression.Lambda(
- Expression.Convert(p, type),
- new ParameterExpression[] { p }
- ))
- );
- }
- /// <summary>
- /// This is the 'is' operator.
- /// </summary>
- private SqlNode VisitTypeBinary(TypeBinaryExpression b) {
- SqlExpression expr = this.VisitExpression(b.Expression);
- SqlExpression result = null;
- switch (b.NodeType) {
- case ExpressionType.TypeIs:
- Type ofType = b.TypeOperand;
- result = sql.Unary(SqlNodeType.IsNotNull, new SqlUnary(SqlNodeType.Treat, ofType, typeProvider.From(ofType), expr, this.dominatingExpression), this.dominatingExpression);
- break;
- default:
- throw Error.TypeBinaryOperatorNotRecognized();
- }
- return result;
- }
- private SqlSelect VisitWhere(Expression sequence, LambdaExpression predicate) {
- SqlSelect select = this.LockSelect(this.VisitSequence(sequence));
- this.map[predicate.Parameters[0]] = (SqlAliasRef)select.Selection;
- select.Where = this.VisitExpression(predicate.Body);
- return select;
- }
- private SqlNode VisitAs(UnaryExpression a) {
- SqlNode node = this.Visit(a.Operand);
- SqlExpression expr = node as SqlExpression;
- if (expr != null) {
- return new SqlUnary(SqlNodeType.Treat, a.Type, typeProvider.From(a.Type), expr, a);
- }
- SqlSelect select = node as SqlSelect;
- if (select != null) {
- SqlSubSelect ms = sql.SubSelect(SqlNodeType.Multiset, select);
- return new SqlUnary(SqlNodeType.Treat, a.Type, typeProvider.From(a.Type), ms, a);
- }
- throw Error.DidNotExpectAs(a);
- }
- private SqlNode VisitArrayLength(UnaryExpression c) {
- SqlExpression exp = this.VisitExpression(c.Operand);
- if (exp.SqlType.IsString || exp.SqlType.IsChar) {
- return sql.CLRLENGTH(exp);
- }
- else {
- return sql.DATALENGTH(exp);
- }
- }
- private SqlNode VisitArrayIndex(BinaryExpression b) {
- SqlExpression array = this.VisitExpression(b.Left);
- SqlExpression index = this.VisitExpression(b.Right);
- if (array.NodeType == SqlNodeType.ClientParameter
- && index.NodeType == SqlNodeType.Value) {
- SqlClientParameter cpArray = (SqlClientParameter)array;
- SqlValue vIndex = (SqlValue)index;
- return new SqlClientParameter(
- b.Type, sql.TypeProvider.From(b.Type),
- Expression.Lambda(
- #pragma warning disable 618 // Disable the 'obsolete' warning
- Expression.ArrayIndex(cpArray.Accessor.Body, Expression.Constant(vIndex.Value, vIndex.ClrType)),
- #pragma warning restore 618
- cpArray.Accessor.Parameters.ToArray()
- ),
- this.dominatingExpression
- );
- }
- throw Error.UnrecognizedExpressionNode(b.NodeType);
- }
- private SqlNode VisitCast(UnaryExpression c) {
- if (c.Method != null) {
- SqlExpression exp = this.VisitExpression(c.Operand);
- return sql.MethodCall(c.Type, c.Method, null, new SqlExpression[] { exp }, dominatingExpression);
- }
- return this.VisitChangeType(c.Operand, c.Type);
- }
- private SqlNode VisitChangeType(Expression expression, Type type) {
- SqlExpression expr = this.VisitExpression(expression);
- return this.ChangeType(expr, type);
- }
- private SqlNode ConvertDateToDateTime2(SqlExpression expr) {
- SqlExpression datetime2 = new SqlVariable(expr.ClrType, expr.SqlType, "DATETIME2", expr.SourceExpression);
- return sql.FunctionCall(typeof(DateTime), "CONVERT", new SqlExpression[2] { datetime2, expr }, expr.SourceExpression);
- }
- private SqlNode ChangeType(SqlExpression expr, Type type) {
- if (type == typeof(object)) {
- return expr; // Boxing conversion?
- }
- else if (expr.NodeType == SqlNodeType.Value && ((SqlValue)expr).Value == null) {
- return sql.TypedLiteralNull(type, expr.SourceExpression);
- }
- else if (expr.NodeType == SqlNodeType.ClientParameter) {
- SqlClientParameter cp = (SqlClientParameter)expr;
- return new SqlClientParameter(
- type, sql.TypeProvider.From(type),
- Expression.Lambda(Expression.Convert(cp.Accessor.Body, type), cp.Accessor.Parameters.ToArray()),
- cp.SourceExpression
- );
- }
- ConversionMethod cm = ChooseConversionMethod(expr.ClrType, type);
- switch (cm) {
- case ConversionMethod.Convert:
- return sql.UnaryConvert(type, typeProvider.From(type), expr, expr.SourceExpression);
- case ConversionMethod.Lift:
- if (SqlFactory.IsSqlDateType(expr)) {
- expr = (SqlExpression) ConvertDateToDateTime2(expr);
- }
- return new SqlLift(type, expr, this.dominatingExpression);
- case ConversionMethod.Ignore:
- if (SqlFactory.IsSqlDateType(expr)) {
- return ConvertDateToDateTime2(expr);
- }
- return expr;
- case ConversionMethod.Treat:
- return new SqlUnary(SqlNodeType.Treat, type, typeProvider.From(type), expr, expr.SourceExpression);
- default:
- throw Error.UnhandledExpressionType(cm);
- }
- }
- enum ConversionMethod {
- Treat,
- Ignore,
- Convert,
- Lift
- }
- private ConversionMethod ChooseConversionMethod(Type fromType, Type toType) {
- Type nnFromType = TypeSystem.GetNonNullableType(fromType);
- Type nnToType = TypeSystem.GetNonNullableType(toType);
- if (fromType != toType && nnFromType == nnToType) {
- return ConversionMethod.Lift;
- }
- else if (TypeSystem.IsSequenceType(nnFromType) || TypeSystem.IsSequenceType(nnToType)) {
- return ConversionMethod.Ignore;
- }
- ProviderType sfromType = typeProvider.From(nnFromType);
- ProviderType stoType = typeProvider.From(nnToType);
- bool isRuntimeOnly1 = sfromType.IsRuntimeOnlyType;
- bool isRuntimeOnly2 = stoType.IsRuntimeOnlyType;
- if (isRuntimeOnly1 || isRuntimeOnly2) {
- return ConversionMethod.Treat;
- }
- if (nnFromType == nnToType // same non-nullable .NET types
- || (sfromType.IsString && sfromType.Equals(stoType)) // same SQL string types
- || (nnFromType.IsEnum || nnToType.IsEnum) // any .NET enum type
- ) {
- return ConversionMethod.Ignore;
- }
- else {
- return ConversionMethod.Convert;
- }
- }
- /// <summary>
- /// Convert ITable into SqlNodes. If the hierarchy involves inheritance then
- /// a type case is built. Abstractly, a type case is a CASE where each WHEN is a possible
- /// a typebinding that may be instantianted.
- /// </summary>
- private SqlNode TranslateConstantTable(ITable table, SqlLink link) {
- if (table.Context != this.services.Context) {
- throw Error.WrongDataContext();
- }
- MetaTable metaTable = this.services.Model.GetTable(table.ElementType);
- return this.translator.BuildDefaultQuery(metaTable.RowType, this.allowDeferred, link, this.dominatingExpression);
- }
- private SqlNode VisitLinkedTable(LinkedTableExpression linkedTable) {
- return TranslateConstantTable(linkedTable.Table, linkedTable.Link);
- }
- private SqlNode VisitConstant(ConstantExpression cons) {
- // A value constant or null.
- Type type = cons.Type;
- if (cons.Value == null) {
- return sql.TypedLiteralNull(type, this.dominatingExpression);
- }
- if (type == typeof(object)) {
- type = cons.Value.GetType();
- }
- return sql.ValueFromObject(cons.Value, type, true, this.dominatingExpression);
- }
- private SqlExpression VisitConditional(ConditionalExpression cond) {
- List<SqlWhen> whens = new List<SqlWhen>(1);
- whens.Add(new SqlWhen(this.VisitExpression(cond.Test), this.VisitExpression(cond.IfTrue)));
- SqlExpression @else = this.VisitExpression(cond.IfFalse);
- // combine search cases found in the else clause into a single seach case
- while (@else.NodeType == SqlNodeType.SearchedCase) {
- SqlSearchedCase sc = (SqlSearchedCase)@else;
- whens.AddRange(sc.Whens);
- @else = sc.Else;
- }
- return sql.SearchedCase(whens.ToArray(), @else, this.dominatingExpression);
- }
- private SqlExpression VisitNew(NewExpression qn) {
- if (TypeSystem.IsNullableType(qn.Type) && qn.Arguments.Count == 1 &&
- TypeSystem.GetNonNullableType(qn.Type) == qn.Arguments[0].Type) {
- return this.VisitCast(Expression.Convert(qn.Arguments[0], qn.Type)) as SqlExpression;
- }
- else if (qn.Type == typeof(decimal) && qn.Arguments.Count == 1) {
- return this.VisitCast(Expression.Convert(qn.Arguments[0], typeof(decimal))) as SqlExpression;
- }
- MetaType mt = this.services.Model.GetMetaType(qn.Type);
- if (mt.IsEntity) {
- throw Error.CannotMaterializeEntityType(qn.Type);
- }
- SqlExpression[] args = null;
- if (qn.Arguments.Count > 0) {
- args = new SqlExpression[qn.Arguments.Count];
- for (int i = 0, n = qn.Arguments.Count; i < n; i++) {
- args[i] = this.VisitExpression(qn.Arguments[i]);
- }
- }
- SqlNew tb = sql.New(mt, qn.Constructor, args, PropertyOrFieldOf(qn.Members), null, this.dominatingExpression);
- return tb;
- }
- private SqlExpression VisitMemberInit(MemberInitExpression init) {
- MetaType mt = this.services.Model.GetMetaType(init.Type);
- if (mt.IsEntity) {
- throw Error.CannotMaterializeEntityType(init.Type);
- }
- SqlExpression[] args = null;
- NewExpression qn = init.NewExpression;
- if (qn.Type == typeof(decimal) && qn.Arguments.Count == 1) {
- return this.VisitCast(Expression.Convert(qn.Arguments[0], typeof(decimal))) as SqlExpression;
- }
- if (qn.Arguments.Count > 0) {
- args = new SqlExpression[qn.Arguments.Count];
- for (int i = 0, n = args.Length; i < n; i++) {
- args[i] = this.VisitExpression(qn.Arguments[i]);
- }
- }
- int cBindings = init.Bindings.Count;
- SqlMemberAssign[] members = new SqlMemberAssign[cBindings];
- int[] ordinal = new int[members.Length];
- for (int i = 0; i < cBindings; i++) {
- MemberAssignment mb = init.Bindings[i] as MemberAssignment;
- if (mb != null) {
- SqlExpression expr = this.VisitExpression(mb.Expression);
- SqlMemberAssign sma = new SqlMemberAssign(mb.Member, expr);
- members[i] = sma;
- ordinal[i] = mt.GetDataMember(mb.Member).Ordinal;
- }
- else {
- throw Error.UnhandledBindingType(init.Bindings[i].BindingType);
- }
- }
- // put members in type's declaration order
- Array.Sort(ordinal, members, 0, members.Length);
- SqlNew tb = sql.New(mt, qn.Constructor, args, PropertyOrFieldOf(qn.Members), members, this.dominatingExpression);
- return tb;
- }
- private static IEnumerable<MemberInfo> PropertyOrFieldOf(IEnumerable<MemberInfo> members) {
- if (members == null) {
- return null;
- }
- List<MemberInfo> result = new List<MemberInfo>();
- foreach (MemberInfo mi in members) {
- switch (mi.MemberType) {
- case MemberTypes.Method: {
- foreach (PropertyInfo pi in mi.DeclaringType.GetProperties(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic)) {
- MethodInfo method = mi as MethodInfo;
- if (pi.CanRead && pi.GetGetMethod() == method) {
- result.Add(pi);
- break;
- }
- }
- break;
- }
- case MemberTypes.Field:
- case MemberTypes.Property: {
- result.Add(mi);
- break;
- }
- default: {
- throw Error.CouldNotConvertToPropertyOrField(mi);
- }
- }
- }
- return result;
- }
- private SqlSelect VisitDistinct(Expression sequence) {
- SqlSelect select = this.LockSelect(this.VisitSequence(sequence));
- select.IsDistinct = true;
- select.OrderingType = SqlOrderingType.Blocked;
- return select;
- }
- private SqlSelect VisitTake(Expression sequence, Expression count) {
- // verify that count >= 0
- SqlExpression takeExp = this.VisitExpression(count);
- if (takeExp.NodeType == SqlNodeType.Value) {
- SqlValue constTakeCount = (SqlValue)takeExp;
- if (typeof(int).IsAssignableFrom(constTakeCount.Value.GetType()) && ((int)constTakeCount.Value) < 0) {
- throw Error.ArgumentOutOfRange("takeCount");
- }
- }
- MethodCallExpression mce = sequence as MethodCallExpression;
- if (mce != null && IsSequenceOperatorCall(mce) && mce.Method.Name == "Skip" && mce.Arguments.Count == 2) {
- SqlExpression skipExp = this.VisitExpression(mce.Arguments[1]);
- // verify that count >= 0
- if (skipExp.NodeType == SqlNodeType.Value) {
- SqlValue constSkipCount = (SqlValue)skipExp;
- if (typeof(int).IsAssignableFrom(constSkipCount.Value.GetType()) && ((int)constSkipCount.Value) < 0) {
- throw Error.ArgumentOutOfRange("skipCount");
- }
- }
- SqlSelect select = this.VisitSequence(mce.Arguments[0]);
- return this.GenerateSkipTake(select, skipExp, takeExp);
- }
- else {
- SqlSelect select = this.VisitSequence(sequence);
- return this.GenerateSkipTake(select, null, takeExp);
- }
- }
- /// <summary>
- /// In order for elements of a sequence to be skipped, they must have identity
- /// that can be compared. This excludes elements that are sequences and elements
- /// that contain sequences.
- /// </summary>
- private bool CanSkipOnSelection(SqlExpression selection) {
- // we can skip over groupings (since we can compare them by key)
- if (IsGrouping(selection.ClrType)) {
- return true;
- }
- // we can skip over entities (since we can compare them by primary key)
- MetaTable table = this.services.Model.GetTable(selection.ClrType);
- if (table != null) {
- return true;
- }
- // sequences that are not primitives are not skippable
- if (TypeSystem.IsSequenceType(selection.ClrType) && !selection.SqlType.CanBeColumn) {
- return false;
- }
- switch (selection.NodeType) {
- case SqlNodeType.AliasRef: {
- SqlNode node = ((SqlAliasRef)selection).Alias.Node;
- SqlSelect select = node as SqlSelect;
- if (select != null) {
- return CanSkipOnSelection(select.Selection);
- }
- SqlUnion union = node as SqlUnion;
- if (union != null) {
- bool left = default(bool);
- bool right = default(bool);
- SqlSelect selectLeft = union.Left as SqlSelect;
- if (selectLeft != null) {
- left = CanSkipOnSelection(selectLeft.Selection);
- }
- SqlSelect selectRight = union.Right as SqlSelect;
- if (selectRight != null) {
- right = CanSkipOnSelection(selectRight.Selection);
- }
- return left && right;
- }
- SqlExpression expr = (SqlExpression)node;
- return CanSkipOnSelection(expr);
- }
- case SqlNodeType.New:
- SqlNew sn = (SqlNew)selection;
- // check each member of the projection for sequences
- foreach (SqlMemberAssign ma in sn.Members) {
- if (!CanSkipOnSelection(ma.Expression))
- return false;
- }
- if (sn.ArgMembers != null) {
- for (int i = 0, n = sn.ArgMembers.Count; i < n; ++i) {
- if (!CanSkipOnSelection(sn.Args[i])) {
- return false;
- }
- }
- }
- break;
- }
- return true;
- }
- /// <summary>
- /// SQL2000:
- /// SELECT *
- /// FROM sequence
- /// WHERE NOT EXISTS (
- /// SELECT TOP count *
- /// FROM sequence)
- ///
- /// SQL2005: SELECT *
- /// FROM (SELECT sequence.*,
- /// ROW_NUMBER() OVER (ORDER BY order) AS ROW_NUMBER
- /// FROM sequence)
- /// WHERE ROW_NUMBER > count
- /// </summary>
- /// <param name="sequence">Sequence containing elements to skip</param>
- /// <param name="count">Number of elements to skip</param>
- /// <returns>SELECT node</returns>
- private SqlSelect VisitSkip(Expression sequence, Expression skipCount) {
- SqlExpression skipExp = this.VisitExpression(skipCount);
- // verify that count >= 0
- if (skipExp.NodeType == SqlNodeType.Value) {
- SqlValue constSkipCount = (SqlValue)skipExp;
- if (typeof(int).IsAssignableFrom(constSkipCount.Value.GetType()) && ((int)constSkipCount.Value) < 0) {
- throw Error.ArgumentOutOfRange("skipCount");
- }
- }
- SqlSelect select = this.VisitSequence(sequence);
- return this.GenerateSkipTake(select, skipExp, null);
- }
- private SqlSelect GenerateSkipTake(SqlSelect sequence, SqlExpression skipExp, SqlExpression takeExp) {
- SqlSelect select = this.LockSelect(sequence);
- // no skip?
- if (skipExp == null) {
- if (takeExp != null) {
- select.Top = takeExp;
- }
- return select;
- }
- SqlAlias alias = new SqlAlias(select);
- SqlAliasRef aref = new SqlAliasRef(alias);
- if (this.UseConverterStrategy(ConverterStrategy.SkipWithRowNumber)) {
- // use ROW_NUMBER() (preferred)
- SqlColumn rowNumber = new SqlColumn("ROW_NUMBER", sql.RowNumber(new List<SqlOrderExpression>(), this.dominatingExpression));
- SqlColumnRef rowNumberRef = new SqlColumnRef(rowNumber);
- select.Row.Columns.Add(rowNumber);
- SqlSelect final = new SqlSelect(aref, alias, this.dominatingExpression);
- if (takeExp != null) {
- // use BETWEEN for skip+take combo (much faster)
- final.Where = sql.Between(
- rowNumberRef,
- sql.Add(skipExp, 1),
- sql.Binary(SqlNodeType.Add, (SqlExpression)SqlDuplicator.Copy(skipExp), takeExp),
- this.dominatingExpression
- );
- }
- else {
- final.Where = sql.Binary(SqlNodeType.GT, rowNumberRef, skipExp);
- }
- return final;
- }
- else {
- // Ensure that the sequence contains elements that can be skipped
- if (!CanSkipOnSelection(select.Selection)) {
- throw Error.SkipNotSupportedForSequenceTypes();
- }
- // use NOT EXISTS
- // Supported cases:
- // - Entities
- // - Projections that contain all PK columns
- //
- // .. where there sequence can be traced back to a:
- // - Single-table query
- // - Distinct
- // - Except
- // - Intersect
- // - Union, where union.All == false
- // Not supported: joins
- // Sequence should also be ordered, but we can't test for it at this
- // point in processing, and we won't know that we need to test it, later.
- SingleTableQueryVisitor stqv = new SingleTableQueryVisitor();
- stqv.Visit(select);
- if (!stqv.IsValid) {
- throw Error.SkipRequiresSingleTableQueryWithPKs();
- }
- SqlSelect dupsel = (SqlSelect)SqlDuplicator.Copy(select);
- dupsel.Top = skipExp;
- SqlAlias dupAlias = new SqlAlias(dupsel);
- SqlAliasRef dupRef = new SqlAliasRef(dupAlias);
- SqlSelect eqsel = new SqlSelect(dupRef, dupAlias, this.dominatingExpression);
- eqsel.Where = sql.Binary(SqlNodeType.EQ2V, aref, dupRef);
- SqlSubSelect ss = sql.SubSelect(SqlNodeType.Exists, eqsel);
- SqlSelect final = new SqlSelect(aref, alias, this.dominatingExpression);
- final.Where = sql.Unary(SqlNodeType.Not, ss, this.dominatingExpression);
- final.Top = takeExp;
- return final;
- }
- }
- private SqlNode VisitParameter(ParameterExpression p) {
- SqlExpression sqlExpr;
- if (this.map.TryGetValue(p, out sqlExpr))
- return sqlExpr;
- Expression expr;
- if (this.exprMap.TryGetValue(p, out expr))
- return this.Visit(expr);
- SqlNode nodeToDup;
- if (this.dupMap.TryGetValue(p, out nodeToDup)) {
- SqlDuplicator duplicator = new SqlDuplicator(true);
- return duplicator.Duplicate(nodeToDup);
- }
- throw Error.ParameterNotInScope(p.Name);
- }
- /// <summary>
- /// Translate a call to a table valued function expression into a sql select.
- /// </summary>
- private SqlNode TranslateTableValuedFunction(MethodCallExpression mce, MetaFunction function) {
- // translate method call into sql function call
- List<SqlExpression> sqlParams = GetFunctionParameters(mce, function);
- SqlTableValuedFunctionCall functionCall = sql.TableValuedFunctionCall(function.ResultRowTypes[0].InheritanceRoot, mce.Method.ReturnType, function.MappedName, sqlParams, mce);
- SqlAlias alias = new SqlAlias(functionCall);
- SqlAliasRef aref = new SqlAliasRef(alias);
- // Build default projection
- SqlExpression projection = this.translator.BuildProjection(aref, function.ResultRowTypes[0].InheritanceRoot, this.allowDeferred, null, mce);
- SqlSelect select = new SqlSelect(projection, alias, mce);
- return select;
- }
- /// <summary>
- /// Translate a call to a stored procedure
- /// </summary>
- private SqlNode TranslateStoredProcedureCall(MethodCallExpression mce, MetaFunction function) {
- if (!this.outerNode) {
- throw Error.SprocsCannotBeComposed();
- }
- // translate method call into sql function call
- List<SqlExpression> sqlParams = GetFunctionParameters(mce, function);
- SqlStoredProcedureCall spc = new SqlStoredProcedureCall(function, null, sqlParams, mce);
- Type returnType = mce.Method.ReturnType;
- if (returnType.IsGenericType &&
- (returnType.GetGenericTypeDefinition() == typeof(IEnumerable<>) ||
- returnType.GetGenericTypeDefinition() == typeof(ISingleResult<>))) {
- // Since this is a single rowset returning sproc, we use the one
- // and only root metatype.
- MetaType rowType = function.ResultRowTypes[0].InheritanceRoot;
-
- SqlUserRow rowExp = new SqlUserRow(rowType, this.typeProvider.GetApplicationType((int)ConverterSpecialTypes.Row), spc, mce);
- spc.Projection = this.translator.BuildProjection(rowExp, rowType, this.allowDeferred, null, mce);
- }
- else if (!(
- typeof(IMultipleResults).IsAssignableFrom(returnType)
- || returnType == typeof(int)
- || returnType == typeof(int?)
- )) {
- throw Error.InvalidReturnFromSproc(returnType);
- }
- return spc;
- }
- /// <summary>
- /// Create a list of sql parameters for the specified method call expression,
- /// taking into account any explicit typing applied to the parameters via the
- /// Parameter attribute.
- /// </summary>
- private List<SqlExpression> GetFunctionParameters(MethodCallExpression mce, MetaFunction function) {
- List<SqlExpression> sqlParams = new List<SqlExpression>(mce.Arguments.Count);
- // create sql parameters for each method parameter
- for (int i = 0, n = mce.Arguments.Count; i < n; i++) {
- SqlExpression newParamExpression = this.VisitExpression(mce.Arguments[i]);
- // If the parameter explicitly specifies a type in metadata,
- // use it as the provider type.
- MetaParameter currMetaParam = function.Parameters[i];
- if (!string.IsNullOrEmpty(currMetaParam.DbType)) {
- SqlSimpleTypeExpression typeExpression = newParamExpression as SqlSimpleTypeExpression;
- if (typeExpression != null) {
- // determine provider type, and update the parameter expression
- ProviderType providerType = typeProvider.Parse(currMetaParam.DbType);
- typeExpression.SetSqlType(providerType);
- }
- }
- sqlParams.Add(newParamExpression);
- }
- return sqlParams;
- }
- private SqlUserQuery VisitUserQuery(string query, Expression[] arguments, Type resultType) {
- SqlExpression[] args = new SqlExpression[arguments.Length];
- for (int i = 0, n = args.Length; i < n; i++) {
- args[i] = this.VisitExpression(arguments[i]);
- }
- SqlUserQuery suq = new SqlUserQuery(query, null, args, this.dominatingExpression);
- if (resultType != typeof(void)) {
- Type elementType = TypeSystem.GetElementType(resultType);
- MetaType mType = this.services.Model.GetMetaType(elementType);
- // if the element type is a simple type (int, bool, etc.) we create
- // a single column binding
- if (TypeSystem.IsSimpleType(elementType)) {
- SqlUserColumn col = new SqlUserColumn(elementType, typeProvider.From(elementType), suq, "", false, this.dominatingExpression);
- suq.Columns.Add(col);
- suq.Projection = col;
- }
- else {
- // ... otherwise we generate a default projection
- SqlUserRow rowExp = new SqlUserRow(mType.InheritanceRoot, this.typeProvider.GetApplicationType((int)ConverterSpecialTypes.Row), suq, this.dominatingExpression);
- suq.Projection = this.translator.BuildProjection(rowExp, mType, this.allowDeferred, null, this.dominatingExpression);
- }
- }
- return suq;
- }
- private SqlNode VisitUnary(UnaryExpression u) {
- SqlExpression exp = this.VisitExpression(u.Operand);
- if (u.Method != null) {
- return sql.MethodCall(u.Type, u.Method, null, new SqlExpression[] { exp }, dominatingExpression);
- }
- SqlExpression result = null;
- switch (u.NodeType) {
- case ExpressionType.Negate:
- case ExpressionType.NegateChecked:
- result = sql.Unary(SqlNodeType.Negate, exp, this.dominatingExpression);
- break;
- case ExpressionType.Not:
- if (u.Operand.Type == typeof(bool) || u.Operand.Type == typeof(bool?)) {
- result = sql.Unary(SqlNodeType.Not, exp, this.dominatingExpression);
- }
- else {
- result = sql.Unary(SqlNodeType.BitNot, exp, this.dominatingExpression);
- }
- break;
- case ExpressionType.TypeAs:
- result = sql.Unary(SqlNodeType.Treat, exp, this.dominatingExpression);
- break;
- }
- return result;
- }
- [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
- private SqlNode VisitBinary(BinaryExpression b) {
- SqlExpression left = this.VisitExpression(b.Left);
- SqlExpression right = this.VisitExpression(b.Right);
- if (b.Method != null) {
- return sql.MethodCall(b.Type, b.Method, null, new SqlExpression[] { left, right }, dominatingExpression);
- }
- SqlExpression result = null;
- switch (b.NodeType) {
- case ExpressionType.Add:
- case ExpressionType.AddChecked:
- result = sql.Binary(SqlNodeType.Add, left, right, b.Type);
- break;
- case ExpressionType.Subtract:
- case ExpressionType.SubtractChecked:
- result = sql.Binary(SqlNodeType.Sub, left, right, b.Type);
- break;
- case ExpressionType.Multiply:
- case ExpressionType.MultiplyChecked:
- result = sql.Binary(SqlNodeType.Mul, left, right, b.Type);
- break;
- case ExpressionType.Divide:
- result = sql.Binary(SqlNodeType.Div, left, right, b.Type);
- break;
- case ExpressionType.Modulo:
- result = sql.Binary(SqlNodeType.Mod, left, right, b.Type);
- break;
- case ExpressionType.And:
- if (b.Left.Type == typeof(bool) || b.Left.Type == typeof(bool?)) {
- result = sql.Binary(SqlNodeType.And, left, right, b.Type);
- }
- else {
- result = sql.Binary(SqlNodeType.BitAnd, left, right, b.Type);
- }
- break;
- case ExpressionType.AndAlso:
- result = sql.Binary(SqlNodeType.And, left, right, b.Type);
- break;
- case ExpressionType.Or:
- if (b.Left.Type == typeof(bool) || b.Left.Type == typeof(bool?)) {
- result = sql.Binary(SqlNodeType.Or, left, right, b.Type);
- }
- else {
- result = sql.Binary(SqlNodeType.BitOr, left, right, b.Type);
- }
- break;
- case ExpressionType.OrElse:
- result = sql.Binary(SqlNodeType.Or, left, right, b.Type);
- break;
- case ExpressionType.LessThan:
- result = sql.Binary(SqlNodeType.LT, left, right, b.Type);
- break;
- case ExpressionType.LessThanOrEqual:
- result = sql.Binary(SqlNodeType.LE, left, right, b.Type);
- break;
- case ExpressionType.GreaterThan:
- result = sql.Binary(SqlNodeType.GT, left, right, b.Type);
- break;
- case ExpressionType.GreaterThanOrEqual:
- result = sql.Binary(SqlNodeType.GE, left, right, b.Type);
- break;
- case ExpressionType.Equal:
- result = sql.Binary(SqlNodeType.EQ, left, right, b.Type);
- break;
- case ExpressionType.NotEqual:
- result = sql.Binary(SqlNodeType.NE, left, right, b.Type);
- break;
- case ExpressionType.ExclusiveOr:
- result = sql.Binary(SqlNodeType.BitXor, left, right, b.Type);
- break;
- case ExpressionType.Coalesce:
- result = this.MakeCoalesce(left, right, b.Type);
- break;
- default:
- throw Error.BinaryOperatorNotRecognized(b.NodeType);
- }
- return result;
- }
- private SqlExpression MakeCoalesce(SqlExpression left, SqlExpression right, Type resultType) {
- CompensateForLowerPrecedenceOfDateType(ref left, ref right); // DevDiv 176874
- if (TypeSystem.IsSimpleType(resultType)) {
- return sql.Binary(SqlNodeType.Coalesce, left, right, resultType);
- }
- else {
- List<SqlWhen> whens = new List<SqlWhen>(1);
- whens.Add(new SqlWhen(sql.Unary(SqlNodeType.IsNull, left, left.SourceExpression), right));
- SqlDuplicator dup = new SqlDuplicator(true);
- return sql.SearchedCase(whens.ToArray(), (SqlExpression)dup.Duplicate(left), this.dominatingExpression);
- }
- }
- // The result *type* of a COALESCE function call is that of the operand with the highest precedence.
- // However, the SQL DATE type has a lower precedence than DATETIME or SMALLDATETIME, despite having
- // a hihger range. The following logic compensates for that discrepancy.
- //
- private void CompensateForLowerPrecedenceOfDateType(ref SqlExpression left, ref SqlExpression right) {
- if (SqlFactory.IsSqlDateType(left) && SqlFactory.IsSqlDateTimeType(right)) {
- right = (SqlExpression)ConvertDateToDateTime2(right);
- }
- else if (SqlFactory.IsSqlDateType(right) && SqlFactory.IsSqlDateTimeType(left)) {
- left = (SqlExpression)ConvertDateToDateTime2(left);
- }
- }
- private SqlNode VisitConcat(Expression source1, Expression source2) {
- SqlSelect left = this.VisitSequence(source1);
- SqlSelect right = this.VisitSequence(source2);
- SqlUnion union = new SqlUnion(left, right, true);
- SqlAlias alias = new SqlAlias(union);
- SqlAliasRef aref = new SqlAliasRef(alias);
- SqlSelect result = new SqlSelect(aref, alias, this.dominatingExpression);
- result.OrderingType = SqlOrderingType.Blocked;
- return result;
- }
- private SqlNode VisitUnion(Expression source1, Expression source2) {
- SqlSelect left = this.VisitSequence(source1);
- SqlSelect right = this.VisitSequence(source2);
- SqlUnion union = new SqlUnion(left, right, false);
- SqlAlias alias = new SqlAlias(union);
- SqlAliasRef aref = new SqlAliasRef(alias);
- SqlSelect result = new SqlSelect(aref, alias, this.dominatingExpression);
- result.OrderingType = SqlOrderingType.Blocked;
- return result;
- }
- private SqlNode VisitIntersect(Expression source1, Expression source2) {
- Type type = TypeSystem.GetElementType(source1.Type);
- if (IsGrouping(type)) {
- throw Error.IntersectNotSupportedForHierarchicalTypes();
- }
- SqlSelect select1 = this.LockSelect(this.VisitSequence(source1));
- SqlSelect select2 = this.VisitSequence(source2);
- SqlAlias alias1 = new SqlAlias(select1);
- SqlAliasRef aref1 = new SqlAliasRef(alias1);
- SqlAlias alias2 = new SqlAlias(select2);
- SqlAliasRef aref2 = new SqlAliasRef(alias2);
- SqlExpression any = this.GenerateQuantifier(alias2, sql.Binary(SqlNodeType.EQ2V, aref1, aref2), true);
- SqlSelect result = new SqlSelect(aref1, alias1, select1.SourceExpression);
- result.Where = any;
- result.IsDistinct = true;
- result.OrderingType = SqlOrderingType.Blocked;
- return result;
- }
- private SqlNode VisitExcept(Expression source1, Expression source2) {
- Type type = TypeSystem.GetElementType(source1.Type);
- if (IsGrouping(type)) {
- throw Error.ExceptNotSupportedForHierarchicalTypes();
- }
- SqlSelect select1 = this.LockSelect(this.VisitSequence(source1));
- SqlSelect select2 = this.VisitSequence(source2);
- SqlAlias alias1 = new SqlAlias(select1);
- SqlAliasRef aref1 = new SqlAliasRef(alias1);
- SqlAlias alias2 = new SqlAlias(select2);
- SqlAliasRef aref2 = new SqlAliasRef(alias2);
- SqlExpression any = this.GenerateQuantifier(alias2, sql.Binary(SqlNodeType.EQ2V, aref1, aref2), true);
- SqlSelect result = new SqlSelect(aref1, alias1, select1.SourceExpression);
- result.Where = sql.Unary(SqlNodeType.Not, any);
- result.IsDistinct = true;
- result.OrderingType = SqlOrderingType.Blocked;
- return result;
- }
- /// <summary>
- /// Returns true if the type is an IGrouping.
- /// </summary>
- [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
- private bool IsGrouping(Type t) {
- if (t.IsGenericType &&
- t.GetGenericTypeDefinition() == typeof(IGrouping<,>))
- return true;
- return false;
- }
- private SqlSelect VisitOrderBy(Expression sequence, LambdaExpression expression, SqlOrderType orderType) {
- if (IsGrouping(expression.Body.Type)) {
- throw Error.GroupingNotSupportedAsOrderCriterion();
- }
- if (!this.typeProvider.From(expression.Body.Type).IsOrderable) {
- throw Error.TypeCannotBeOrdered(expression.Body.Type);
- }
- SqlSelect select = this.LockSelect(this.VisitSequence(sequence));
- if (select.Selection.NodeType != SqlNodeType.AliasRef || select.OrderBy.Count > 0) {
- SqlAlias alias = new SqlAlias(select);
- SqlAliasRef aref = new SqlAliasRef(alias);
- select = new SqlSelect(aref, alias, this.dominatingExpression);
- }
- this.map[expression.Parameters[0]] = (SqlAliasRef)select.Selection;
- SqlExpression expr = this.VisitExpression(expression.Body);
- select.OrderBy.Add(new SqlOrderExpression(orderType, expr));
- return select;
- }
- private SqlSelect VisitThenBy(Expression sequence, LambdaExpression expression, SqlOrderType orderType) {
- if (IsGrouping(expression.Body.Type)) {
- throw Error.GroupingNotSupportedAsOrderCriterion();
- }
- if (!this.typeProvider.From(expression.Body.Type).IsOrderable) {
- throw Error.TypeCannotBeOrdered(expression.Body.Type);
- }
- SqlSelect select = this.VisitSequence(sequence);
- System.Diagnostics.Debug.Assert(select.Selection.NodeType == SqlNodeType.AliasRef);
- this.map[expression.Parameters[0]] = (SqlAliasRef)select.Selection;
- SqlExpression expr = this.VisitExpression(expression.Body);
- select.OrderBy.Add(new SqlOrderExpression(orderType, expr));
- return select;
- }
- private SqlNode VisitGroupBy(Expression sequence, LambdaExpression keyLambda, LambdaExpression elemLambda, LambdaExpression resultSelector) {
- // Convert seq.Group(elem, key) into
- //
- // SELECT s.key, MULTISET(select s2.elem from seq AS s2 where s.key == s2.key)
- // FROM seq AS s
- //
- // where key and elem can be either simple scalars or object constructions
- //
- SqlSelect seq = this.VisitSequence(sequence);
- seq = this.LockSelect(seq);
- SqlAlias seqAlias = new SqlAlias(seq);
- SqlAliasRef seqAliasRef = new SqlAliasRef(seqAlias);
- // evaluate the key expression relative to original sequence
- this.map[keyLambda.Parameters[0]] = seqAliasRef;
- SqlExpression keyExpr = this.VisitExpression(keyLambda.Body);
- // make a duplicate of the original sequence to use as a foundation of our group multiset
- SqlDuplicator sd = new SqlDuplicator();
- SqlSelect selDup = (SqlSelect)sd.Duplicate(seq);
- // rebind key in relative to the duplicate sequence
- SqlAlias selDupAlias = new SqlAlias(selDup);
- SqlAliasRef selDupRef = new SqlAliasRef(selDupAlias);
- this.map[keyLambda.Parameters[0]] = selDupRef;
- SqlExpression keyDup = this.VisitExpression(keyLambda.Body);
- SqlExpression elemExpr = null;
- SqlExpression elemOnGroupSource = null;
- if (elemLambda != null) {
- // evaluate element expression relative to the duplicate sequence
- this.map[elemLambda.Parameters[0]] = selDupRef;
- elemExpr = this.VisitExpression(elemLambda.Body);
- // evaluate element expression relative to original sequence
- this.map[elemLambda.Parameters[0]] = seqAliasRef;
- elemOnGroupSource = this.VisitExpression(elemLambda.Body);
- }
- else {
- // no elem expression supplied, so just use an alias ref to the duplicate sequence.
- // this will resolve to whatever was being produced by the sequence
- elemExpr = selDupRef;
- elemOnGroupSource = seqAliasRef;
- }
- // Make a sub expression out of the key. This will allow a single definition of the
- // expression to be shared at multiple points in the tree (via SqlSharedExpressionRef's)
- SqlSharedExpression keySubExpr = new SqlSharedExpression(keyExpr);
- keyExpr = new SqlSharedExpressionRef(keySubExpr);
- // construct the select clause that picks out the elements (this may be redundant...)
- SqlSelect selElem = new SqlSelect(elemExpr, selDupAlias, this.dominatingExpression);
- selElem.Where = sql.Binary(SqlNodeType.EQ2V, keyExpr, keyDup);
- // Finally, make the MULTISET node. this will be used as part of the final select
- SqlSubSelect ss = sql.SubSelect(SqlNodeType.Multiset, selElem);
- // add a layer to the original sequence before applying the actual group-by clause
- SqlSelect gsel = new SqlSelect(new SqlSharedExpressionRef(keySubExpr), seqAlias, this.dominatingExpression);
- gsel.GroupBy.Add(keySubExpr);
- SqlAlias gselAlias = new SqlAlias(gsel);
- SqlSelect result = null;
- if (resultSelector != null) {
- // Create final select to include construction of group multiset
- // select new Grouping { Key = key, Group = Multiset(select elem from seq where match) } from ...
- Type elementType = typeof(IGrouping<,>).MakeGenericType(keyExpr.ClrType, elemExpr.ClrType);
- SqlExpression keyGroup = new SqlGrouping(elementType, this.typeProvider.From(elementType), keyExpr, ss, this.dominatingExpression);
- SqlSelect keyGroupSel = new SqlSelect(keyGroup, gselAlias, this.dominatingExpression);
- SqlAlias kgAlias = new SqlAlias(keyGroupSel);
- SqlAliasRef kgAliasRef = new SqlAliasRef(kgAlias);
- this.map[resultSelector.Parameters[0]] = sql.Member(kgAliasRef, elementType.GetProperty("Key"));
- this.map[resultSelector.Parameters[1]] = kgAliasRef;
- // remember the select that has the actual group (for optimizing aggregates later)
- this.gmap[kgAliasRef] = new GroupInfo { SelectWithGroup = gsel, ElementOnGroupSource = elemOnGroupSource };
- SqlExpression resultExpr = this.VisitExpression(resultSelector.Body);
- result = new SqlSelect(resultExpr, kgAlias, this.dominatingExpression);
- // remember the select that has the actual group (for optimizing aggregates later)
- this.gmap[resultExpr] = new GroupInfo { SelectWithGroup = gsel, ElementOnGroupSource = elemOnGroupSource };
- }
- else {
- // Create final select to include construction of group multiset
- // select new Grouping { Key = key, Group = Multiset(select elem from seq where match) } from ...
- Type elementType = typeof(IGrouping<,>).MakeGenericType(keyExpr.ClrType, elemExpr.ClrType);
- SqlExpression resultExpr = new SqlGrouping(elementType, this.typeProvider.From(elementType), keyExpr, ss, this.dominatingExpression);
- result = new SqlSelect(resultExpr, gselAlias, this.dominatingExpression);
- // remember the select that has the actual group (for optimizing aggregates later)
- this.gmap[resultExpr] = new GroupInfo { SelectWithGroup = gsel, ElementOnGroupSource = elemOnGroupSource };
- }
- return result;
- }
- [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification="These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
- private SqlNode VisitAggregate(Expression sequence, LambdaExpression lambda, SqlNodeType aggType, Type returnType) {
- // Convert seq.Agg(exp) into
- //
- // 1) SELECT Agg(exp) FROM seq
- // 2) SELECT Agg1 FROM (SELECT Agg(exp) as Agg1 FROM group-seq GROUP BY ...)
- // 3) SCALAR(SELECT Agg(exp) FROM seq)
- //
- bool isCount = aggType == SqlNodeType.Count || aggType == SqlNodeType.LongCount;
- SqlNode source = this.Visit(sequence);
- SqlSelect select = this.CoerceToSequence(source);
- SqlAlias alias = new SqlAlias(select);
- SqlAliasRef aref = new SqlAliasRef(alias);
- // If the sequence is of the form x.Select(expr).Agg() and the lambda for the aggregate is null,
- // or is a no-op parameter expression (like u=>u), clone the group by selection lambda
- // expression, and use for the aggregate.
- // Final form should be x.Agg(expr)
- MethodCallExpression mce = sequence as MethodCallExpression;
- if (!outerNode && !isCount && (lambda == null || (lambda.Parameters.Count == 1 && lambda.Parameters[0] == lambda.Body)) &&
- (mce != null) && IsSequenceOperatorCall(mce, "Select") && select.From is SqlAlias) {
- LambdaExpression selectionLambda = GetLambda(mce.Arguments[1]);
- lambda = Expression.Lambda(selectionLambda.Type, selectionLambda.Body, selectionLambda.Parameters);
- alias = (SqlAlias)select.From;
- aref = new SqlAliasRef(alias);
- }
- if (lambda != null && !TypeSystem.IsSimpleType(lambda.Body.Type)) {
- throw Error.CannotAggregateType(lambda.Body.Type);
- }
- //Empty parameter aggregates are not allowed on anonymous types
- //i.e. db.Customers.Select(c=>new{c.Age}).Max() instead it should be
- // db.Customers.Select(c=>new{c.Age}).Max(c=>c.Age)
- if (select.Selection.SqlType.IsRuntimeOnlyType && !IsGrouping(sequence.Type) && !isCount && lambda == null) {
- throw Error.NonCountAggregateFunctionsAreNotValidOnProjections(aggType);
- }
- if (lambda != null)
- this.map[lambda.Parameters[0]] = aref;
- if (this.outerNode) {
- // If this aggregate is basically the last/outer-most operator of the query
- //
- // produce SELECT Agg(exp) FROM seq
- //
- SqlExpression exp = (lambda != null) ? this.VisitExpression(lambda.Body) : null;
- SqlExpression where = null;
- if (isCount && exp != null) {
- where = exp;
- exp = null;
- }
- else if (exp == null && !isCount) {
- exp = aref;
- }
- if (exp != null) {
- // in case this contains another aggregate
- exp = new SqlSimpleExpression(exp);
- }
- SqlSelect sel = new SqlSelect(
- this.GetAggregate(aggType, returnType, exp),
- alias,
- this.dominatingExpression
- );
- sel.Where = where;
- sel.OrderingType = SqlOrderingType.Never;
- return sel;
- }
- else if (!isCount || lambda == null) {
- // Look to optimize aggregate by pushing its evaluation down to the select node that has the
- // actual group-by operator.
- //
- // Produce: SELECT Agg1 FROM (SELECT Agg(exp) as Agg1 FROM seq GROUP BY ...)
- //
- GroupInfo info = this.FindGroupInfo(source);
- if (info != null) {
- SqlExpression exp = null;
- if (lambda != null) {
- // evaluate expression relative to the group-by select node
- this.map[lambda.Parameters[0]] = (SqlExpression)SqlDuplicator.Copy(info.ElementOnGroupSource);
- exp = this.VisitExpression(lambda.Body);
- } else if (!isCount) {
- // support aggregates w/o an explicit selector specified
- exp = info.ElementOnGroupSource;
- }
- if (exp != null) {
- // in case this contains another aggregate
- exp = new SqlSimpleExpression(exp);
- }
- SqlExpression agg = this.GetAggregate(aggType, returnType, exp);
- SqlColumn c = new SqlColumn(agg.ClrType, agg.SqlType, null, null, agg, this.dominatingExpression);
- info.SelectWithGroup.Row.Columns.Add(c);
- return new SqlColumnRef(c);
- }
- }
- // Otherwise, if we cannot optimize then fall back to generating a nested aggregate in a correlated sub query
- //
- // SCALAR(SELECT Agg(exp) FROM seq)
- {
- SqlExpression exp = (lambda != null) ? this.VisitExpression(lambda.Body) : null;
- if (exp != null) {
- // in case this contains another aggregate
- exp = new SqlSimpleExpression(exp);
- }
- SqlSelect sel = new SqlSelect(
- this.GetAggregate(aggType, returnType, isCount ? null : (lambda == null) ? aref : exp),
- alias,
- this.dominatingExpression
- );
- sel.Where = isCount ? exp : null;
- return sql.SubSelect(SqlNodeType.ScalarSubSelect, sel);
- }
- }
- private GroupInfo FindGroupInfo(SqlNode source) {
- GroupInfo info = null;
- this.gmap.TryGetValue(source, out info);
- if (info != null) {
- return info;
- }
- SqlAlias alias = source as SqlAlias;
- if (alias != null) {
- SqlSelect select = alias.Node as SqlSelect;
- if (select != null) {
- return this.FindGroupInfo(select.Selection);
- }
- // it might be an expression (not yet fully resolved)
- source = alias.Node;
- }
- SqlExpression expr = source as SqlExpression;
- if (expr != null) {
- switch (expr.NodeType) {
- case SqlNodeType.AliasRef:
- return this.FindGroupInfo(((SqlAliasRef)expr).Alias);
- case SqlNodeType.Member:
- return this.FindGroupInfo(((SqlMember)expr).Expression);
- default:
- this.gmap.TryGetValue(expr, out info);
- return info;
- }
- }
- return null;
- }
- private SqlExpression GetAggregate(SqlNodeType aggType, Type clrType, SqlExpression exp) {
- ProviderType sqlType = this.typeProvider.From(clrType);
- return new SqlUnary(aggType, clrType, sqlType, exp, this.dominatingExpression);
- }
- private SqlNode VisitContains(Expression sequence, Expression value) {
- Type elemType = TypeSystem.GetElementType(sequence.Type);
- SqlNode seqNode = this.Visit(sequence);
- if (seqNode.NodeType == SqlNodeType.ClientArray) {
- SqlClientArray array = (SqlClientArray)seqNode;
- return this.GenerateInExpression(this.VisitExpression(value), array.Expressions);
- }
- else if (seqNode.NodeType == SqlNodeType.Value) {
- IEnumerable values = ((SqlValue)seqNode).Value as IEnumerable;
- IQueryable query = values as IQueryable;
- if (query == null) {
- SqlExpression expr = this.VisitExpression(value);
- List<SqlExpression> list = values.OfType<object>().Select(v => sql.ValueFromObject(v, elemType, true, this.dominatingExpression)).ToList();
- return this.GenerateInExpression(expr, list);
- }
- seqNode = this.Visit(query.Expression);
- }
- ParameterExpression p = Expression.Parameter(value.Type, "p");
- LambdaExpression lambda = Expression.Lambda(Expression.Equal(p, value), p);
- return this.VisitQuantifier(this.CoerceToSequence(seqNode), lambda, true);
- }
- private SqlExpression GenerateInExpression(SqlExpression expr, List<SqlExpression> list) {
- if (list.Count == 0) {
- return sql.ValueFromObject(false, this.dominatingExpression);
- }
- else if (list[0].SqlType.CanBeColumn) {
- return sql.In(expr, list, this.dominatingExpression);
- }
- else {
- SqlExpression pred = sql.Binary(SqlNodeType.EQ, expr, list[0]);
- for (int i = 1, n = list.Count; i < n; i++) {
- pred = sql.Binary(SqlNodeType.Or, pred, sql.Binary(SqlNodeType.EQ, (SqlExpression)SqlDuplicator.Copy(expr), list[i]));
- }
- return pred;
- }
- }
- private SqlNode VisitQuantifier(Expression sequence, LambdaExpression lambda, bool isAny) {
- return this.VisitQuantifier(this.VisitSequence(sequence), lambda, isAny);
- }
- private SqlNode VisitQuantifier(SqlSelect select, LambdaExpression lambda, bool isAny) {
- SqlAlias alias = new SqlAlias(select);
- SqlAliasRef aref = new SqlAliasRef(alias);
- if (lambda != null) {
- this.map[lambda.Parameters[0]] = aref;
- }
- SqlExpression cond = lambda != null ? this.VisitExpression(lambda.Body) : null;
- return this.GenerateQuantifier(alias, cond, isAny);
- }
- private SqlExpression GenerateQuantifier(SqlAlias alias, SqlExpression cond, bool isAny) {
- SqlAliasRef aref = new SqlAliasRef(alias);
- if (isAny) {
- SqlSelect sel = new SqlSelect(aref, alias, this.dominatingExpression);
- sel.Where = cond;
- sel.OrderingType = SqlOrderingType.Never;
- SqlSubSelect exists = sql.SubSelect(SqlNodeType.Exists, sel);
- return exists;
- }
- else {
- SqlSelect sel = new SqlSelect(aref, alias, this.dominatingExpression);
- SqlSubSelect ss = sql.SubSelect(SqlNodeType.Exists, sel);
- sel.Where = sql.Unary(SqlNodeType.Not2V, cond, this.dominatingExpression);
- return sql.Unary(SqlNodeType.Not, ss, this.dominatingExpression);
- }
- }
- private void CheckContext(SqlExpression expr) {
- // try to catch use of incorrect context if we can
- SqlValue value = expr as SqlValue;
- if (value != null) {
- DataContext dc = value.Value as DataContext;
- if (dc != null) {
- if (dc != this.services.Context) {
- throw Error.WrongDataContext();
- }
- }
- }
- }
- private SqlNode VisitMemberAccess(MemberExpression ma) {
- Type memberType = TypeSystem.GetMemberType(ma.Member);
- if (memberType.IsGenericType && memberType.GetGenericTypeDefinition() == typeof(Table<>)) {
- Type rowType = memberType.GetGenericArguments()[0];
- CheckContext(this.VisitExpression(ma.Expression));
- ITable table = this.services.Context.GetTable(rowType);
- if (table != null)
- return this.Visit(Expression.Constant(table));
- }
- if (ma.Member.Name == "Count" && TypeSystem.IsSequenceType(ma.Expression.Type)) {
- return this.VisitAggregate(ma.Expression, null, SqlNodeType.Count, typeof(int));
- }
- return sql.Member(VisitExpression(ma.Expression), ma.Member);
- }
- [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
- private SqlNode VisitMethodCall(MethodCallExpression mc) {
- Type declType = mc.Method.DeclaringType;
- if (mc.Method.IsStatic) {
- if (this.IsSequenceOperatorCall(mc)) {
- return this.VisitSequenceOperatorCall(mc);
- }
- else if (IsDataManipulationCall(mc)) {
- return this.VisitDataManipulationCall(mc);
- }
- // why is this handled here and not in SqlMethodCallConverter?
- else if (declType == typeof(DBConvert) || declType == typeof(Convert)) {
- if (mc.Method.Name == "ChangeType") {
- SqlNode sn = null;
- if (mc.Arguments.Count == 2) {
- object value = GetValue(mc.Arguments[1], "ChangeType");
- if (value != null && typeof(Type).IsAssignableFrom(value.GetType())) {
- sn = this.VisitChangeType(mc.Arguments[0], (Type)value);
- }
- }
- if(sn == null) {
- throw Error.MethodFormHasNoSupportConversionToSql(mc.Method.Name, mc.Method);
- }
- return sn;
- }
- }
- }
- else if (typeof(DataContext).IsAssignableFrom(mc.Method.DeclaringType)) {
- switch (mc.Method.Name) {
- case "GetTable": {
- // calls to GetTable<T> can be translated directly as table references
- if (mc.Method.IsGenericMethod) {
- Type[] typeArgs = mc.Method.GetGenericArguments();
- if (typeArgs.Length == 1 && mc.Method.GetParameters().Length == 0) {
- CheckContext(this.VisitExpression(mc.Object));
- ITable table = this.services.Context.GetTable(typeArgs[0]);
- if (table != null) {
- return this.Visit(Expression.Constant(table));
- }
- }
- }
- break;
- }
- case "ExecuteCommand":
- case "ExecuteQuery":
- return this.VisitUserQuery((string)GetValue(mc.Arguments[0], mc.Method.Name), GetArray(mc.Arguments[1]), mc.Type);
- }
- if (this.IsMappedFunctionCall(mc)) {
- return this.VisitMappedFunctionCall(mc);
- }
- }
- else if (
- mc.Method.DeclaringType != typeof(string)
- && mc.Method.Name == "Contains"
- && !mc.Method.IsStatic
- && typeof(IList).IsAssignableFrom(mc.Method.DeclaringType)
- && mc.Type == typeof(bool)
- && mc.Arguments.Count == 1
- && TypeSystem.GetElementType(mc.Method.DeclaringType).IsAssignableFrom(mc.Arguments[0].Type)
- ) {
- return this.VisitContains(mc.Object, mc.Arguments[0]);
- }
- // default: create sql method call node instead
- SqlExpression obj = VisitExpression(mc.Object);
- SqlExpression[] args = new SqlExpression[mc.Arguments.Count];
- for (int i = 0, n = args.Length; i < n; i++) {
- args[i] = VisitExpression(mc.Arguments[i]);
- }
- return sql.MethodCall(mc.Method, obj, args, dominatingExpression);
- }
- private object GetValue(Expression expression, string operation) {
- SqlExpression exp = this.VisitExpression(expression);
- if (exp.NodeType == SqlNodeType.Value) {
- return ((SqlValue)exp).Value;
- }
- throw Error.NonConstantExpressionsNotSupportedFor(operation);
- }
- private static Expression[] GetArray(Expression array) {
- NewArrayExpression n = array as NewArrayExpression;
- if (n != null) {
- return n.Expressions.ToArray();
- }
- ConstantExpression c = array as ConstantExpression;
- if (c != null) {
- object[] obs = c.Value as object[];
- if (obs != null) {
- Type elemType = TypeSystem.GetElementType(c.Type);
- return obs.Select(o => Expression.Constant(o, elemType)).ToArray();
- }
- }
- return new Expression[] { };
- }
- [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
- private Expression RemoveQuotes(Expression expression) {
- while (expression.NodeType == ExpressionType.Quote) {
- expression = ((UnaryExpression)expression).Operand;
- }
- return expression;
- }
- private bool IsLambda(Expression expression) {
- return this.RemoveQuotes(expression).NodeType == ExpressionType.Lambda;
- }
- private LambdaExpression GetLambda(Expression expression) {
- return this.RemoveQuotes(expression) as LambdaExpression;
- }
- private bool IsMappedFunctionCall(MethodCallExpression mc) {
- MetaFunction function = services.Model.GetFunction(mc.Method);
- return function != null;
- }
- private SqlNode VisitMappedFunctionCall(MethodCallExpression mc) {
- // See if the method maps to a user defined function
- MetaFunction function = services.Model.GetFunction(mc.Method);
- System.Diagnostics.Debug.Assert(function != null);
- CheckContext(this.VisitExpression(mc.Object));
- if (!function.IsComposable) {
- return this.TranslateStoredProcedureCall(mc, function);
- }
- else if (function.ResultRowTypes.Count > 0) {
- return this.TranslateTableValuedFunction(mc, function);
- }
- else {
- ProviderType sqlType = function.ReturnParameter != null && !string.IsNullOrEmpty(function.ReturnParameter.DbType)
- ? this.typeProvider.Parse(function.ReturnParameter.DbType)
- : this.typeProvider.From(mc.Method.ReturnType);
- List<SqlExpression> sqlParams = this.GetFunctionParameters(mc, function);
- return sql.FunctionCall(mc.Method.ReturnType, sqlType, function.MappedName, sqlParams, mc);
- }
- }
- [SuppressMessage("Microsoft.Performance", "CA1822:MarkMembersAsStatic", Justification="Unknown reason.")]
- private bool IsSequenceOperatorCall(MethodCallExpression mc) {
- Type declType = mc.Method.DeclaringType;
- if (declType == typeof(System.Linq.Enumerable) ||
- declType == typeof(System.Linq.Queryable)) {
- return true;
- }
- return false;
- }
- private bool IsSequenceOperatorCall(MethodCallExpression mc, string methodName) {
- if (IsSequenceOperatorCall(mc) && mc.Method.Name == methodName) {
- return true;
- }
- return false;
- }
- [SuppressMessage("Microsoft.Maintainability", "CA1505:AvoidUnmaintainableCode", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
- [SuppressMessage("Microsoft.Maintainability", "CA1502:AvoidExcessiveComplexity", Justification = "These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
- private SqlNode VisitSequenceOperatorCall(MethodCallExpression mc) {
- Type declType = mc.Method.DeclaringType;
- bool isSupportedSequenceOperator = false;
- if (IsSequenceOperatorCall(mc)) {
- switch (mc.Method.Name) {
- case "Select":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitSelect(mc.Arguments[0], this.GetLambda(mc.Arguments[1]));
- }
- break;
- case "SelectMany":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitSelectMany(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), null);
- }
- else if (mc.Arguments.Count == 3 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1 &&
- this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 2) {
- return this.VisitSelectMany(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), this.GetLambda(mc.Arguments[2]));
- }
- break;
- case "Join":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 5 &&
- this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 1 &&
- this.IsLambda(mc.Arguments[3]) && this.GetLambda(mc.Arguments[3]).Parameters.Count == 1 &&
- this.IsLambda(mc.Arguments[4]) && this.GetLambda(mc.Arguments[4]).Parameters.Count == 2) {
- return this.VisitJoin(mc.Arguments[0], mc.Arguments[1], this.GetLambda(mc.Arguments[2]), this.GetLambda(mc.Arguments[3]), this.GetLambda(mc.Arguments[4]));
- }
- break;
- case "GroupJoin":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 5 &&
- this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 1 &&
- this.IsLambda(mc.Arguments[3]) && this.GetLambda(mc.Arguments[3]).Parameters.Count == 1 &&
- this.IsLambda(mc.Arguments[4]) && this.GetLambda(mc.Arguments[4]).Parameters.Count == 2) {
- return this.VisitGroupJoin(mc.Arguments[0], mc.Arguments[1], this.GetLambda(mc.Arguments[2]), this.GetLambda(mc.Arguments[3]), this.GetLambda(mc.Arguments[4]));
- }
- break;
- case "DefaultIfEmpty":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitDefaultIfEmpty(mc.Arguments[0]);
- }
- break;
- case "OfType":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- Type ofType = mc.Method.GetGenericArguments()[0];
- return this.VisitOfType(mc.Arguments[0], ofType);
- }
- break;
- case "Cast":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- Type type = mc.Method.GetGenericArguments()[0];
- return this.VisitSequenceCast(mc.Arguments[0], type);
- }
- break;
- case "Where":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitWhere(mc.Arguments[0], this.GetLambda(mc.Arguments[1]));
- }
- break;
- case "First":
- case "FirstOrDefault":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitFirst(mc.Arguments[0], null, true);
- }
- else if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitFirst(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), true);
- }
- break;
- case "Single":
- case "SingleOrDefault":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitFirst(mc.Arguments[0], null, false);
- }
- else if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitFirst(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), false);
- }
- break;
- case "Distinct":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitDistinct(mc.Arguments[0]);
- }
- break;
- case "Concat":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2) {
- return this.VisitConcat(mc.Arguments[0], mc.Arguments[1]);
- }
- break;
- case "Union":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2) {
- return this.VisitUnion(mc.Arguments[0], mc.Arguments[1]);
- }
- break;
- case "Intersect":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2) {
- return this.VisitIntersect(mc.Arguments[0], mc.Arguments[1]);
- }
- break;
- case "Except":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2) {
- return this.VisitExcept(mc.Arguments[0], mc.Arguments[1]);
- }
- break;
- case "Any":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitQuantifier(mc.Arguments[0], null, true);
- }
- else if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitQuantifier(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), true);
- }
- break;
- case "All":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitQuantifier(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), false);
- }
- break;
- case "Count":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.Count, mc.Type);
- }
- else if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.Count, mc.Type);
- }
- break;
- case "LongCount":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.LongCount, mc.Type);
- }
- else if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.LongCount, mc.Type);
- }
- break;
- case "Sum":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.Sum, mc.Type);
- }
- else if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.Sum, mc.Type);
- }
- break;
- case "Min":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.Min, mc.Type);
- }
- else if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.Min, mc.Type);
- }
- break;
- case "Max":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.Max, mc.Type);
- }
- else if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.Max, mc.Type);
- }
- break;
- case "Average":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], null, SqlNodeType.Avg, mc.Type);
- }
- else if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitAggregate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlNodeType.Avg, mc.Type);
- }
- break;
- case "GroupBy":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitGroupBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), null, null);
- }
- else if (mc.Arguments.Count == 3 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1 &&
- this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 1) {
- return this.VisitGroupBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), this.GetLambda(mc.Arguments[2]), null);
- }
- else if (mc.Arguments.Count == 3 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1 &&
- this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 2) {
- return this.VisitGroupBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), null, this.GetLambda(mc.Arguments[2]));
- }
- else if (mc.Arguments.Count == 4 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1 &&
- this.IsLambda(mc.Arguments[2]) && this.GetLambda(mc.Arguments[2]).Parameters.Count == 1 &&
- this.IsLambda(mc.Arguments[3]) && this.GetLambda(mc.Arguments[3]).Parameters.Count == 2) {
- return this.VisitGroupBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), this.GetLambda(mc.Arguments[2]), this.GetLambda(mc.Arguments[3]));
- }
- break;
- case "OrderBy":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitOrderBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlOrderType.Ascending);
- }
- break;
- case "OrderByDescending":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitOrderBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlOrderType.Descending);
- }
- break;
- case "ThenBy":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitThenBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlOrderType.Ascending);
- }
- break;
- case "ThenByDescending":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2 &&
- this.IsLambda(mc.Arguments[1]) && this.GetLambda(mc.Arguments[1]).Parameters.Count == 1) {
- return this.VisitThenBy(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), SqlOrderType.Descending);
- }
- break;
- case "Take":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2) {
- return this.VisitTake(mc.Arguments[0], mc.Arguments[1]);
- }
- break;
- case "Skip":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2) {
- return this.VisitSkip(mc.Arguments[0], mc.Arguments[1]);
- }
- break;
- case "Contains":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 2) {
- return this.VisitContains(mc.Arguments[0], mc.Arguments[1]);
- }
- break;
- case "ToList":
- case "AsEnumerable":
- case "ToArray":
- isSupportedSequenceOperator = true;
- if (mc.Arguments.Count == 1) {
- return this.Visit(mc.Arguments[0]);
- }
- break;
- }
- // If the operator is supported, but the particular overload is not,
- // give an appropriate error message
- if (isSupportedSequenceOperator) {
- throw Error.QueryOperatorOverloadNotSupported(mc.Method.Name);
- }
- throw Error.QueryOperatorNotSupported(mc.Method.Name);
- }
- else {
- throw Error.InvalidSequenceOperatorCall(declType);
- }
- }
- private static bool IsDataManipulationCall(MethodCallExpression mc) {
- return mc.Method.IsStatic && mc.Method.DeclaringType == typeof(DataManipulation);
- }
- private SqlNode VisitDataManipulationCall(MethodCallExpression mc) {
- if (IsDataManipulationCall(mc)) {
- bool isSupportedDML = false;
- switch (mc.Method.Name) {
- case "Insert":
- isSupportedDML = true;
- if (mc.Arguments.Count == 2) {
- return this.VisitInsert(mc.Arguments[0], this.GetLambda(mc.Arguments[1]));
- }
- else if (mc.Arguments.Count == 1) {
- return this.VisitInsert(mc.Arguments[0], null);
- }
- break;
- case "Update":
- isSupportedDML = true;
- if (mc.Arguments.Count == 3) {
- return this.VisitUpdate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), this.GetLambda(mc.Arguments[2]));
- }
- else if (mc.Arguments.Count == 2) {
- if (mc.Method.GetGenericArguments().Length == 1) {
- return this.VisitUpdate(mc.Arguments[0], this.GetLambda(mc.Arguments[1]), null);
- }
- else {
- return this.VisitUpdate(mc.Arguments[0], null, this.GetLambda(mc.Arguments[1]));
- }
- }
- else if (mc.Arguments.Count == 1) {
- return this.VisitUpdate(mc.Arguments[0], null, null);
- }
- break;
- case "Delete":
- isSupportedDML = true;
- if (mc.Arguments.Count == 2) {
- return this.VisitDelete(mc.Arguments[0], this.GetLambda(mc.Arguments[1]));
- }
- else if (mc.Arguments.Count == 1) {
- return this.VisitDelete(mc.Arguments[0], null);
- }
- break;
- }
- if (isSupportedDML) {
- throw Error.QueryOperatorOverloadNotSupported(mc.Method.Name);
- }
- throw Error.QueryOperatorNotSupported(mc.Method.Name);
- }
- throw Error.InvalidSequenceOperatorCall(mc.Method.Name);
- }
- private SqlNode VisitFirst(Expression sequence, LambdaExpression lambda, bool isFirst) {
- SqlSelect select = this.LockSelect(this.VisitSequence(sequence));
- if (lambda != null) {
- this.map[lambda.Parameters[0]] = (SqlAliasRef)select.Selection;
- select.Where = this.VisitExpression(lambda.Body);
- }
- if (isFirst) {
- select.Top = this.sql.ValueFromObject(1, false, this.dominatingExpression);
- }
- if (this.outerNode) {
- return select;
- }
- SqlNodeType subType = (this.typeProvider.From(select.Selection.ClrType).CanBeColumn) ? SqlNodeType.ScalarSubSelect : SqlNodeType.Element;
- SqlSubSelect elem = sql.SubSelect(subType, select, sequence.Type);
- return elem;
- }
- [SuppressMessage("Microsoft.Maintainability", "CA1506:AvoidExcessiveClassCoupling", Justification="These issues are related to our use of if-then and case statements for node types, which adds to the complexity count however when reviewed they are easy to navigate and understand.")]
- private SqlStatement VisitInsert(Expression item, LambdaExpression resultSelector) {
- if (item == null) {
- throw Error.ArgumentNull("item");
- }
- this.dominatingExpression = item;
- MetaTable metaTable = this.services.Model.GetTable(item.Type);
- Expression source = this.services.Context.GetTable(metaTable.RowType.Type).Expression;
- MetaType itemMetaType = null;
- SqlNew sqlItem = null;
- // construct insert assignments from 'item' info
- ConstantExpression conItem = item as ConstantExpression;
- if (conItem == null) {
- throw Error.InsertItemMustBeConstant();
- }
- if (conItem.Value == null) {
- throw Error.ArgumentNull("item");
- }
- // construct insert based on constant value
- List<SqlMemberAssign> bindings = new List<SqlMemberAssign>();
- itemMetaType = metaTable.RowType.GetInheritanceType(conItem.Value.GetType());
- SqlExpression sqlExprItem = sql.ValueFromObject(conItem.Value, true, source);
- foreach (MetaDataMember mm in itemMetaType.PersistentDataMembers) {
- if (!mm.IsAssociation && !mm.IsDbGenerated && !mm.IsVersion) {
- bindings.Add(new SqlMemberAssign(mm.Member, sql.Member(sqlExprItem, mm.Member)));
- }
- }
- ConstructorInfo cons = itemMetaType.Type.GetConstructor(Type.EmptyTypes);
- System.Diagnostics.Debug.Assert(cons != null);
- sqlItem = sql.New(itemMetaType, cons, null, null, bindings, item);
- SqlTable tab = sql.Table(metaTable, metaTable.RowType, this.dominatingExpression);
- SqlInsert sin = new SqlInsert(tab, sqlItem, item);
- if (resultSelector == null) {
- return sin;
- }
- else {
- MetaDataMember id = itemMetaType.DBGeneratedIdentityMember;
- bool isDbGenOnly = false;
- if (id != null) {
- isDbGenOnly = this.IsDbGeneratedKeyProjectionOnly(resultSelector.Body, id);
- if (id.Type == typeof(Guid) && (this.converterStrategy & ConverterStrategy.CanOutputFromInsert) != 0) {
- sin.OutputKey = new SqlColumn(id.Type, sql.Default(id), id.Name, id, null, this.dominatingExpression);
- if (!isDbGenOnly) {
- sin.OutputToLocal = true;
- }
- }
- }
- SqlSelect result = null;
- SqlSelect preResult = null;
- SqlAlias tableAlias = new SqlAlias(tab);
- SqlAliasRef tableAliasRef = new SqlAliasRef(tableAlias);
- System.Diagnostics.Debug.Assert(resultSelector.Parameters.Count == 1);
- this.map.Add(resultSelector.Parameters[0], tableAliasRef);
- SqlExpression projection = this.VisitExpression(resultSelector.Body);
- // build select to return result
- SqlExpression pred = null;
- if (id != null) {
- pred = sql.Binary(
- SqlNodeType.EQ,
- sql.Member(tableAliasRef, id.Member),
- this.GetIdentityExpression(id, sin.OutputKey != null)
- );
- }
- else {
- SqlExpression itemExpression = this.VisitExpression(item);
- pred = sql.Binary(SqlNodeType.EQ2V, tableAliasRef, itemExpression);
- }
- result = new SqlSelect(projection, tableAlias, resultSelector);
- result.Where = pred;
- // Since we're only projecting back a single generated key, we can
- // optimize the query to a simple selection (e.g. SELECT @@IDENTITY)
- // rather than selecting back from the table.
- if (id != null && isDbGenOnly) {
- if (sin.OutputKey == null) {
- SqlExpression exp = this.GetIdentityExpression(id, false);
- if (exp.ClrType != id.Type) {
- ProviderType sqlType = sql.Default(id);
- exp = sql.ConvertTo(id.Type, sqlType, exp);
- }
- // The result selector passed in was bound to the table -
- // we need to rebind to the single result as an array projection
- ParameterExpression p = Expression.Parameter(id.Type, "p");
- Expression[] init = new Expression[1] { Expression.Convert(p, typeof(object)) };
- NewArrayExpression arrExp = Expression.NewArrayInit(typeof(object), init);
- LambdaExpression rs = Expression.Lambda(arrExp, p);
- this.map.Add(p, exp);
- SqlExpression proj = this.VisitExpression(rs.Body);
- preResult = new SqlSelect(proj, null, rs);
- }
- else {
- // case handled in formatter automatically
- }
- result.DoNotOutput = true;
- }
- // combine insert & result into block
- SqlBlock block = new SqlBlock(this.dominatingExpression);
- block.Statements.Add(sin);
- if (preResult != null) {
- block.Statements.Add(preResult);
- }
- block.Statements.Add(result);
- return block;
- }
- }
- private bool IsDbGeneratedKeyProjectionOnly(Expression projection, MetaDataMember keyMember) {
- NewArrayExpression array = projection as NewArrayExpression;
- if (array != null && array.Expressions.Count == 1) {
- Expression exp = array.Expressions[0];
- while (exp.NodeType == ExpressionType.Convert || exp.NodeType == ExpressionType.ConvertChecked) {
- exp = ((UnaryExpression)exp).Operand;
- }
- MemberExpression mex = exp as MemberExpression;
- if (mex != null && mex.Member == keyMember.Member) {
- return true;
- }
- }
- return false;
- }
- private SqlExpression GetIdentityExpression(MetaDataMember id, bool isOutputFromInsert) {
- if (isOutputFromInsert) {
- return new SqlVariable(id.Type, sql.Default(id), "@id", this.dominatingExpression);
- }
- else {
- ProviderType sqlType = sql.Default(id);
- if (!IsLegalIdentityType(sqlType.GetClosestRuntimeType())) {
- throw Error.InvalidDbGeneratedType(sqlType.ToQueryString());
- }
- if ((this.converterStrategy & ConverterStrategy.CanUseScopeIdentity) != 0) {
- return new SqlVariable(typeof(decimal), typeProvider.From(typeof(decimal)), "SCOPE_IDENTITY()", this.dominatingExpression);
- }
- else {
- return new SqlVariable(typeof(decimal), typeProvider.From(typeof(decimal)), "@@IDENTITY", this.dominatingExpression);
- }
- }
- }
- private static bool IsLegalIdentityType(Type type) {
- switch (Type.GetTypeCode(type)) {
- case TypeCode.SByte:
- case TypeCode.Int16:
- case TypeCode.Int32:
- case TypeCode.Int64:
- case TypeCode.Decimal:
- return true;
- }
- return false;
- }
- private SqlExpression GetRowCountExpression() {
- if ((this.converterStrategy & ConverterStrategy.CanUseRowStatus) != 0) {
- return new SqlVariable(typeof(decimal), typeProvider.From(typeof(decimal)), "@@ROWCOUNT", this.dominatingExpression);
- }
- else {
- return new SqlVariable(typeof(decimal), typeProvider.From(typeof(decimal)), "@ROWCOUNT", this.dominatingExpression);
- }
- }
- private SqlStatement VisitUpdate(Expression item, LambdaExpression check, LambdaExpression resultSelector) {
- if (item == null) {
- throw Error.ArgumentNull("item");
- }
- MetaTable metaTable = this.services.Model.GetTable(item.Type);
- Expression source = this.services.Context.GetTable(metaTable.RowType.Type).Expression;
- Type rowType = metaTable.RowType.Type;
- bool saveAllowDeferred = this.allowDeferred;
- this.allowDeferred = false;
- try {
- Expression seq = source;
- // construct identity predicate based on supplied item
- ParameterExpression p = Expression.Parameter(rowType, "p");
- LambdaExpression idPredicate = Expression.Lambda(Expression.Equal(p, item), p);
- // combine predicate and check expression into single find predicate
- LambdaExpression findPredicate = idPredicate;
- if (check != null) {
- findPredicate = Expression.Lambda(Expression.And(Expression.Invoke(findPredicate, p), Expression.Invoke(check, p)), p);
- }
- seq = Expression.Call(typeof(Enumerable), "Where", new Type[] { rowType }, seq, findPredicate);
- // source 'query' is based on table + find predicate
- SqlSelect ss = new RetypeCheckClause().VisitSelect(this.VisitSequence(seq));
- // construct update assignments from 'item' info
- List<SqlAssign> assignments = new List<SqlAssign>();
- ConstantExpression conItem = item as ConstantExpression;
- if (conItem == null) {
- throw Error.UpdateItemMustBeConstant();
- }
- if (conItem.Value == null) {
- throw Error.ArgumentNull("item");
- }
- // get changes from data services to construct update command
- Type entityType = conItem.Value.GetType();
- MetaType metaType = this.services.Model.GetMetaType(entityType);
- ITable table = this.services.Context.GetTable(metaType.InheritanceRoot.Type);
- foreach (ModifiedMemberInfo mmi in table.GetModifiedMembers(conItem.Value)) {
- MetaDataMember mdm = metaType.GetDataMember(mmi.Member);
- assignments.Add(
- new SqlAssign(
- sql.Member(ss.Selection, mmi.Member),
- new SqlValue(mdm.Type, this.typeProvider.From(mdm.Type), mmi.CurrentValue, true, source),
- source
- ));
- }
- SqlUpdate upd = new SqlUpdate(ss, assignments, source);
- if (resultSelector == null) {
- return upd;
- }
- SqlSelect select = null;
- // build select to return result
- seq = source;
- seq = Expression.Call(typeof(Enumerable), "Where", new Type[] { rowType }, seq, idPredicate);
- seq = Expression.Call(typeof(Enumerable), "Select", new Type[] { rowType, resultSelector.Body.Type }, seq, resultSelector);
- select = this.VisitSequence(seq);
- select.Where = sql.AndAccumulate(
- sql.Binary(SqlNodeType.GT, this.GetRowCountExpression(), sql.ValueFromObject(0, false, this.dominatingExpression)),
- select.Where
- );
- // combine update & select into statement block
- SqlBlock block = new SqlBlock(source);
- block.Statements.Add(upd);
- block.Statements.Add(select);
- return block;
- }
- finally {
- this.allowDeferred = saveAllowDeferred;
- }
- }
- private SqlStatement VisitDelete(Expression item, LambdaExpression check) {
- if (item == null) {
- throw Error.ArgumentNull("item");
- }
- bool saveAllowDeferred = this.allowDeferred;
- this.allowDeferred = false;
- try {
- MetaTable metaTable = this.services.Model.GetTable(item.Type);
- Expression source = this.services.Context.GetTable(metaTable.RowType.Type).Expression;
- Type rowType = metaTable.RowType.Type;
- // construct identity predicate based on supplied item
- ParameterExpression p = Expression.Parameter(rowType, "p");
- LambdaExpression idPredicate = Expression.Lambda(Expression.Equal(p, item), p);
- // combine predicate and check expression into single find predicate
- LambdaExpression findPredicate = idPredicate;
- if (check != null) {
- findPredicate = Expression.Lambda(Expression.And(Expression.Invoke(findPredicate, p), Expression.Invoke(check, p)), p);
- }
- Expression seq = Expression.Call(typeof(Enumerable), "Where", new Type[] { rowType }, source, findPredicate);
- SqlSelect ss = new RetypeCheckClause().VisitSelect(this.VisitSequence(seq));
- this.allowDeferred = saveAllowDeferred;
- SqlDelete sd = new SqlDelete(ss, source);
- return sd;
- }
- finally {
- this.allowDeferred = saveAllowDeferred;
- }
- }
- private class RetypeCheckClause : SqlVisitor {
- internal override SqlExpression VisitMethodCall(SqlMethodCall mc) {
- if (mc.Arguments.Count==2 && mc.Method.Name=="op_Equality") {
- var r = mc.Arguments[1];
- if (r.NodeType == SqlNodeType.Value) {
- var v = (SqlValue)r;
- v.SetSqlType(mc.Arguments[0].SqlType);
- }
- }
- return base.VisitMethodCall(mc);
- }
- }
- private SqlExpression VisitNewArrayInit(NewArrayExpression arr) {
- SqlExpression[] exprs = new SqlExpression[arr.Expressions.Count];
- for (int i = 0, n = exprs.Length; i < n; i++) {
- exprs[i] = this.VisitExpression(arr.Expressions[i]);
- }
- return new SqlClientArray(arr.Type, this.typeProvider.From(arr.Type), exprs, this.dominatingExpression);
- }
- private SqlExpression VisitListInit(ListInitExpression list) {
- if (null != list.NewExpression.Constructor && 0 != list.NewExpression.Arguments.Count) {
- // Throw existing exception for unrecognized expressions if list
- // init does not use a default constructor.
- throw Error.UnrecognizedExpressionNode(list.NodeType);
- }
- SqlExpression[] exprs = new SqlExpression[list.Initializers.Count];
- for (int i = 0, n = exprs.Length; i < n; i++) {
- if (1 != list.Initializers[i].Arguments.Count) {
- // Throw existing exception for unrecognized expressions if element
- // init is not adding a single element.
- throw Error.UnrecognizedExpressionNode(list.NodeType);
- }
- exprs[i] = this.VisitExpression(list.Initializers[i].Arguments.Single());
- }
- return new SqlClientArray(list.Type, this.typeProvider.From(list.Type), exprs, this.dominatingExpression);
- }
- }
- class SingleTableQueryVisitor : SqlVisitor {
- public bool IsValid;
- bool IsDistinct;
- List<MemberInfo> IdentityMembers;
- void AddIdentityMembers(IEnumerable<MemberInfo> members) {
- System.Diagnostics.Debug.Assert(this.IdentityMembers == null, "We already have a set of keys -- why are we adding more?");
- this.IdentityMembers = new List<MemberInfo>(members);
- }
- internal SingleTableQueryVisitor(): base() {
- this.IsValid = true;
- }
- internal override SqlNode Visit(SqlNode node) {
- // recurse until we know we're invalid
- if (this.IsValid && node != null) {
- return base.Visit(node);
- }
- return node;
- }
- internal override SqlTable VisitTable(SqlTable tab) {
- // if we're distinct, we don't care about joins
- if (this.IsDistinct) {
- return tab;
- }
- if (this.IdentityMembers != null) {
- this.IsValid = false;
- } else {
- this.AddIdentityMembers(tab.MetaTable.RowType.IdentityMembers.Select(m => m.Member));
- }
- return tab;
- }
- internal override SqlSource VisitSource(SqlSource source) {
- return base.VisitSource(source);
- }
- internal override SqlSelect VisitSelect(SqlSelect select) {
- if (select.IsDistinct) {
- this.IsDistinct = true;
- // get all members from selection
- this.AddIdentityMembers(select.Selection.ClrType.GetProperties());
- return select;
- }
- //
- //
- //
- //
- //
- // We're not distinct, but let's check our sources...
- select.From = (SqlSource)base.Visit(select.From);
- if (this.IdentityMembers == null || this.IdentityMembers.Count == 0) {
- throw Error.SkipRequiresSingleTableQueryWithPKs();
- }
- else {
- switch (select.Selection.NodeType) {
- case SqlNodeType.Column:
- case SqlNodeType.ColumnRef:
- case SqlNodeType.Member: {
- // we've got a bare member/column node, eg "select c.CustomerId"
- // find out if it refers to the table's PK, of which there must be only 1
- if (this.IdentityMembers.Count == 1) {
- MemberInfo column = this.IdentityMembers[0];
- this.IsValid &= IsColumnMatch(column, select.Selection);
- }
- else {
- this.IsValid = false;
- }
- break;
- }
- case SqlNodeType.New:
- case SqlNodeType.AliasRef: {
- select.Selection = this.VisitExpression(select.Selection);
- break;
- }
- case SqlNodeType.Treat:
- case SqlNodeType.TypeCase: {
- break;
- }
- default: {
- this.IsValid = false;
- break;
- }
- }
- }
- return select;
- }
- //
- //
- //
- //
- //
- internal override SqlExpression VisitNew(SqlNew sox) {
- // check the args for the PKs
- foreach (MemberInfo column in this.IdentityMembers) {
- // assume we're invalid unless we find a matching argument which is
- // a bare column/columnRef to the PK
- bool isMatch = false;
- // find a matching arg
- foreach (SqlExpression expr in sox.Args) {
- isMatch = IsColumnMatch(column, expr);
- if (isMatch) {
- break;
- }
- }
- if (!isMatch) {
- foreach (SqlMemberAssign ma in sox.Members) {
- SqlExpression expr = ma.Expression;
- isMatch = IsColumnMatch(column, expr);
- if (isMatch) {
- break;
- }
- }
- }
- this.IsValid &= isMatch;
- if (!this.IsValid) {
- break;
- }
- }
- return sox;
- }
- internal override SqlNode VisitUnion(SqlUnion su) {
- // we don't want to descend inward
- // just check that it's not a UNION ALL
- if (su.All) {
- this.IsValid = false;
- }
- // UNIONs are distinct
- this.IsDistinct = true;
- // get all members from selection
- this.AddIdentityMembers(su.GetClrType().GetProperties());
- return su;
- }
- private static bool IsColumnMatch(MemberInfo column, SqlExpression expr) {
- MemberInfo memberInfo = null;
- switch (expr.NodeType) {
- case SqlNodeType.Column: {
- memberInfo = ((SqlColumn)expr).MetaMember.Member;
- break;
- }
- case SqlNodeType.ColumnRef: {
- memberInfo = (((SqlColumnRef)expr).Column).MetaMember.Member;
- break;
- }
- case SqlNodeType.Member: {
- memberInfo = ((SqlMember)expr).Member;
- break;
- }
- }
- return (memberInfo != null && memberInfo == column);
- }
- }
- }