PageRenderTime 57ms CodeModel.GetById 36ms app.highlight 15ms RepoModel.GetById 1ms app.codeStats 0ms

/src/LinqToExcel/Query/WhereClauseExpressionTreeVisitor.cs

https://github.com/jwcarroll/LinqToExcel
C# | 269 lines | 221 code | 25 blank | 23 comment | 22 complexity | 49992dabcdefd4ff4fbeda17930b4318 MD5 | raw file
  1using System;
  2using System.Collections.Generic;
  3using System.Linq;
  4using System.Text;
  5using Remotion.Data.Linq.Parsing;
  6using System.Data.OleDb;
  7using System.Linq.Expressions;
  8using LinqToExcel.Extensions;
  9using Remotion.Data.Linq.Clauses.Expressions;
 10
 11namespace LinqToExcel.Query
 12{
 13    public class WhereClauseExpressionTreeVisitor : ThrowingExpressionTreeVisitor
 14    {
 15        private readonly StringBuilder _whereClause = new StringBuilder();
 16        private readonly List<OleDbParameter> _params = new List<OleDbParameter>();
 17        private readonly Dictionary<string, string> _columnMapping;
 18        private readonly List<string> _columnNamesUsed = new List<string>();
 19        private readonly Type _sheetType;
 20        private readonly List<string> _validStringMethods;
 21
 22        public WhereClauseExpressionTreeVisitor(Type sheetType, Dictionary<string, string> columnMapping)
 23        {
 24            _sheetType = sheetType;
 25            _columnMapping = columnMapping;
 26            _validStringMethods = new List<string>() {
 27                "Equals",
 28                "Contains",
 29                "StartsWith",
 30                "EndsWith" };
 31        }
 32
 33        public string WhereClause
 34        {
 35            get { return _whereClause.ToString(); }
 36        }
 37
 38        public IEnumerable<OleDbParameter> Params
 39        {
 40            get { return _params; }
 41        }
 42
 43        public IEnumerable<string> ColumnNamesUsed
 44        {
 45            get { return _columnNamesUsed.Select(x => x.Replace("[", "").Replace("]", "")); }
 46        }
 47
 48        public void Visit(Expression expression)
 49        {
 50            base.VisitExpression(expression);
 51        }
 52
 53        protected override Exception CreateUnhandledItemException<T>(T unhandledItem, string visitMethod)
 54        {
 55            throw new NotImplementedException(visitMethod + " method is not implemented");
 56        }
 57
 58        protected override Expression VisitBinaryExpression(BinaryExpression bExp)
 59        {
 60            _whereClause.Append("(");
 61
 62            // Patch for vb.net expression that are always considered a MethodCallExpression even if they are not.
 63            // see http://www.re-motion.org/blogs/mix/archive/2009/10/16/vb.net-specific-text-comparison-in-linq-queries.aspx
 64            bExp = ConvertVbStringCompare(bExp);
 65
 66            //We always want the MemberAccess (ColumnName) to be on the left side of the statement
 67            var bLeft = bExp.Left;
 68            var bRight = bExp.Right;
 69            if ((bExp.Right.NodeType == ExpressionType.MemberAccess) &&
 70                (((MemberExpression)bExp.Right).Member.DeclaringType == _sheetType))
 71            {
 72                bLeft = bExp.Right;
 73                bRight = bExp.Left;
 74            }
 75
 76            VisitExpression(bLeft);
 77            switch (bExp.NodeType)
 78            {
 79                case ExpressionType.AndAlso:
 80                    _whereClause.Append(" AND ");
 81                    break;
 82                case ExpressionType.Equal:
 83                    if (bRight.IsNullValue())
 84                        _whereClause.Append(" IS NULL");
 85                    else
 86                        _whereClause.Append(" = ");
 87                    break;
 88                case ExpressionType.GreaterThan:
 89                    _whereClause.Append(" > ");
 90                    break;
 91                case ExpressionType.GreaterThanOrEqual:
 92                    _whereClause.Append(" >= ");
 93                    break;
 94                case ExpressionType.LessThan:
 95                    _whereClause.Append(" < ");
 96                    break;
 97                case ExpressionType.LessThanOrEqual:
 98                    _whereClause.Append(" <= ");
 99                    break;
100                case ExpressionType.NotEqual:
101                    if (bRight.IsNullValue())
102                        _whereClause.Append(" IS NOT NULL");
103                    else
104                        _whereClause.Append(" <> ");
105                    break;
106                case ExpressionType.OrElse:
107                    _whereClause.Append(" OR ");
108                    break;
109                default:
110                    throw new NotSupportedException(string.Format("{0} statement is not supported", bExp.NodeType.ToString()));
111            }
112            if (!bRight.IsNullValue())
113                VisitExpression(bRight);
114            _whereClause.Append(")");
115            return bExp;
116        }
117
118        protected BinaryExpression ConvertVbStringCompare(BinaryExpression exp)
119        {
120            if (exp.Left.NodeType == ExpressionType.Call)
121            {
122                var compareStringCall = (MethodCallExpression)exp.Left;
123                if (compareStringCall.Method.DeclaringType.FullName == "Microsoft.VisualBasic.CompilerServices.Operators" && compareStringCall.Method.Name == "CompareString")
124                {
125                    var arg1 = compareStringCall.Arguments[0];
126                    var arg2 = compareStringCall.Arguments[1];
127
128                    switch (exp.NodeType)
129                    {
130                        case ExpressionType.LessThan:
131                            return Expression.LessThan(arg1, arg2);
132                        case ExpressionType.LessThanOrEqual:
133                            return Expression.LessThanOrEqual(arg1, arg2);
134                        case ExpressionType.GreaterThan:
135                            return Expression.GreaterThan(arg1, arg2);
136                        case ExpressionType.GreaterThanOrEqual:
137                            return Expression.GreaterThanOrEqual(arg1, arg2);
138                        case ExpressionType.NotEqual:
139                            return Expression.NotEqual(arg1, arg2);
140                        default:
141                            return Expression.Equal(arg1, arg2);
142                    }
143                }
144            }
145            return exp;
146        }
147
148        protected override Expression VisitMemberExpression(MemberExpression mExp)
149        {
150            //Set the column name to the property mapping if there is one, 
151            //else use the property name for the column name
152            var columnName = (_columnMapping.ContainsKey(mExp.Member.Name)) ? 
153                _columnMapping[mExp.Member.Name] : 
154                mExp.Member.Name;
155            _whereClause.AppendFormat("[{0}]", columnName);
156            _columnNamesUsed.Add(columnName);
157            return mExp;
158        }
159
160        protected override Expression VisitConstantExpression(ConstantExpression cExp)
161        {
162            _params.Add(new OleDbParameter("?", cExp.Value));
163            _whereClause.Append("?");
164            return cExp;
165        }
166
167        /// <summary>
168        /// This method is visited when the LinqToExcel.Row type is used in the query
169        /// </summary>
170        protected override Expression VisitUnaryExpression(UnaryExpression uExp)
171        {
172            var columnName = GetColumnName(uExp.Operand);
173            _whereClause.Append(columnName);
174            return uExp;
175        }
176
177        /// <summary>
178        /// Only As<>() method calls on the LinqToExcel.Row type are support
179        /// </summary>
180        protected override Expression VisitMethodCallExpression(MethodCallExpression mExp)
181        {
182            if (_validStringMethods.Contains(mExp.Method.Name))
183                ProcessStringMethod(mExp);
184            else
185            {
186                var columnName = GetColumnName(mExp);
187                _whereClause.Append(columnName);
188                _columnNamesUsed.Add(columnName);
189            }
190            return mExp;
191        }
192
193        private void ProcessStringMethod(MethodCallExpression mExp)
194        {
195            switch (mExp.Method.Name)
196            {
197                case "Contains":
198                    AddStringMethodToWhereClause(mExp, "LIKE", "%{0}%");
199                    break;
200                case "StartsWith":
201                    AddStringMethodToWhereClause(mExp, "LIKE", "{0}%");
202                    break;
203                case "EndsWith":
204                    AddStringMethodToWhereClause(mExp, "LIKE", "%{0}");
205                    break;
206                case "Equals":
207                    AddStringMethodToWhereClause(mExp, "=", "{0}");
208                    break;
209            }
210        }
211
212        private void AddStringMethodToWhereClause(MethodCallExpression mExp, string operatorString, string argumentFormat)
213        {
214            _whereClause.Append("(");
215            VisitExpression(mExp.Object);
216            _whereClause.AppendFormat(" {0} ?)", operatorString);
217
218            var value = mExp.Arguments.First().ToString().Replace("\"", "");
219            var parameter = string.Format(argumentFormat, value);
220            _params.Add(new OleDbParameter("?", parameter));
221        }
222
223        /// <summary>
224        /// Retrieves the column name from a member expression or method call expression
225        /// </summary>
226        /// <param name="exp">Expression</param>
227        private string GetColumnName(Expression exp)
228        {
229            if (exp is MemberExpression)
230                return GetColumnName((MemberExpression)exp);
231            else
232                return GetColumnName((MethodCallExpression)exp);
233        }
234
235        /// <summary>
236        /// Retrieves the column name from a member expression
237        /// </summary>
238        /// <param name="mExp">Member Expression</param>
239        private string GetColumnName(MemberExpression mExp)
240        {
241            return "[" + mExp.Member.Name + "]";
242        }
243
244        /// <summary>
245        /// Retrieves the column name from a method call expression
246        /// </summary>
247        /// <param name="exp">Method Call Expression</param>
248        private string GetColumnName(MethodCallExpression mExp)
249        {
250            MethodCallExpression method = mExp;
251            if (mExp.Object is MethodCallExpression)
252                method = (MethodCallExpression)mExp.Object;
253
254            var arg = method.Arguments.First();
255            if (arg.Type == typeof(int))
256            {
257                if (_sheetType == typeof(RowNoHeader))
258                    return string.Format("F{0}", Int32.Parse(arg.ToString()) + 1);
259                else
260                    throw new ArgumentException("Can only use column indexes in WHERE clause when using WorksheetNoHeader");
261            }
262
263            var columnName = arg.ToString().ToCharArray();
264            columnName[0] = "[".ToCharArray().First();
265            columnName[columnName.Length - 1] = "]".ToCharArray().First();
266            return new string(columnName);
267        }
268    }
269}