PageRenderTime 50ms CodeModel.GetById 19ms RepoModel.GetById 0ms app.codeStats 1ms

/src/PowerTools/Handlers/ReverseEngineerCodeFirstHandler.cs

#
C# | 364 lines | 286 code | 58 blank | 20 comment | 30 complexity | c91d1c020b1d52bf837187b83916b3bb MD5 | raw file
  1. namespace Microsoft.DbContextPackage.Handlers
  2. {
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Configuration;
  6. using System.Data.Common;
  7. using System.Data.Entity.Design;
  8. using System.Data.Entity.Design.PluralizationServices;
  9. using System.Data.Metadata.Edm;
  10. using System.Data.SqlClient;
  11. using System.Diagnostics.Contracts;
  12. using System.Globalization;
  13. using System.IO;
  14. using System.Linq;
  15. using System.Text;
  16. using System.Xml;
  17. using Microsoft.DbContextPackage.Extensions;
  18. using Microsoft.DbContextPackage.Resources;
  19. using Microsoft.DbContextPackage.Utilities;
  20. using Microsoft.VisualStudio.Data.Core;
  21. using Microsoft.VisualStudio.Data.Services;
  22. using Microsoft.VisualStudio.Shell;
  23. using Project = EnvDTE.Project;
  24. internal class ReverseEngineerCodeFirstHandler
  25. {
  26. private static readonly IEnumerable<EntityStoreSchemaFilterEntry> _storeMetadataFilters = new[]
  27. {
  28. new EntityStoreSchemaFilterEntry(null, null, "EdmMetadata", EntityStoreSchemaFilterObjectTypes.Table, EntityStoreSchemaFilterEffect.Exclude),
  29. new EntityStoreSchemaFilterEntry(null, null, "__MigrationHistory", EntityStoreSchemaFilterObjectTypes.Table, EntityStoreSchemaFilterEffect.Exclude)
  30. };
  31. private readonly DbContextPackage _package;
  32. public ReverseEngineerCodeFirstHandler(DbContextPackage package)
  33. {
  34. Contract.Requires(package != null);
  35. _package = package;
  36. }
  37. public void ReverseEngineerCodeFirst(Project project)
  38. {
  39. Contract.Requires(project != null);
  40. try
  41. {
  42. // Show dialog with SqlClient selected by default
  43. var dialogFactory = _package.GetService<IVsDataConnectionDialogFactory>();
  44. var dialog = dialogFactory.CreateConnectionDialog();
  45. dialog.AddAllSources();
  46. dialog.SelectedSource = new Guid("067ea0d9-ba62-43f7-9106-34930c60c528");
  47. var dialogResult = dialog.ShowDialog(connect: true);
  48. if (dialogResult != null)
  49. {
  50. // Find connection string and provider
  51. _package.DTE2.StatusBar.Text = Strings.ReverseEngineer_LoadSchema;
  52. var connection = (DbConnection)dialogResult.GetLockedProviderObject();
  53. var connectionString = connection.ConnectionString;
  54. var providerManager = (IVsDataProviderManager)Package.GetGlobalService(typeof(IVsDataProviderManager));
  55. IVsDataProvider dp;
  56. providerManager.Providers.TryGetValue(dialogResult.Provider, out dp);
  57. var providerInvariant = (string)dp.GetProperty("InvariantName");
  58. // Load store schema
  59. var storeGenerator = new EntityStoreSchemaGenerator(providerInvariant, connectionString, "dbo");
  60. storeGenerator.GenerateForeignKeyProperties = true;
  61. var errors = storeGenerator.GenerateStoreMetadata(_storeMetadataFilters).Where(e => e.Severity == EdmSchemaErrorSeverity.Error);
  62. errors.HandleErrors(Strings.ReverseEngineer_SchemaError);
  63. // Generate default mapping
  64. _package.DTE2.StatusBar.Text = Strings.ReverseEngineer_GenerateMapping;
  65. var contextName = connection.Database.Replace(" ", string.Empty).Replace(".", string.Empty) + "Context";
  66. var modelGenerator = new EntityModelSchemaGenerator(storeGenerator.EntityContainer, "DefaultNamespace", contextName);
  67. modelGenerator.PluralizationService = PluralizationService.CreateService(new CultureInfo("en"));
  68. modelGenerator.GenerateForeignKeyProperties = true;
  69. modelGenerator.GenerateMetadata();
  70. // Pull out info about types to be generated
  71. var entityTypes = modelGenerator.EdmItemCollection.OfType<EntityType>().ToArray();
  72. var mappings = new EdmMapping(modelGenerator, storeGenerator.StoreItemCollection);
  73. // Find the project to add the code to
  74. var vsProject = (VSLangProj.VSProject)project.Object;
  75. var projectDirectory = new FileInfo(project.FileName).Directory;
  76. var projectNamespace = (string)project.Properties.Item("RootNamespace").Value;
  77. var references = vsProject.References.Cast<VSLangProj.Reference>();
  78. if (!references.Any(r => r.Name == "EntityFramework"))
  79. {
  80. // Add EF References
  81. _package.DTE2.StatusBar.Text = Strings.ReverseEngineer_InstallEntityFramework;
  82. try
  83. {
  84. project.InstallPackage("EntityFramework");
  85. }
  86. catch (Exception ex)
  87. {
  88. _package.LogError(Strings.ReverseEngineer_InstallEntityFrameworkError, ex);
  89. }
  90. }
  91. // Generate Entity Classes and Mappings
  92. _package.DTE2.StatusBar.Text = Strings.ReverseEngineer_GenerateClasses;
  93. var templateProcessor = new TemplateProcessor(project);
  94. var modelsNamespace = projectNamespace + ".Models";
  95. var modelsDirectory = Path.Combine(projectDirectory.FullName, "Models");
  96. var mappingNamespace = modelsNamespace + ".Mapping";
  97. var mappingDirectory = Path.Combine(modelsDirectory, "Mapping");
  98. var entityFrameworkVersion = GetEntityFrameworkVersion(references);
  99. foreach (var entityType in entityTypes)
  100. {
  101. // Generate the code file
  102. var entityHost = new EfTextTemplateHost
  103. {
  104. EntityType = entityType,
  105. EntityContainer = modelGenerator.EntityContainer,
  106. Namespace = modelsNamespace,
  107. ModelsNamespace = modelsNamespace,
  108. MappingNamespace = mappingNamespace,
  109. EntityFrameworkVersion = entityFrameworkVersion,
  110. TableSet = mappings.EntityMappings[entityType].Item1,
  111. PropertyToColumnMappings = mappings.EntityMappings[entityType].Item2,
  112. ManyToManyMappings = mappings.ManyToManyMappings
  113. };
  114. var entityContents = templateProcessor.Process(Templates.EntityTemplate, entityHost);
  115. var filePath = Path.Combine(modelsDirectory, entityType.Name + entityHost.FileExtension);
  116. project.AddNewFile(filePath, entityContents);
  117. var mappingHost = new EfTextTemplateHost
  118. {
  119. EntityType = entityType,
  120. EntityContainer = modelGenerator.EntityContainer,
  121. Namespace = mappingNamespace,
  122. ModelsNamespace = modelsNamespace,
  123. MappingNamespace = mappingNamespace,
  124. EntityFrameworkVersion = entityFrameworkVersion,
  125. TableSet = mappings.EntityMappings[entityType].Item1,
  126. PropertyToColumnMappings = mappings.EntityMappings[entityType].Item2,
  127. ManyToManyMappings = mappings.ManyToManyMappings
  128. };
  129. var mappingContents = templateProcessor.Process(Templates.MappingTemplate, mappingHost);
  130. var mappingFilePath = Path.Combine(mappingDirectory, entityType.Name + "Map" + mappingHost.FileExtension);
  131. project.AddNewFile(mappingFilePath, mappingContents);
  132. }
  133. // Generate Context
  134. _package.DTE2.StatusBar.Text = Strings.ReverseEngineer_GenerateContext;
  135. var contextHost = new EfTextTemplateHost
  136. {
  137. EntityContainer = modelGenerator.EntityContainer,
  138. Namespace = modelsNamespace,
  139. ModelsNamespace = modelsNamespace,
  140. MappingNamespace = mappingNamespace,
  141. EntityFrameworkVersion = entityFrameworkVersion
  142. };
  143. var contextContents = templateProcessor.Process(Templates.ContextTemplate, contextHost);
  144. var contextFilePath = Path.Combine(modelsDirectory, modelGenerator.EntityContainer.Name + contextHost.FileExtension);
  145. var contextItem = project.AddNewFile(contextFilePath, contextContents);
  146. AddConnectionStringToConfigFile(project, connectionString, providerInvariant, modelGenerator.EntityContainer.Name);
  147. if (contextItem != null)
  148. {
  149. // Open context class when done
  150. _package.DTE2.ItemOperations.OpenFile(contextFilePath);
  151. }
  152. _package.DTE2.StatusBar.Text = Strings.ReverseEngineer_Complete;
  153. }
  154. }
  155. catch (Exception exception)
  156. {
  157. _package.LogError(Strings.ReverseEngineer_Error, exception);
  158. }
  159. }
  160. private static Version GetEntityFrameworkVersion(IEnumerable<VSLangProj.Reference> references)
  161. {
  162. var entityFrameworkReference = references.FirstOrDefault(r => r.Name == "EntityFramework");
  163. if (entityFrameworkReference != null)
  164. {
  165. return new Version(entityFrameworkReference.Version);
  166. }
  167. return null;
  168. }
  169. private static void AddConnectionStringToConfigFile(Project project, string connectionString, string providerInvariant, string connectionStringName)
  170. {
  171. Contract.Requires(project != null);
  172. Contract.Requires(!string.IsNullOrWhiteSpace(providerInvariant));
  173. Contract.Requires(!string.IsNullOrWhiteSpace(connectionStringName));
  174. // Find App.config or Web.config
  175. var configFilePath = Path.Combine(
  176. project.GetProjectDir(),
  177. project.IsWebProject()
  178. ? "Web.config"
  179. : "App.config");
  180. // Either load up the existing file or create a blank file
  181. var config = ConfigurationManager.OpenMappedExeConfiguration(
  182. new ExeConfigurationFileMap { ExeConfigFilename = configFilePath },
  183. ConfigurationUserLevel.None);
  184. // Find or create the connectionStrings section
  185. var connectionStringSettings = config.ConnectionStrings
  186. .ConnectionStrings
  187. .Cast<ConnectionStringSettings>()
  188. .FirstOrDefault(css => css.Name == connectionStringName);
  189. if (connectionStringSettings == null)
  190. {
  191. connectionStringSettings = new ConnectionStringSettings
  192. {
  193. Name = connectionStringName
  194. };
  195. config.ConnectionStrings
  196. .ConnectionStrings
  197. .Add(connectionStringSettings);
  198. }
  199. // Add in the new connection string
  200. connectionStringSettings.ProviderName = providerInvariant;
  201. connectionStringSettings.ConnectionString = FixUpConnectionString(connectionString, providerInvariant);
  202. project.DTE.SourceControl.CheckOutItemIfNeeded(configFilePath);
  203. config.Save();
  204. // Add any new file to the project
  205. project.ProjectItems.AddFromFile(configFilePath);
  206. }
  207. private static string FixUpConnectionString(string connectionString, string providerName)
  208. {
  209. Contract.Requires(!string.IsNullOrWhiteSpace(providerName));
  210. if (providerName != "System.Data.SqlClient")
  211. {
  212. return connectionString;
  213. }
  214. var builder = new SqlConnectionStringBuilder(connectionString)
  215. {
  216. MultipleActiveResultSets = true
  217. };
  218. builder.Remove("Pooling");
  219. return builder.ToString();
  220. }
  221. private class EdmMapping
  222. {
  223. public EdmMapping(EntityModelSchemaGenerator mcGenerator, StoreItemCollection store)
  224. {
  225. Contract.Requires(mcGenerator != null);
  226. Contract.Requires(store != null);
  227. // Pull mapping xml out
  228. var mappingDoc = new XmlDocument();
  229. var mappingXml = new StringBuilder();
  230. using (var textWriter = new StringWriter(mappingXml))
  231. {
  232. mcGenerator.WriteStorageMapping(new XmlTextWriter(textWriter));
  233. }
  234. mappingDoc.LoadXml(mappingXml.ToString());
  235. var entitySets = mcGenerator.EntityContainer.BaseEntitySets.OfType<EntitySet>();
  236. var associationSets = mcGenerator.EntityContainer.BaseEntitySets.OfType<AssociationSet>();
  237. var tableSets = store.GetItems<EntityContainer>().Single().BaseEntitySets.OfType<EntitySet>();
  238. this.EntityMappings = BuildEntityMappings(mappingDoc, entitySets, tableSets);
  239. this.ManyToManyMappings = BuildManyToManyMappings(mappingDoc, associationSets, tableSets);
  240. }
  241. public Dictionary<EntityType, Tuple<EntitySet, Dictionary<EdmProperty, EdmProperty>>> EntityMappings { get; set; }
  242. public Dictionary<AssociationType, Tuple<EntitySet, Dictionary<RelationshipEndMember, Dictionary<EdmMember, string>>>> ManyToManyMappings { get; set; }
  243. private static Dictionary<AssociationType, Tuple<EntitySet, Dictionary<RelationshipEndMember, Dictionary<EdmMember, string>>>> BuildManyToManyMappings(XmlDocument mappingDoc, IEnumerable<AssociationSet> associationSets, IEnumerable<EntitySet> tableSets)
  244. {
  245. Contract.Requires(mappingDoc != null);
  246. Contract.Requires(associationSets != null);
  247. Contract.Requires(tableSets != null);
  248. // Build mapping for each association
  249. var mappings = new Dictionary<AssociationType, Tuple<EntitySet, Dictionary<RelationshipEndMember, Dictionary<EdmMember, string>>>>();
  250. var namespaceManager = new XmlNamespaceManager(mappingDoc.NameTable);
  251. namespaceManager.AddNamespace("ef", mappingDoc.ChildNodes[0].NamespaceURI);
  252. foreach (var associationSet in associationSets.Where(a => !a.ElementType.AssociationEndMembers.Where(e => e.RelationshipMultiplicity != RelationshipMultiplicity.Many).Any()))
  253. {
  254. var setMapping = mappingDoc.SelectSingleNode(string.Format("//ef:AssociationSetMapping[@Name=\"{0}\"]", associationSet.Name), namespaceManager);
  255. var tableName = setMapping.Attributes["StoreEntitySet"].Value;
  256. var tableSet = tableSets.Single(s => s.Name == tableName);
  257. var endMappings = new Dictionary<RelationshipEndMember, Dictionary<EdmMember, string>>();
  258. foreach (var end in associationSet.AssociationSetEnds)
  259. {
  260. var propertyToColumnMappings = new Dictionary<EdmMember, string>();
  261. var endMapping = setMapping.SelectSingleNode(string.Format("./ef:EndProperty[@Name=\"{0}\"]", end.Name), namespaceManager);
  262. foreach (XmlNode fk in endMapping.ChildNodes)
  263. {
  264. var propertyName = fk.Attributes["Name"].Value;
  265. var property = end.EntitySet.ElementType.Properties[propertyName];
  266. var columnName = fk.Attributes["ColumnName"].Value;
  267. propertyToColumnMappings.Add(property, columnName);
  268. }
  269. endMappings.Add(end.CorrespondingAssociationEndMember, propertyToColumnMappings);
  270. }
  271. mappings.Add(associationSet.ElementType, Tuple.Create(tableSet, endMappings));
  272. }
  273. return mappings;
  274. }
  275. private static Dictionary<EntityType, Tuple<EntitySet, Dictionary<EdmProperty, EdmProperty>>> BuildEntityMappings(XmlDocument mappingDoc, IEnumerable<EntitySet> entitySets, IEnumerable<EntitySet> tableSets)
  276. {
  277. Contract.Requires(mappingDoc != null);
  278. Contract.Requires(entitySets != null);
  279. Contract.Requires(tableSets != null);
  280. // Build mapping for each type
  281. var mappings = new Dictionary<EntityType, Tuple<EntitySet, Dictionary<EdmProperty, EdmProperty>>>();
  282. var namespaceManager = new XmlNamespaceManager(mappingDoc.NameTable);
  283. namespaceManager.AddNamespace("ef", mappingDoc.ChildNodes[0].NamespaceURI);
  284. foreach (var entitySet in entitySets)
  285. {
  286. // Post VS2010 builds use a different structure for mapping
  287. var setMapping = mappingDoc.ChildNodes[0].NamespaceURI == "http://schemas.microsoft.com/ado/2009/11/mapping/cs"
  288. ? mappingDoc.SelectSingleNode(string.Format("//ef:EntitySetMapping[@Name=\"{0}\"]/ef:EntityTypeMapping/ef:MappingFragment", entitySet.Name), namespaceManager)
  289. : mappingDoc.SelectSingleNode(string.Format("//ef:EntitySetMapping[@Name=\"{0}\"]", entitySet.Name), namespaceManager);
  290. var tableName = setMapping.Attributes["StoreEntitySet"].Value;
  291. var tableSet = tableSets.Single(s => s.Name == tableName);
  292. var propertyMappings = new Dictionary<EdmProperty, EdmProperty>();
  293. foreach (var prop in entitySet.ElementType.Properties)
  294. {
  295. var propMapping = setMapping.SelectSingleNode(string.Format("./ef:ScalarProperty[@Name=\"{0}\"]", prop.Name), namespaceManager);
  296. var columnName = propMapping.Attributes["ColumnName"].Value;
  297. var columnProp = tableSet.ElementType.Properties[columnName];
  298. propertyMappings.Add(prop, columnProp);
  299. }
  300. mappings.Add(entitySet.ElementType, Tuple.Create(tableSet, propertyMappings));
  301. }
  302. return mappings;
  303. }
  304. }
  305. }
  306. }