/src/EntityFramework/Migrations/IDbSetExtensions.cs

# · C# · 170 lines · 119 code · 24 blank · 27 comment · 21 complexity · 965bd15120187b683a572ff92efeed06 MD5 · raw file

  1. namespace System.Data.Entity.Migrations
  2. {
  3. using System.Collections.Generic;
  4. using System.Data.Entity.Internal.Linq;
  5. using System.Data.Entity.ModelConfiguration.Utilities;
  6. using System.Data.Entity.Resources;
  7. using System.Data.Entity.Utilities;
  8. using System.Diagnostics.CodeAnalysis;
  9. using System.Diagnostics.Contracts;
  10. using System.Linq;
  11. using System.Linq.Expressions;
  12. using System.Reflection;
  13. /// <summary>
  14. /// A set of extension methods for <see cref = "IDbSet{TEntity}" />
  15. /// </summary>
  16. public static class IDbSetExtensions
  17. {
  18. private const BindingFlags KeyPropertyBindingFlags
  19. = BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic;
  20. /// <summary>
  21. /// Adds or updates entities by key when SaveChanges is called. Equivalent to an "upsert" operation
  22. /// from database terminology.
  23. /// This method can useful when seeding data using Migrations.
  24. /// </summary>
  25. /// <param name = "entities">The entities to add or update.</param>
  26. /// <remarks>
  27. /// When the <param name="set"/> parameter is a custom or fake IDbSet implementation, this method will
  28. /// attempt to locate and invoke a public, instance method with the same signature as this extension method.
  29. /// </remarks>
  30. public static void AddOrUpdate<TEntity>(
  31. this IDbSet<TEntity> set, params TEntity[] entities)
  32. where TEntity : class
  33. {
  34. Contract.Requires(set != null);
  35. Contract.Requires(entities != null);
  36. var dbSet = set as DbSet<TEntity>;
  37. if (dbSet != null)
  38. {
  39. var internalSet = (InternalSet<TEntity>)((IInternalSetAdapter)dbSet).InternalSet;
  40. dbSet.AddOrUpdate(GetKeyProperties(typeof(TEntity), internalSet), entities);
  41. }
  42. else
  43. {
  44. var targetType = set.GetType();
  45. var method = targetType.GetMethod("AddOrUpdate", new[] { typeof(TEntity[]) });
  46. if (method == null)
  47. {
  48. throw Error.UnableToDispatchAddOrUpdate(targetType);
  49. }
  50. method.Invoke(set, new[] { entities });
  51. }
  52. }
  53. /// <summary>
  54. /// Adds or updates entities by a custom identification expression when SaveChanges is called.
  55. /// Equivalent to an "upsert" operation from database terminology.
  56. /// This method can useful when seeding data using Migrations.
  57. /// </summary>
  58. /// <param name = "identifierExpression">
  59. /// An expression specifying the properties that should be used when determining
  60. /// whether an Add or Update operation should be performed.
  61. /// </param>
  62. /// <param name = "entities">The entities to add or update.</param>
  63. /// <remarks>
  64. /// When the <param name="set"/> parameter is a custom or fake IDbSet implementation, this method will
  65. /// attempt to locate and invoke a public, instance method with the same signature as this extension method.
  66. /// </remarks>
  67. [SuppressMessage("Microsoft.Design", "CA1006:DoNotNestGenericTypesInMemberSignatures")]
  68. public static void AddOrUpdate<TEntity>(
  69. this IDbSet<TEntity> set, Expression<Func<TEntity, object>> identifierExpression, params TEntity[] entities)
  70. where TEntity : class
  71. {
  72. Contract.Requires(set != null);
  73. Contract.Requires(identifierExpression != null);
  74. Contract.Requires(entities != null);
  75. var dbSet = set as DbSet<TEntity>;
  76. if (dbSet != null)
  77. {
  78. var identifyingProperties = identifierExpression.GetPropertyAccessList();
  79. dbSet.AddOrUpdate(identifyingProperties, entities);
  80. }
  81. else
  82. {
  83. var targetType = set.GetType();
  84. var method
  85. = targetType.GetMethod(
  86. "AddOrUpdate",
  87. new[] { typeof(Expression<Func<TEntity, object>>), typeof(TEntity[]) });
  88. if (method == null)
  89. {
  90. throw Error.UnableToDispatchAddOrUpdate(targetType);
  91. }
  92. method.Invoke(set, new object[] { identifierExpression, entities });
  93. }
  94. }
  95. private static void AddOrUpdate<TEntity>(
  96. this DbSet<TEntity> set, IEnumerable<PropertyPath> identifyingProperties, params TEntity[] entities)
  97. where TEntity : class
  98. {
  99. Contract.Requires(set != null);
  100. Contract.Requires(identifyingProperties != null);
  101. Contract.Requires(entities != null);
  102. var internalSet = (InternalSet<TEntity>)((IInternalSetAdapter)set).InternalSet;
  103. var keyProperties = GetKeyProperties(typeof(TEntity), internalSet);
  104. var parameter = Expression.Parameter(typeof(TEntity));
  105. foreach (var entity in entities)
  106. {
  107. var matchExpression
  108. = identifyingProperties.Select(
  109. pi => Expression.Equal(
  110. Expression.Property(parameter, pi.Last()),
  111. Expression.Constant(pi.Last().GetValue(entity, null))))
  112. .Aggregate<BinaryExpression, Expression>(
  113. null,
  114. (current, predicate)
  115. => (current == null)
  116. ? predicate
  117. : Expression.AndAlso(current, predicate));
  118. var existing
  119. = set.SingleOrDefault(Expression.Lambda<Func<TEntity, bool>>(matchExpression, new[] { parameter }));
  120. if (existing != null)
  121. {
  122. foreach (var keyProperty in keyProperties)
  123. {
  124. keyProperty.Single().SetValue(entity, keyProperty.Single().GetValue(existing, null), null);
  125. }
  126. internalSet.InternalContext.Owner.Entry(existing).CurrentValues.SetValues(entity);
  127. }
  128. else
  129. {
  130. internalSet.Add(entity);
  131. }
  132. }
  133. }
  134. private static IEnumerable<PropertyPath> GetKeyProperties<TEntity>(
  135. Type entityType, InternalSet<TEntity> internalSet)
  136. where TEntity : class
  137. {
  138. Contract.Requires(entityType != null);
  139. Contract.Requires(internalSet != null);
  140. return internalSet.InternalContext
  141. .GetEntitySetAndBaseTypeForType(typeof(TEntity))
  142. .EntitySet.ElementType.KeyMembers
  143. .Select(km => new PropertyPath(entityType.GetProperty(km.Name, KeyPropertyBindingFlags)));
  144. }
  145. }
  146. }