PageRenderTime 56ms CodeModel.GetById 2ms app.highlight 43ms RepoModel.GetById 1ms app.codeStats 1ms

/mcs/tools/sqlmetal/src/DbLinq/Data/Linq/DataContext.cs

http://github.com/mono/mono
C# | 1282 lines | 918 code | 158 blank | 206 comment | 147 complexity | fecb956f56965df978a192a53dd217ad MD5 | raw file

Large files files are truncated, but you can click here to view the full file

   1#region MIT license
   2// 
   3// MIT license
   4//
   5// Copyright (c) 2007-2008 Jiri Moudry, Pascal Craponne
   6// 
   7// Permission is hereby granted, free of charge, to any person obtaining a copy
   8// of this software and associated documentation files (the "Software"), to deal
   9// in the Software without restriction, including without limitation the rights
  10// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  11// copies of the Software, and to permit persons to whom the Software is
  12// furnished to do so, subject to the following conditions:
  13// 
  14// The above copyright notice and this permission notice shall be included in
  15// all copies or substantial portions of the Software.
  16// 
  17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  18// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  19// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  20// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  21// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  22// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  23// THE SOFTWARE.
  24// 
  25#endregion
  26
  27using System;
  28using System.Collections;
  29using System.Data;
  30using System.Data.Common;
  31using System.Data.Linq;
  32using System.Data.Linq.Mapping;
  33using System.Linq.Expressions;
  34using System.Collections.Generic;
  35using System.IO;
  36using System.Linq;
  37using System.Reflection;
  38using System.Reflection.Emit;
  39
  40#if MONO_STRICT
  41using AttributeMappingSource  = System.Data.Linq.Mapping.AttributeMappingSource;
  42#else
  43using AttributeMappingSource  = DbLinq.Data.Linq.Mapping.AttributeMappingSource;
  44#endif
  45
  46using DbLinq;
  47using DbLinq.Data.Linq;
  48using DbLinq.Data.Linq.Database;
  49using DbLinq.Data.Linq.Database.Implementation;
  50using DbLinq.Data.Linq.Identity;
  51using DbLinq.Data.Linq.Implementation;
  52using DbLinq.Data.Linq.Mapping;
  53using DbLinq.Data.Linq.Sugar;
  54using DbLinq.Factory;
  55using DbLinq.Util;
  56using DbLinq.Vendor;
  57
  58#if MONO_STRICT
  59namespace System.Data.Linq
  60#else
  61namespace DbLinq.Data.Linq
  62#endif
  63{
  64    public partial class DataContext : IDisposable
  65    {
  66        //private readonly Dictionary<string, ITable> _tableMap = new Dictionary<string, ITable>();
  67        private readonly Dictionary<Type, ITable> _tableMap = new Dictionary<Type, ITable>();
  68
  69        public MetaModel Mapping { get; private set; }
  70        // PC question: at ctor, we get a IDbConnection and the Connection property exposes a DbConnection
  71        //              WTF?
  72        public DbConnection Connection { get { return DatabaseContext.Connection as DbConnection; } }
  73
  74        // all properties below are set public to optionally be injected
  75        internal IVendor Vendor { get; set; }
  76        internal IQueryBuilder QueryBuilder { get; set; }
  77        internal IQueryRunner QueryRunner { get; set; }
  78        internal IMemberModificationHandler MemberModificationHandler { get; set; }
  79        internal IDatabaseContext DatabaseContext { get; private set; }
  80        // /all properties...
  81
  82        private bool objectTrackingEnabled = true;
  83        private bool deferredLoadingEnabled = true;
  84
  85        private bool queryCacheEnabled = false;
  86
  87        /// <summary>
  88        /// Disable the QueryCache: this is surely good for rarely used Select, since preparing
  89        /// the SelectQuery to be cached could require more time than build the sql from scratch.
  90        /// </summary>
  91        [DBLinqExtended]
  92        public bool QueryCacheEnabled 
  93        {
  94            get { return queryCacheEnabled; }
  95            set { queryCacheEnabled = value; }
  96        }
  97
  98        private IEntityTracker currentTransactionEntities;
  99        private IEntityTracker CurrentTransactionEntities
 100        {
 101            get
 102            {
 103                if (this.currentTransactionEntities == null)
 104                {
 105                    if (this.ObjectTrackingEnabled)
 106                        this.currentTransactionEntities = new EntityTracker();
 107                    else
 108                        this.currentTransactionEntities = new DisabledEntityTracker();
 109                }
 110                return this.currentTransactionEntities;
 111            }
 112        }
 113
 114        private IEntityTracker allTrackedEntities;
 115        private IEntityTracker AllTrackedEntities
 116        {
 117            get
 118            {
 119                if (this.allTrackedEntities == null)
 120                {
 121                    allTrackedEntities = ObjectTrackingEnabled
 122                        ? (IEntityTracker) new EntityTracker()
 123                        : (IEntityTracker) new DisabledEntityTracker();
 124                }
 125                return this.allTrackedEntities;
 126            }
 127        }
 128
 129        private IIdentityReaderFactory identityReaderFactory;
 130        private readonly IDictionary<Type, IIdentityReader> identityReaders = new Dictionary<Type, IIdentityReader>();
 131
 132        /// <summary>
 133        /// The default behavior creates one MappingContext.
 134        /// </summary>
 135        [DBLinqExtended]
 136        internal virtual MappingContext _MappingContext { get; set; }
 137
 138        [DBLinqExtended]
 139        internal IVendorProvider _VendorProvider { get; set; }
 140
 141        public DataContext(IDbConnection connection, MappingSource mapping)
 142        {
 143            Profiler.At("START DataContext(IDbConnection, MappingSource)");
 144            Init(new DatabaseContext(connection), mapping, null);
 145            Profiler.At("END DataContext(IDbConnection, MappingSource)");
 146        }
 147
 148        public DataContext(IDbConnection connection)
 149        {
 150            Profiler.At("START DataContext(IDbConnection)");
 151            if (connection == null)
 152                throw new ArgumentNullException("connection");
 153
 154            Init(new DatabaseContext(connection), null, null);
 155            Profiler.At("END DataContext(IDbConnection)");
 156        }
 157
 158        [DbLinqToDo]
 159        public DataContext(string fileOrServerOrConnection, MappingSource mapping)
 160        {
 161            Profiler.At("START DataContext(string, MappingSource)");
 162            if (fileOrServerOrConnection == null)
 163                throw new ArgumentNullException("fileOrServerOrConnection");
 164            if (mapping == null)
 165                throw new ArgumentNullException("mapping");
 166
 167            if (File.Exists(fileOrServerOrConnection))
 168                throw new NotImplementedException("File names not supported.");
 169
 170            // Is this a decent server name check?
 171            // It assumes that the connection string will have at least 2
 172            // parameters (separated by ';')
 173            if (!fileOrServerOrConnection.Contains(";"))
 174                throw new NotImplementedException("Server name not supported.");
 175
 176            // Assume it's a connection string...
 177            IVendor ivendor = GetVendor(ref fileOrServerOrConnection);
 178
 179            IDbConnection dbConnection = ivendor.CreateDbConnection(fileOrServerOrConnection);
 180            Init(new DatabaseContext(dbConnection), mapping, ivendor);
 181            Profiler.At("END DataContext(string, MappingSource)");
 182        }
 183
 184        /// <summary>
 185        /// Construct DataContext, given a connectionString.
 186        /// To determine which DB type to go against, we look for 'DbLinqProvider=xxx' substring.
 187        /// If not found, we assume that we are dealing with MS Sql Server.
 188        /// 
 189        /// Valid values are names of provider DLLs (or any other DLL containing an IVendor implementation)
 190        /// DbLinqProvider=Mysql
 191        /// DbLinqProvider=Oracle etc.
 192        /// </summary>
 193        /// <param name="connectionString">specifies file or server connection</param>
 194        [DbLinqToDo]
 195        public DataContext(string fileOrServerOrConnection)
 196        {
 197            Profiler.At("START DataContext(string)");
 198            IVendor ivendor = GetVendor(ref fileOrServerOrConnection);
 199
 200            IDbConnection dbConnection = ivendor.CreateDbConnection(fileOrServerOrConnection);
 201            Init(new DatabaseContext(dbConnection), null, ivendor);
 202
 203            Profiler.At("END DataContext(string)");
 204        }
 205
 206        private IVendor GetVendor(ref string connectionString)
 207        {
 208            if (connectionString == null)
 209                throw new ArgumentNullException("connectionString");
 210
 211            Assembly assy;
 212            string vendorClassToLoad;
 213            GetVendorInfo(ref connectionString, out assy, out vendorClassToLoad);
 214
 215            var types =
 216                from type in assy.GetTypes()
 217                where type.Name.ToLowerInvariant() == vendorClassToLoad.ToLowerInvariant() &&
 218                    type.GetInterfaces().Contains(typeof(IVendor)) &&
 219                    type.GetConstructor(Type.EmptyTypes) != null
 220                select type;
 221            if (!types.Any())
 222            {
 223                throw new ArgumentException(string.Format("Found no IVendor class in assembly `{0}' named `{1}' having a default constructor.",
 224                    assy.GetName().Name, vendorClassToLoad));
 225            }
 226            else if (types.Count() > 1)
 227            {
 228                throw new ArgumentException(string.Format("Found too many IVendor classes in assembly `{0}' named `{1}' having a default constructor.",
 229                    assy.GetName().Name, vendorClassToLoad));
 230            }
 231            return (IVendor) Activator.CreateInstance(types.First());
 232        }
 233
 234        private void GetVendorInfo(ref string connectionString, out Assembly assembly, out string typeName)
 235        {
 236            System.Text.RegularExpressions.Regex reProvider
 237                = new System.Text.RegularExpressions.Regex(@"DbLinqProvider=([\w\.]+);?");
 238
 239            string assemblyName = null;
 240            string vendor;
 241            if (!reProvider.IsMatch(connectionString))
 242            {
 243                vendor       = "SqlServer";
 244                assemblyName = "DbLinq.SqlServer";
 245            }
 246            else
 247            {
 248                var match    = reProvider.Match(connectionString);
 249                vendor       = match.Groups[1].Value;
 250                assemblyName = "DbLinq." + vendor;
 251
 252                //plain DbLinq - non MONO: 
 253                //IVendor classes are in DLLs such as "DbLinq.MySql.dll"
 254                if (vendor.Contains("."))
 255                {
 256                    //already fully qualified DLL name?
 257                    throw new ArgumentException("Please provide a short name, such as 'MySql', not '" + vendor + "'");
 258                }
 259
 260                //shorten: "DbLinqProvider=X;Server=Y" -> ";Server=Y"
 261                connectionString = reProvider.Replace(connectionString, "");
 262            }
 263
 264            typeName = vendor + "Vendor";
 265
 266            try
 267            {
 268#if MONO_STRICT
 269                assembly = typeof (DataContext).Assembly; // System.Data.Linq.dll
 270#else
 271                assembly = Assembly.Load(assemblyName);
 272#endif
 273            }
 274            catch (Exception e)
 275            {
 276                throw new ArgumentException(
 277                        string.Format(
 278                            "Unable to load the `{0}' DbLinq vendor within assembly '{1}.dll'.",
 279                            assemblyName, vendor),
 280                        "connectionString", e);
 281            }
 282        }
 283
 284        private void Init(IDatabaseContext databaseContext, MappingSource mappingSource, IVendor vendor)
 285        {
 286            if (databaseContext == null)
 287                throw new ArgumentNullException("databaseContext");
 288
 289            // Yes, .NET throws an NRE for this.  Why it's not ArgumentNullException, I couldn't tell you.
 290            if (databaseContext.Connection.ConnectionString == null)
 291                throw new NullReferenceException();
 292
 293            string connectionString = databaseContext.Connection.ConnectionString;
 294            _VendorProvider = ObjectFactory.Get<IVendorProvider>();
 295            Vendor = vendor ?? 
 296                (connectionString != null ? GetVendor(ref connectionString) : null) ??
 297#if MOBILE
 298                _VendorProvider.FindVendorByProviderType(typeof(DbLinq.Sqlite.SqliteSqlProvider));
 299#else
 300                _VendorProvider.FindVendorByProviderType(typeof(SqlClient.Sql2005Provider));
 301#endif
 302            
 303            DatabaseContext = databaseContext;
 304
 305            MemberModificationHandler = ObjectFactory.Create<IMemberModificationHandler>(); // not a singleton: object is stateful
 306            QueryBuilder = ObjectFactory.Get<IQueryBuilder>();
 307            QueryRunner = ObjectFactory.Get<IQueryRunner>();
 308
 309            //EntityMap = ObjectFactory.Create<IEntityMap>();
 310            identityReaderFactory = ObjectFactory.Get<IIdentityReaderFactory>();
 311
 312            _MappingContext = new MappingContext();
 313
 314            // initialize the mapping information
 315            if (mappingSource == null)
 316                mappingSource = new AttributeMappingSource();
 317            Mapping = mappingSource.GetModel(GetType());
 318        }
 319
 320        /// <summary>
 321        /// Checks if the table is allready mapped or maps it if not.
 322        /// </summary>
 323        /// <param name="tableType">Type of the table.</param>
 324        /// <exception cref="InvalidOperationException">Thrown if the table is not mappable.</exception>
 325        private void CheckTableMapping(Type tableType)
 326        {
 327            //This will throw an exception if the table is not found
 328            if(Mapping.GetTable(tableType) == null)
 329            {
 330                throw new InvalidOperationException("The type '" + tableType.Name + "' is not mapped as a Table.");
 331            }
 332        }
 333
 334        /// <summary>
 335        /// Returns a Table for the type TEntity.
 336        /// </summary>
 337        /// <exception cref="InvalidOperationException">If the type TEntity is not mappable as a Table.</exception>
 338        /// <typeparam name="TEntity">The table type.</typeparam>
 339        public Table<TEntity> GetTable<TEntity>() where TEntity : class
 340        {
 341            return (Table<TEntity>)GetTable(typeof(TEntity));
 342        }
 343
 344        /// <summary>
 345        /// Returns a Table for the given type.
 346        /// </summary>
 347        /// <param name="type">The table type.</param>
 348        /// <exception cref="InvalidOperationException">If the type is not mappable as a Table.</exception>
 349        public ITable GetTable(Type type)
 350        {
 351            Profiler.At("DataContext.GetTable(typeof({0}))", type != null ? type.Name : null);
 352            ITable tableExisting;
 353            if (_tableMap.TryGetValue(type, out tableExisting))
 354                return tableExisting;
 355
 356            //Check for table mapping
 357            CheckTableMapping(type);
 358
 359            var tableNew = Activator.CreateInstance(
 360                              typeof(Table<>).MakeGenericType(type)
 361                              , BindingFlags.NonPublic | BindingFlags.Instance
 362                              , null
 363                              , new object[] { this }
 364                              , System.Globalization.CultureInfo.CurrentCulture) as ITable;
 365
 366            _tableMap[type] = tableNew;
 367            return tableNew;
 368        }
 369
 370        public void SubmitChanges()
 371        {
 372            SubmitChanges(ConflictMode.FailOnFirstConflict);
 373        }
 374
 375        /// <summary>
 376        /// Pings database
 377        /// </summary>
 378        /// <returns></returns>
 379        public bool DatabaseExists()
 380        {
 381            try
 382            {
 383                return Vendor.Ping(this);
 384            }
 385            catch (Exception)
 386            {
 387                return false;
 388            }
 389        }
 390
 391        /// <summary>
 392        /// Commits all pending changes to database 
 393        /// </summary>
 394        /// <param name="failureMode"></param>
 395        public virtual void SubmitChanges(ConflictMode failureMode)
 396        {
 397            if (this.objectTrackingEnabled == false)
 398                throw new InvalidOperationException("Object tracking is not enabled for the current data context instance.");
 399            using (DatabaseContext.OpenConnection()) //ConnMgr will close connection for us
 400            {
 401                if (Transaction != null)
 402                    SubmitChangesImpl(failureMode);
 403                else
 404                {
 405                    using (IDbTransaction transaction = DatabaseContext.CreateTransaction())
 406                    {
 407                        try
 408                        {
 409                            Transaction = (DbTransaction) transaction;
 410                            SubmitChangesImpl(failureMode);
 411                            // TODO: handle conflicts (which can only occur when concurrency mode is implemented)
 412                            transaction.Commit();
 413                        }
 414                        finally
 415                        {
 416                            Transaction = null;
 417                        }
 418                    }
 419                }
 420            }
 421        }
 422
 423        void SubmitChangesImpl(ConflictMode failureMode)
 424        {
 425            var queryContext = new QueryContext(this);
 426
 427            // There's no sense in updating an entity when it's going to 
 428            // be deleted in the current transaction, so do deletes first.
 429            foreach (var entityTrack in CurrentTransactionEntities.EnumerateAll().ToList())
 430            {
 431                switch (entityTrack.EntityState)
 432                {
 433                    case EntityState.ToDelete:
 434                        var deleteQuery = QueryBuilder.GetDeleteQuery(entityTrack.Entity, queryContext);
 435                        QueryRunner.Delete(entityTrack.Entity, deleteQuery);
 436
 437                        UnregisterDelete(entityTrack.Entity);
 438                        AllTrackedEntities.RegisterToDelete(entityTrack.Entity);
 439                        AllTrackedEntities.RegisterDeleted(entityTrack.Entity);
 440                        break;
 441                    default:
 442                        // ignore.
 443                        break;
 444                }
 445            }
 446            foreach (var entityTrack in CurrentTransactionEntities.EnumerateAll()
 447                    .Concat(AllTrackedEntities.EnumerateAll())
 448                    .ToList())
 449            {
 450                switch (entityTrack.EntityState)
 451                {
 452                    case EntityState.ToInsert:
 453                        foreach (var toInsert in GetReferencedObjects(entityTrack.Entity))
 454                        {
 455                            InsertEntity(toInsert, queryContext);
 456                        }
 457                        break;
 458                    case EntityState.ToWatch:
 459                        foreach (var toUpdate in GetReferencedObjects(entityTrack.Entity))
 460                        {
 461                            UpdateEntity(toUpdate, queryContext);
 462                        }
 463                        break;
 464                    default:
 465                        throw new ArgumentOutOfRangeException();
 466                }
 467            }
 468        }
 469
 470        private IEnumerable<object> GetReferencedObjects(object value)
 471        {
 472            var values = new EntitySet<object>();
 473            FillReferencedObjects(value, values);
 474            return values;
 475        }
 476
 477        // Breadth-first traversal of an object graph
 478        private void FillReferencedObjects(object parent, EntitySet<object> values)
 479        {
 480            if (parent == null)
 481                return;
 482            var children = new Queue<object>();
 483			children.Enqueue(parent);
 484			while (children.Count > 0)
 485			{
 486                object value = children.Dequeue();
 487                values.Add(value);
 488                IEnumerable<MetaAssociation> associationList = Mapping.GetMetaType(value.GetType()).Associations.Where(a => !a.IsForeignKey);
 489                if (associationList.Any())
 490			    {
 491				    foreach (MetaAssociation association in associationList)
 492                    {
 493                        var memberData = association.ThisMember;
 494                        var entitySetValue = memberData.Member.GetMemberValue(value);
 495
 496                        if (entitySetValue != null)
 497                        {
 498						    var hasLoadedOrAssignedValues = entitySetValue.GetType().GetProperty("HasLoadedOrAssignedValues");
 499						    if (!((bool)hasLoadedOrAssignedValues.GetValue(entitySetValue, null)))
 500							    continue;   // execution deferred; ignore.
 501						    foreach (var o in ((IEnumerable)entitySetValue))
 502							    children.Enqueue(o);
 503					    }
 504                    }
 505                }
 506			}
 507        }
 508
 509        private void InsertEntity(object entity, QueryContext queryContext)
 510        {
 511            var insertQuery = QueryBuilder.GetInsertQuery(entity, queryContext);
 512            QueryRunner.Insert(entity, insertQuery);
 513            Register(entity);
 514            UpdateReferencedObjects(entity);
 515            MoveToAllTrackedEntities(entity, true);
 516        }
 517
 518        private void UpdateEntity(object entity, QueryContext queryContext)
 519        {
 520            if (!AllTrackedEntities.ContainsReference(entity))
 521                InsertEntity(entity, queryContext);
 522            else if (MemberModificationHandler.IsModified(entity, Mapping))
 523            {
 524                var modifiedMembers = MemberModificationHandler.GetModifiedProperties(entity, Mapping);
 525                var updateQuery = QueryBuilder.GetUpdateQuery(entity, modifiedMembers, queryContext);
 526                QueryRunner.Update(entity, updateQuery, modifiedMembers);
 527
 528                RegisterUpdateAgain(entity);
 529                UpdateReferencedObjects(entity);
 530                MoveToAllTrackedEntities(entity, false);
 531            }
 532        }
 533
 534        private void UpdateReferencedObjects(object root)
 535        {
 536            var metaType = Mapping.GetMetaType(root.GetType());
 537            foreach (var assoc in metaType.Associations)
 538            {
 539                var memberData = assoc.ThisMember;
 540				//This is not correct - AutoSyncing applies to auto-updating columns, such as a TimeStamp, not to foreign key associations, which is always automatically synched
 541				//Confirmed against default .NET l2sql - association columns are always set, even if AutoSync==AutoSync.Never
 542				//if (memberData.Association.ThisKey.Any(m => (m.AutoSync != AutoSync.Always) && (m.AutoSync != sync)))
 543                //    continue;
 544                var oks = memberData.Association.OtherKey.Select(m => m.StorageMember).ToList();
 545                if (oks.Count == 0)
 546                    continue;
 547                var pks = memberData.Association.ThisKey
 548                    .Select(m => m.StorageMember.GetMemberValue(root))
 549                    .ToList();
 550                if (pks.Count != oks.Count)
 551                    throw new InvalidOperationException(
 552                        string.Format("Count of primary keys ({0}) doesn't match count of other keys ({1}).",
 553                            pks.Count, oks.Count));
 554                var members = memberData.Member.GetMemberValue(root) as IEnumerable;
 555                if (members == null)
 556                    continue;
 557                foreach (var member in members)
 558                {
 559                    for (int i = 0; i < pks.Count; ++i)
 560                    {
 561                        oks[i].SetMemberValue(member, pks[i]);
 562                    }
 563                }
 564            }
 565        }
 566
 567        private void MoveToAllTrackedEntities(object entity, bool insert)
 568        {
 569            if (!ObjectTrackingEnabled)
 570                return;
 571            if (CurrentTransactionEntities.ContainsReference(entity))
 572            {
 573                CurrentTransactionEntities.RegisterToDelete(entity);
 574                if (!insert)
 575                    CurrentTransactionEntities.RegisterDeleted(entity);
 576            }
 577            if (!AllTrackedEntities.ContainsReference(entity))
 578            {
 579                var identityReader = _GetIdentityReader(entity.GetType());
 580                AllTrackedEntities.RegisterToWatch(entity, identityReader.GetIdentityKey(entity));
 581            }
 582        }
 583
 584        /// <summary>
 585        /// TODO - allow generated methods to call into stored procedures
 586        /// </summary>
 587        [DBLinqExtended]
 588        internal IExecuteResult _ExecuteMethodCall(DataContext context, System.Reflection.MethodInfo method, params object[] sqlParams)
 589        {
 590            using (DatabaseContext.OpenConnection())
 591            {
 592                System.Data.Linq.IExecuteResult result = Vendor.ExecuteMethodCall(context, method, sqlParams);
 593                return result;
 594            }
 595        }
 596
 597        [DbLinqToDo]
 598        protected IExecuteResult ExecuteMethodCall(object instance, System.Reflection.MethodInfo methodInfo, params object[] parameters)
 599        {
 600            throw new NotImplementedException();
 601        }
 602
 603        #region Identity management
 604
 605        [DBLinqExtended]
 606        internal IIdentityReader _GetIdentityReader(Type t)
 607        {
 608            IIdentityReader identityReader;
 609            if (!identityReaders.TryGetValue(t, out identityReader))
 610            {
 611                identityReader = identityReaderFactory.GetReader(t, this);
 612                identityReaders[t] = identityReader;
 613            }
 614            return identityReader;
 615        }
 616
 617        [DBLinqExtended]
 618        internal object _GetRegisteredEntity(object entity)
 619        {
 620            // TODO: check what is faster: by identity or by ref
 621            var identityReader = _GetIdentityReader(entity.GetType());
 622            var identityKey = identityReader.GetIdentityKey(entity);
 623            if (identityKey == null) // if we don't have an entitykey here, it means that the entity has no PK
 624                return entity;
 625            // even 
 626            var registeredEntityTrack = 
 627                CurrentTransactionEntities.FindByIdentity(identityKey) ??
 628                AllTrackedEntities.FindByIdentity(identityKey);
 629            if (registeredEntityTrack != null)
 630                return registeredEntityTrack.Entity;
 631            return null;
 632        }
 633
 634        //internal object GetRegisteredEntityByKey(IdentityKey identityKey)
 635        //{
 636        //    return EntityMap[identityKey];
 637        //}
 638
 639        /// <summary>
 640        /// Registers an entity in a watch state
 641        /// </summary>
 642        /// <param name="entity"></param>
 643        /// <returns></returns>
 644        [DBLinqExtended]
 645        internal object _GetOrRegisterEntity(object entity)
 646        {
 647            var identityReader = _GetIdentityReader(entity.GetType());
 648            var identityKey = identityReader.GetIdentityKey(entity);
 649            SetEntitySetsQueries(entity);
 650            SetEntityRefQueries(entity);
 651
 652            // if we have no identity, we can't track it
 653            if (identityKey == null)
 654                return entity;
 655
 656            // try to find an already registered entity and return it
 657            var registeredEntityTrack = 
 658                CurrentTransactionEntities.FindByIdentity(identityKey) ??
 659                AllTrackedEntities.FindByIdentity(identityKey);
 660            if (registeredEntityTrack != null)
 661                return registeredEntityTrack.Entity;
 662
 663            // otherwise, register and return
 664            AllTrackedEntities.RegisterToWatch(entity, identityKey);
 665            return entity;
 666        }
 667
 668        readonly IDataMapper DataMapper = ObjectFactory.Get<IDataMapper>();
 669		private void SetEntityRefQueries(object entity)
 670		{
 671            if (!this.deferredLoadingEnabled)
 672                return;
 673
 674            // BUG: This is ignoring External Mappings from XmlMappingSource.
 675
 676			Type thisType = entity.GetType();
 677			IEnumerable<MetaAssociation> associationList = Mapping.GetMetaType(entity.GetType()).Associations.Where(a => a.IsForeignKey);
 678			foreach (MetaAssociation association in associationList)
 679			{
 680				//example of entityRef:Order.Employee
 681				var memberData = association.ThisMember;
 682				Type otherTableType = association.OtherType.Type;
 683				ParameterExpression p = Expression.Parameter(otherTableType, "other");
 684
 685				var otherTable = GetTable(otherTableType);
 686
 687				//ie:EmployeeTerritories.EmployeeID
 688				var foreignKeys = memberData.Association.ThisKey;
 689				BinaryExpression predicate = null;
 690				var otherPKs = memberData.Association.OtherKey;
 691				IEnumerator<MetaDataMember> otherPKEnumerator = otherPKs.GetEnumerator();
 692
 693				if (otherPKs.Count != foreignKeys.Count)
 694					throw new InvalidOperationException("Foreign keys don't match ThisKey");
 695				foreach (MetaDataMember key in foreignKeys)
 696				{
 697					otherPKEnumerator.MoveNext();
 698
 699					var thisForeignKeyProperty = (PropertyInfo)key.Member;
 700					object thisForeignKeyValue = thisForeignKeyProperty.GetValue(entity, null);
 701
 702					if (thisForeignKeyValue != null)
 703					{
 704						BinaryExpression keyPredicate;
 705						if (!(thisForeignKeyProperty.PropertyType.IsNullable()))
 706						{
 707							keyPredicate = Expression.Equal(Expression.MakeMemberAccess(p, otherPKEnumerator.Current.Member),
 708																		Expression.Constant(thisForeignKeyValue));
 709						}
 710						else
 711						{
 712							var ValueProperty = thisForeignKeyProperty.PropertyType.GetProperty("Value");
 713							keyPredicate = Expression.Equal(Expression.MakeMemberAccess(p, otherPKEnumerator.Current.Member),
 714																	 Expression.Constant(ValueProperty.GetValue(thisForeignKeyValue, null)));
 715						}
 716
 717						if (predicate == null)
 718							predicate = keyPredicate;
 719						else
 720							predicate = Expression.And(predicate, keyPredicate);
 721					}
 722				}
 723				IEnumerable query = null;
 724				if (predicate != null)
 725				{
 726					query = GetOtherTableQuery(predicate, p, otherTableType, otherTable) as IEnumerable;
 727					//it would be interesting surround the above query with a .Take(1) expression for performance.
 728				}
 729
 730				// If no separate Storage is specified, use the member directly
 731				MemberInfo storage = memberData.StorageMember;
 732				if (storage == null)
 733					storage = memberData.Member;
 734
 735				 // Check that the storage is a field or a writable property
 736				if (!(storage is FieldInfo) && !(storage is PropertyInfo && ((PropertyInfo)storage).CanWrite)) {
 737					throw new InvalidOperationException(String.Format(
 738						"Member {0}.{1} is not a field nor a writable property",
 739						storage.DeclaringType, storage.Name));
 740				}
 741
 742				Type storageType = storage.GetMemberType();
 743
 744				object entityRefValue = null;
 745				if (query != null)
 746					entityRefValue = Activator.CreateInstance(storageType, query);
 747				else
 748					entityRefValue = Activator.CreateInstance(storageType);
 749
 750				storage.SetMemberValue(entity, entityRefValue);
 751			}
 752		}
 753
 754        /// <summary>
 755        /// This method is executed when the entity is being registered. Each EntitySet property has a internal query that can be set using the EntitySet.SetSource method.
 756        /// Here we set the query source of each EntitySetProperty
 757        /// </summary>
 758        /// <param name="entity"></param>
 759        private void SetEntitySetsQueries(object entity)
 760        {
 761            if (!this.deferredLoadingEnabled)
 762                return;
 763
 764            // BUG: This is ignoring External Mappings from XmlMappingSource.
 765
 766			IEnumerable<MetaAssociation> associationList = Mapping.GetMetaType(entity.GetType()).Associations.Where(a => !a.IsForeignKey);
 767
 768			if (associationList.Any())
 769			{
 770				foreach (MetaAssociation association in associationList)
 771                {
 772					//example of entitySet: Employee.EmployeeTerritories
 773					var memberData = association.ThisMember;
 774					Type otherTableType = association.OtherType.Type;
 775                    ParameterExpression p = Expression.Parameter(otherTableType, "other");
 776
 777                    //other table:EmployeeTerritories
 778                    var otherTable = GetTable(otherTableType);
 779
 780					var otherKeys = memberData.Association.OtherKey;
 781					var thisKeys = memberData.Association.ThisKey;
 782                    if (otherKeys.Count != thisKeys.Count)
 783                        throw new InvalidOperationException("This keys don't match OtherKey");
 784                    BinaryExpression predicate = null;
 785                    IEnumerator<MetaDataMember> thisKeyEnumerator = thisKeys.GetEnumerator();
 786					foreach (MetaDataMember otherKey in otherKeys)
 787                    {
 788                        thisKeyEnumerator.MoveNext();
 789                        //other table member:EmployeeTerritories.EmployeeID
 790						var otherTableMember = (PropertyInfo)otherKey.Member;
 791
 792                        BinaryExpression keyPredicate;
 793                        if (!(otherTableMember.PropertyType.IsNullable()))
 794                        {
 795                            keyPredicate = Expression.Equal(Expression.MakeMemberAccess(p, otherTableMember),
 796                                                                        Expression.Constant(thisKeyEnumerator.Current.Member.GetMemberValue(entity)));
 797                        }
 798                        else
 799                        {
 800                            var ValueProperty = otherTableMember.PropertyType.GetProperty("Value");
 801                            keyPredicate = Expression.Equal(Expression.MakeMemberAccess(
 802                                                                        Expression.MakeMemberAccess(p, otherTableMember),
 803                                                                        ValueProperty),
 804                                                                     Expression.Constant(thisKeyEnumerator.Current.Member.GetMemberValue(entity)));
 805                        }
 806                        if (predicate == null)
 807                            predicate = keyPredicate;
 808                        else
 809                            predicate = Expression.And(predicate, keyPredicate);
 810                    }
 811
 812                    var query = GetOtherTableQuery(predicate, p, otherTableType, otherTable);
 813
 814					var entitySetValue = memberData.Member.GetMemberValue(entity);
 815
 816                    if (entitySetValue == null)
 817                    {
 818						entitySetValue = Activator.CreateInstance(memberData.Member.GetMemberType());
 819						memberData.Member.SetMemberValue(entity, entitySetValue);
 820                    }
 821
 822                    var hasLoadedOrAssignedValues = entitySetValue.GetType().GetProperty("HasLoadedOrAssignedValues");
 823                    if ((bool)hasLoadedOrAssignedValues.GetValue(entitySetValue, null))
 824                        continue;
 825
 826                    var setSourceMethod = entitySetValue.GetType().GetMethod("SetSource");
 827                    setSourceMethod.Invoke(entitySetValue, new[] { query });
 828                    //employee.EmployeeTerritories.SetSource(Table[EmployeesTerritories].Where(other=>other.employeeID="WARTH"))
 829                }
 830            }
 831        }
 832
 833		private static MethodInfo _WhereMethod;
 834        internal object GetOtherTableQuery(Expression predicate, ParameterExpression parameter, Type otherTableType, IQueryable otherTable)
 835        {
 836            if (_WhereMethod == null)
 837                System.Threading.Interlocked.CompareExchange (ref _WhereMethod, typeof(Queryable).GetMethods().First(m => m.Name == "Where"), null);
 838
 839            //predicate: other.EmployeeID== "WARTH"
 840            Expression lambdaPredicate = Expression.Lambda(predicate, parameter);
 841            //lambdaPredicate: other=>other.EmployeeID== "WARTH"
 842
 843			Expression call = Expression.Call(_WhereMethod.MakeGenericMethod(otherTableType), otherTable.Expression, lambdaPredicate);
 844            //Table[EmployeesTerritories].Where(other=>other.employeeID="WARTH")
 845
 846            return otherTable.Provider.CreateQuery(call);
 847        }
 848
 849        #endregion
 850
 851        #region Insert/Update/Delete management
 852
 853        /// <summary>
 854        /// Registers an entity for insert
 855        /// </summary>
 856        /// <param name="entity"></param>
 857        internal void RegisterInsert(object entity)
 858        {
 859            CurrentTransactionEntities.RegisterToInsert(entity);
 860        }
 861
 862        private void DoRegisterUpdate(object entity)
 863        {
 864            if (entity == null)
 865                throw new ArgumentNullException("entity");
 866
 867            if (!this.objectTrackingEnabled)
 868                return;
 869
 870            var identityReader = _GetIdentityReader(entity.GetType());
 871            var identityKey = identityReader.GetIdentityKey(entity);
 872            // if we have no key, we can not watch
 873            if (identityKey == null || identityKey.Keys.Count == 0)
 874                return;
 875            // register entity
 876            AllTrackedEntities.RegisterToWatch(entity, identityKey);
 877        }
 878
 879        /// <summary>
 880        /// Registers an entity for update
 881        /// The entity will be updated only if some of its members have changed after the registration
 882        /// </summary>
 883        /// <param name="entity"></param>
 884        internal void RegisterUpdate(object entity)
 885        {
 886            DoRegisterUpdate(entity);
 887			MemberModificationHandler.Register(entity, Mapping);
 888        }
 889
 890        /// <summary>
 891        /// Registers or re-registers an entity and clears its state
 892        /// </summary>
 893        /// <param name="entity"></param>
 894        /// <returns></returns>
 895        internal object Register(object entity)
 896        {
 897            if (! this.objectTrackingEnabled)
 898                return entity;
 899            var registeredEntity = _GetOrRegisterEntity(entity);
 900            // the fact of registering again clears the modified state, so we're... clear with that
 901            MemberModificationHandler.Register(registeredEntity, Mapping);
 902            return registeredEntity;
 903        }
 904
 905        /// <summary>
 906        /// Registers an entity for update
 907        /// The entity will be updated only if some of its members have changed after the registration
 908        /// </summary>
 909        /// <param name="entity"></param>
 910        /// <param name="entityOriginalState"></param>
 911        internal void RegisterUpdate(object entity, object entityOriginalState)
 912        {
 913            if (!this.objectTrackingEnabled)
 914                return;
 915            DoRegisterUpdate(entity);
 916            MemberModificationHandler.Register(entity, entityOriginalState, Mapping);
 917        }
 918
 919        /// <summary>
 920        /// Clears the current state, and marks the object as clean
 921        /// </summary>
 922        /// <param name="entity"></param>
 923        internal void RegisterUpdateAgain(object entity)
 924        {
 925            if (!this.objectTrackingEnabled)
 926                return;
 927            MemberModificationHandler.ClearModified(entity, Mapping);
 928        }
 929
 930        /// <summary>
 931        /// Registers an entity for delete
 932        /// </summary>
 933        /// <param name="entity"></param>
 934        internal void RegisterDelete(object entity)
 935        {
 936            if (!this.objectTrackingEnabled)
 937                return;
 938            CurrentTransactionEntities.RegisterToDelete(entity);
 939        }
 940
 941        /// <summary>
 942        /// Unregisters entity after deletion
 943        /// </summary>
 944        /// <param name="entity"></param>
 945        internal void UnregisterDelete(object entity)
 946        {
 947            if (!this.objectTrackingEnabled)
 948                return;
 949            CurrentTransactionEntities.RegisterDeleted(entity);
 950        }
 951
 952        #endregion
 953
 954        /// <summary>
 955        /// Changed object determine 
 956        /// </summary>
 957        /// <returns>Lists of inserted, updated, deleted objects</returns>
 958        public ChangeSet GetChangeSet()
 959        {
 960            var inserts = new List<object>();
 961            var updates = new List<object>();
 962            var deletes = new List<object>();
 963            foreach (var entityTrack in CurrentTransactionEntities.EnumerateAll()
 964                    .Concat(AllTrackedEntities.EnumerateAll()))
 965            {
 966                switch (entityTrack.EntityState)
 967                {
 968                    case EntityState.ToInsert:
 969                        inserts.Add(entityTrack.Entity);
 970                        break;
 971                    case EntityState.ToWatch:
 972                        if (MemberModificationHandler.IsModified(entityTrack.Entity, Mapping))
 973                            updates.Add(entityTrack.Entity);
 974                        break;
 975                    case EntityState.ToDelete:
 976                        deletes.Add(entityTrack.Entity);
 977                        break;
 978                    default:
 979                        throw new ArgumentOutOfRangeException();
 980                }
 981            }
 982            return new ChangeSet(inserts, updates, deletes);
 983        }
 984
 985        /// <summary>
 986        /// use ExecuteCommand to call raw SQL
 987        /// </summary>
 988        public int ExecuteCommand(string command, params object[] parameters)
 989        {
 990            var directQuery = QueryBuilder.GetDirectQuery(command, new QueryContext(this));
 991            return QueryRunner.Execute(directQuery, parameters);
 992        }
 993
 994        /// <summary>
 995        /// Execute raw SQL query and return object
 996        /// </summary>
 997        public IEnumerable<TResult> ExecuteQuery<TResult>(string query, params object[] parameters) where TResult : new()
 998        {
 999            if (query == null)
1000                throw new ArgumentNullException("query");
1001
1002            return CreateExecuteQueryEnumerable<TResult>(query, parameters);
1003        }
1004
1005        private IEnumerable<TResult> CreateExecuteQueryEnumerable<TResult>(string query, object[] parameters)
1006            where TResult : new()
1007        {
1008            foreach (TResult result in ExecuteQuery(typeof(TResult), query, parameters))
1009                yield return result;
1010        }
1011
1012        public IEnumerable ExecuteQuery(Type elementType, string query, params object[] parameters)
1013        {
1014            if (elementType == null)
1015                throw new ArgumentNullException("elementType");
1016            if (query == null)
1017                throw new ArgumentNullException("query");
1018
1019            var queryContext = new QueryContext(this);
1020            var directQuery = QueryBuilder.GetDirectQuery(query, queryContext);
1021            return QueryRunner.ExecuteSelect(elementType, directQuery, parameters);
1022        }
1023
1024        /// <summary>
1025        /// Gets or sets the load options
1026        /// </summary>
1027        [DbLinqToDo]
1028		public DataLoadOptions LoadOptions
1029		{
1030			get { throw new NotImplementedException(); }
1031			set { throw new NotImplementedException(); }
1032		}
1033
1034        public DbTransaction Transaction {
1035            get { return (DbTransaction) DatabaseContext.CurrentTransaction; }
1036            set { DatabaseContext.CurrentTransaction = value; }
1037        }
1038
1039        /// <summary>
1040        /// Runs the given reader and returns columns.
1041        /// </summary>
1042        /// <typeparam name="TResult">The type of the result.</typeparam>
1043        /// <param name="reader">The reader.</param>
1044        /// <returns></returns>
1045        public IEnumerable<TResult> Translate<TResult>(DbDataReader reader)
1046        {
1047            if (reader == null)
1048                throw new ArgumentNullException("reader");
1049            return CreateTranslateIterator<TResult>(reader);
1050        }
1051
1052        IEnumerable<TResult> CreateTranslateIterator<TResult>(DbDataReader reader)
1053        {
1054            foreach (TResult result in Translate(typeof(TResult), reader))
1055                yield return result;
1056        }
1057
1058        public IMultipleResults Translate(DbDataReader reader)
1059        {
1060            throw new NotImplementedException();
1061        }
1062
1063        public IEnumerable Translate(Type elementType, DbDataReader reader)
1064        {
1065            if (elementType == null)
1066                throw new ArgumentNullException("elementType");
1067            if (reader == null)
1068                throw new ArgumentNullException("reader");
1069
1070            return QueryRunner.EnumerateResult(elementType, reader, this);
1071        }
1072
1073        public void Dispose()
1074        {
1075            //connection closing should not be done here.
1076            //read: http://msdn2.microsoft.com/en-us/library/bb292288.aspx
1077
1078			//We own the instance of MemberModificationHandler - we must unregister listeners of entities we attached to
1079			MemberModificationHandler.UnregisterAll();
1080        }
1081
1082        [DbLinqToDo]
1083        protected virtual void Dispose(bool disposing)
1084        {
1085            throw new NotImplementedException();
1086        }
1087
1088        /// <summary>
1089        /// Creates a IDbDataAdapter. Used internally by Vendors
1090        /// </summary>
1091        /// <returns></returns>
1092        internal IDbDataAdapter CreateDataAdapter()
1093        {
1094            return DatabaseContext.CreateDataAdapter();
1095        }
1096
1097        /// <summary>
1098        /// Sets a TextWriter where generated SQL commands are written
1099        /// </summary>
1100        public TextWriter Log { get; set; }
1101
1102        /// <summary>
1103        /// Writes text on Log (if not null)
1104        /// Internal helper
1105        /// </summary>
1106        /// <param name="text"></param>
1107        internal void WriteLog(string text)
1108        {
1109            if (Log != null)
1110                Log.WriteLine(text);
1111        }
1112
1113        /// <summary>
1114        /// Write an IDbCommand to Log (if non null)
1115        /// </summary>
1116        /// <param name="command"></param>
1117        internal void WriteLog(IDbCommand command)
1118        {
1119            if (Log != null)
1120            {
1121                Log.WriteLine(command.CommandText);
1122                foreach (IDbDataParameter parameter in command.Parameters)
1123                    WriteLog(parameter);
1124                Log.Write("--");
1125                Log.Write(" Context: {0}", Vendor.VendorName);
1126                Log.Write(" Model: {0}", Mapping.GetType().Name);
1127                Log.Write(" Build: {0}", Assembly.GetExecutingAssembly().GetName().Version);
1128                Log.WriteLine();
1129            }
1130        }
1131
1132        /// <summary>
1133        /// Writes and IDbDataParameter to Log (if non null)
1134        /// </summary>
1135        /// <param name="parameter"></param>
1136        internal void WriteLog(IDbDataParameter parameter)
1137        {
1138            if (Log != null)
1139            {
1140                // -- @p0: Input Int (Size = 0; Prec = 0; Scale = 0) [2]
1141                // -- <name>: <direction> <type> (...) [<value>]
1142                Log.WriteLine("-- {0}: {1} {2} (Size = {3}; Prec = {4}; Scale = {5}) [{6}]",
1143                    parameter.ParameterName, parameter.Direction, parameter.DbType,
1144                    parameter.Size, parameter.Precision, parameter.Scale, parameter.Value);
1145            }
1146        }
1147
1148        public bool ObjectTrackingEnabled
1149        {
1150            get { return this.objectTrackingEnabled; }
1151            set 
1152            {
1153                if (this.currentTransactionEntities != null && value != this.objectTrackingEnabled)
1154                    throw new InvalidOperationException("Data context options cannot be modified after results have been returned from a query.");
1155                this.objectTrackingEnabled = value;
1156            }
1157        }
1158
1159        [DbLinqToDo]
1160        public int CommandTimeout
1161        {
1162            get { throw new NotImplementedException(); }
1163            set { throw new NotImplementedException(); }
1164        }
1165
1166        public bool DeferredLoadingEnabled
1167        {
1168            get { return this.deferredLoadingEnabled; }
1169            set
1170            {
1171                if (this.currentTransactionEntities != null && value != this.deferredLoadingEnabled)
1172                    throw new InvalidOperationException("Data context options cannot be modified after results have been returned from a query.");
1173                this.deferredLoadingEnabled = value;
1174            }
1175        }
1176
1177        [DbLinqToDo]
1178        public ChangeConflictCollection ChangeConflicts
1179        {
1180            get { throw new NotImplementedException(); }
1181        }
1182
1183        [DbLinqToDo]
1184        public DbCommand GetCommand(IQueryable query)
1185        {
1186            DbCommand dbCommand = GetIDbCommand(query) as DbCommand;
1187            if (dbCommand == null)
1188                throw new InvalidOperationException();
1189
1190            return dbCommand;
1191        }
1192
1193        [DBLinqExtended]
1194        public IDbCommand GetIDbCommand(IQueryable query)
1195        {
1196            if (query == null)
1197                throw new ArgumentNullException("query");
1198
1199            var qp = query.Provider as QueryProvider;
1200            if (qp == null)
1201                throw new InvalidOperationException();
1202
1203            if (qp.ExpressionChain.Expressions.Count == 0)
1204                qp.ExpressionChain.Expressions.Add(CreateDefaultQuery(query));
1205
1206            return qp.GetQuery(null).GetCommand().Command;
1207        }
1208
1209        private Expression CreateDefaultQuery(IQueryable query)
1210        {
1211            // Manually create the expression tree for: IQueryable<TableType>.Select(e => e)
1212            var identityParameter = Expression.Parameter(query.ElementType, "e");
1213            var identityBody = Expression.Lambda(
1214                typeof(Func<,>).MakeGenericType(query.ElementType, query.ElementType),
1215                identityParameter,
1216                new[] { identityParameter }
1217            );
1218
1219            return Expression.Call(
1220                typeof(Queryable),
1221                "Select",
1222                new[] { query.El

Large files files are truncated, but you can click here to view the full file