/Rhino.Etl.Core/Operations/SqlBulkInsertOperation.cs

http://github.com/ayende/rhino-etl · C# · 387 lines · 274 code · 40 blank · 73 comment · 24 complexity · dd7dce5c7a82fabf831468d6fd0f549c MD5 · raw file

  1. using System.Configuration;
  2. using Rhino.Etl.Core.Infrastructure;
  3. namespace Rhino.Etl.Core.Operations
  4. {
  5. using System;
  6. using System.Linq;
  7. using System.Collections.Generic;
  8. using System.Data;
  9. using System.Data.SqlClient;
  10. using DataReaders;
  11. /// <summary>
  12. /// Allows to execute an operation that perform a bulk insert into a sql server database
  13. /// </summary>
  14. public abstract class SqlBulkInsertOperation : AbstractDatabaseOperation
  15. {
  16. /// <summary>
  17. /// The schema of the destination table
  18. /// </summary>
  19. private IDictionary<string, Type> _schema = new Dictionary<string, Type>();
  20. /// <summary>
  21. /// The mapping of columns from the row to the database schema.
  22. /// Important: The column name in the database is case sensitive!
  23. /// </summary>
  24. public IDictionary<string, string> Mappings = new Dictionary<string, string>();
  25. private readonly IDictionary<string, Type> _inputSchema = new Dictionary<string, Type>();
  26. private SqlBulkCopy sqlBulkCopy;
  27. private string targetTable;
  28. private int timeout;
  29. private int batchSize;
  30. private int notifyBatchSize;
  31. private SqlBulkCopyOptions bulkCopyOptions = SqlBulkCopyOptions.Default;
  32. /// <summary>
  33. /// Initializes a new instance of the <see cref="SqlBulkInsertOperation"/> class.
  34. /// </summary>
  35. /// <param name="connectionStringName">Name of the connection string.</param>
  36. /// <param name="targetTable">The target table.</param>
  37. protected SqlBulkInsertOperation(string connectionStringName, string targetTable)
  38. : this(ConfigurationManager.ConnectionStrings[connectionStringName], targetTable)
  39. {
  40. }
  41. /// <summary>
  42. /// Initializes a new instance of the <see cref="SqlBulkInsertOperation"/> class.
  43. /// </summary>
  44. /// <param name="connectionStringSettings">Connection string settings to use.</param>
  45. /// <param name="targetTable">The target table.</param>
  46. protected SqlBulkInsertOperation(ConnectionStringSettings connectionStringSettings, string targetTable)
  47. : this(connectionStringSettings, targetTable, 600)
  48. {
  49. }
  50. /// <summary>
  51. /// Initializes a new instance of the <see cref="SqlBulkInsertOperation"/> class.
  52. /// </summary>
  53. /// <param name="connectionStringName">Name of the connection string.</param>
  54. /// <param name="targetTable">The target table.</param>
  55. /// <param name="timeout">The timeout.</param>
  56. protected SqlBulkInsertOperation(string connectionStringName, string targetTable, int timeout)
  57. : this(ConfigurationManager.ConnectionStrings[connectionStringName], targetTable, timeout)
  58. {
  59. Guard.Against(string.IsNullOrEmpty(targetTable), "TargetTable was not set, but it is mandatory");
  60. this.targetTable = targetTable;
  61. this.timeout = timeout;
  62. }
  63. /// <summary>
  64. /// Initializes a new instance of the <see cref="SqlBulkInsertOperation"/> class.
  65. /// </summary>
  66. /// <param name="connectionStringSettings">Connection string settings to use.</param>
  67. /// <param name="targetTable">The target table.</param>
  68. /// <param name="timeout">The timeout.</param>
  69. protected SqlBulkInsertOperation(ConnectionStringSettings connectionStringSettings, string targetTable, int timeout)
  70. : base(connectionStringSettings)
  71. {
  72. Guard.Against(string.IsNullOrEmpty(targetTable), "TargetTable was not set, but it is mandatory");
  73. this.targetTable = targetTable;
  74. this.timeout = timeout;
  75. }
  76. /// <summary>The timeout value of the bulk insert operation</summary>
  77. public virtual int Timeout
  78. {
  79. get { return timeout; }
  80. set { timeout = value; }
  81. }
  82. /// <summary>The batch size value of the bulk insert operation</summary>
  83. public virtual int BatchSize
  84. {
  85. get { return batchSize; }
  86. set { batchSize = value; }
  87. }
  88. /// <summary>The batch size value of the bulk insert operation</summary>
  89. public virtual int NotifyBatchSize
  90. {
  91. get { return notifyBatchSize>0 ? notifyBatchSize : batchSize; }
  92. set { notifyBatchSize = value; }
  93. }
  94. /// <summary>The table or view to bulk load the data into.</summary>
  95. public string TargetTable
  96. {
  97. get { return targetTable; }
  98. set { targetTable = value; }
  99. }
  100. /// <summary><c>true</c> to turn the <see cref="SqlBulkCopyOptions.TableLock"/> option on, otherwise <c>false</c>.</summary>
  101. public virtual bool LockTable
  102. {
  103. get { return IsOptionOn(SqlBulkCopyOptions.TableLock); }
  104. set { ToggleOption(SqlBulkCopyOptions.TableLock, value); }
  105. }
  106. /// <summary><c>true</c> to turn the <see cref="SqlBulkCopyOptions.KeepIdentity"/> option on, otherwise <c>false</c>.</summary>
  107. public virtual bool KeepIdentity
  108. {
  109. get { return IsOptionOn(SqlBulkCopyOptions.KeepIdentity); }
  110. set { ToggleOption(SqlBulkCopyOptions.KeepIdentity, value); }
  111. }
  112. /// <summary><c>true</c> to turn the <see cref="SqlBulkCopyOptions.KeepNulls"/> option on, otherwise <c>false</c>.</summary>
  113. public virtual bool KeepNulls
  114. {
  115. get { return IsOptionOn(SqlBulkCopyOptions.KeepNulls); }
  116. set { ToggleOption(SqlBulkCopyOptions.KeepNulls, value); }
  117. }
  118. /// <summary><c>true</c> to turn the <see cref="SqlBulkCopyOptions.CheckConstraints"/> option on, otherwise <c>false</c>.</summary>
  119. public virtual bool CheckConstraints
  120. {
  121. get { return IsOptionOn(SqlBulkCopyOptions.CheckConstraints); }
  122. set { ToggleOption(SqlBulkCopyOptions.CheckConstraints, value); }
  123. }
  124. /// <summary><c>true</c> to turn the <see cref="SqlBulkCopyOptions.FireTriggers"/> option on, otherwise <c>false</c>.</summary>
  125. public virtual bool FireTriggers
  126. {
  127. get { return IsOptionOn(SqlBulkCopyOptions.FireTriggers); }
  128. set { ToggleOption(SqlBulkCopyOptions.FireTriggers, value); }
  129. }
  130. /// <summary>Turns a <see cref="bulkCopyOptions"/> on or off depending on the value of <paramref name="on"/></summary>
  131. /// <param name="option">The <see cref="SqlBulkCopyOptions"/> to turn on or off.</param>
  132. /// <param name="on"><c>true</c> to set the <see cref="SqlBulkCopyOptions"/> <paramref name="option"/> on otherwise <c>false</c> to turn the <paramref name="option"/> off.</param>
  133. protected void ToggleOption(SqlBulkCopyOptions option, bool on)
  134. {
  135. if (on)
  136. {
  137. TurnOptionOn(option);
  138. }
  139. else
  140. {
  141. TurnOptionOff(option);
  142. }
  143. }
  144. /// <summary>Returns <c>true</c> if the <paramref name="option"/> is turned on, otherwise <c>false</c></summary>
  145. /// <param name="option">The <see cref="SqlBulkCopyOptions"/> option to test for.</param>
  146. /// <returns></returns>
  147. protected bool IsOptionOn(SqlBulkCopyOptions option)
  148. {
  149. return (bulkCopyOptions & option) == option;
  150. }
  151. /// <summary>Turns the <paramref name="option"/> on.</summary>
  152. /// <param name="option"></param>
  153. protected void TurnOptionOn(SqlBulkCopyOptions option)
  154. {
  155. bulkCopyOptions |= option;
  156. }
  157. /// <summary>Turns the <paramref name="option"/> off.</summary>
  158. /// <param name="option"></param>
  159. protected void TurnOptionOff(SqlBulkCopyOptions option)
  160. {
  161. if (IsOptionOn(option))
  162. bulkCopyOptions ^= option;
  163. }
  164. /// <summary>The table or view's schema information.</summary>
  165. public IDictionary<string, Type> Schema
  166. {
  167. get { return _schema; }
  168. set { _schema = value; }
  169. }
  170. /// <summary>
  171. /// Prepares the mapping for use, by default, it uses the schema mapping.
  172. /// This is the preferred appraoch
  173. /// </summary>
  174. public virtual void PrepareMapping()
  175. {
  176. foreach (KeyValuePair<string, Type> pair in _schema)
  177. {
  178. Mappings[pair.Key] = pair.Key;
  179. }
  180. }
  181. /// <summary>Use the destination Schema and Mappings to create the
  182. /// operations input schema so it can build the adapter for sending
  183. /// to the WriteToServer method.</summary>
  184. public virtual void CreateInputSchema()
  185. {
  186. foreach(KeyValuePair<string, string> pair in Mappings)
  187. {
  188. _inputSchema.Add(pair.Key, _schema[pair.Value]);
  189. }
  190. }
  191. /// <summary>
  192. /// Executes this operation
  193. /// </summary>
  194. public override IEnumerable<Row> Execute(IEnumerable<Row> rows)
  195. {
  196. Guard.Against<ArgumentException>(rows == null, "SqlBulkInsertOperation cannot accept a null enumerator");
  197. PrepareSchema();
  198. PrepareMapping();
  199. CreateInputSchema();
  200. using (SqlConnection connection = (SqlConnection)Use.Connection(ConnectionStringSettings))
  201. using (SqlTransaction transaction = (SqlTransaction) BeginTransaction(connection))
  202. {
  203. sqlBulkCopy = CreateSqlBulkCopy(connection, transaction);
  204. DictionaryEnumeratorDataReader adapter = new DictionaryEnumeratorDataReader(_inputSchema, rows);
  205. try
  206. {
  207. sqlBulkCopy.WriteToServer(adapter);
  208. }
  209. catch (InvalidOperationException)
  210. {
  211. CompareSqlColumns(connection, transaction, rows);
  212. throw;
  213. }
  214. if (PipelineExecuter.HasErrors)
  215. {
  216. Warn("Rolling back transaction in {0}", Name);
  217. if (transaction != null) transaction.Rollback();
  218. Warn("Rolled back transaction in {0}", Name);
  219. }
  220. else
  221. {
  222. Debug("Committing {0}", Name);
  223. if (transaction != null) transaction.Commit();
  224. Debug("Committed {0}", Name);
  225. }
  226. }
  227. yield break;
  228. }
  229. /// <summary>
  230. /// Handle sql notifications
  231. /// </summary>
  232. protected virtual void onSqlRowsCopied(object sender, SqlRowsCopiedEventArgs e)
  233. {
  234. Debug("{0} rows copied to database", e.RowsCopied);
  235. }
  236. /// <summary>
  237. /// Prepares the schema of the target table
  238. /// </summary>
  239. protected abstract void PrepareSchema();
  240. /// <summary>
  241. /// Creates the SQL bulk copy instance
  242. /// </summary>
  243. private SqlBulkCopy CreateSqlBulkCopy(SqlConnection connection, SqlTransaction transaction)
  244. {
  245. SqlBulkCopy copy = new SqlBulkCopy(connection, bulkCopyOptions, transaction);
  246. copy.BatchSize = batchSize;
  247. foreach (KeyValuePair<string, string> pair in Mappings)
  248. {
  249. copy.ColumnMappings.Add(pair.Key, pair.Value);
  250. }
  251. copy.NotifyAfter = NotifyBatchSize;
  252. copy.SqlRowsCopied += onSqlRowsCopied;
  253. copy.DestinationTableName = TargetTable;
  254. copy.BulkCopyTimeout = Timeout;
  255. return copy;
  256. }
  257. private void CompareSqlColumns(SqlConnection connection, SqlTransaction transaction, IEnumerable<Row> rows)
  258. {
  259. var command = connection.CreateCommand();
  260. command.CommandText = "select * from {TargetTable} where 1=0".Replace("{TargetTable}", TargetTable);
  261. command.CommandType = CommandType.Text;
  262. command.Transaction = transaction;
  263. using (var reader = command.ExecuteReader(CommandBehavior.KeyInfo))
  264. {
  265. var schemaTable = reader.GetSchemaTable();
  266. var databaseColumns = schemaTable.Rows
  267. .OfType<DataRow>()
  268. .Select(r => new
  269. {
  270. Name = (string)r["ColumnName"],
  271. Type = (Type)r["DataType"],
  272. IsNullable = (bool)r["AllowDBNull"],
  273. MaxLength = (int)r["ColumnSize"]
  274. })
  275. .ToArray();
  276. var missingColumns = _schema.Keys.Except(
  277. databaseColumns.Select(c => c.Name));
  278. if (missingColumns.Any())
  279. throw new InvalidOperationException(
  280. "The following columns are not in the target table: " +
  281. string.Join(", ", missingColumns.ToArray()));
  282. var differentColumns = _schema
  283. .Select(s => new
  284. {
  285. Name = s.Key,
  286. SchemaType = s.Value,
  287. DatabaseType = databaseColumns.Single(c => c.Name == s.Key)
  288. })
  289. .Where(c => !TypesMatch(c.SchemaType, c.DatabaseType.Type, c.DatabaseType.IsNullable));
  290. if (differentColumns.Any())
  291. throw new InvalidOperationException(
  292. "The following columns have different types in the target table: " +
  293. string.Join(", ", differentColumns
  294. //.Select(c => $"{c.Name}: is {GetFriendlyName(c.SchemaType)}, but should be {GetFriendlyName(c.DatabaseType.Type)}{(c.DatabaseType.IsNullable ? "?" : "")}.")
  295. // c.Name, GetFriendlyName(c.SchemaType), GetFriendlyName(c.DatabaseType.Type), (c.DatabaseType.IsNullable ? \"?\" : \"\")
  296. .Select(c => string.Format("{0}: is {1}, but should be {2}{3}.", c.Name,
  297. GetFriendlyName(c.SchemaType), GetFriendlyName(c.DatabaseType.Type),
  298. (c.DatabaseType.IsNullable ? "?" : "")))
  299. .ToArray()
  300. ));
  301. var stringsTooLong =
  302. (from column in databaseColumns
  303. where column.Type == typeof(string)
  304. from mapping in Mappings
  305. where mapping.Value == column.Name
  306. let name = mapping.Key
  307. from row in rows
  308. let value = (string)row[name]
  309. where value != null && value.Length > column.MaxLength
  310. select new { column.Name, column.MaxLength, Value = value })
  311. .ToArray();
  312. if (stringsTooLong.Any())
  313. throw new InvalidOperationException(
  314. "The folowing columns have values too long for the target table: " +
  315. string.Join(", ", stringsTooLong
  316. .Select(s => "{s.Name}: max length is {s.MaxLength}, value is {s.Value}."
  317. .Replace("{s.Name}", s.Name)
  318. .Replace("{s.MaxLength}", s.MaxLength.ToString())
  319. .Replace("{s.Value}", s.Value)
  320. )
  321. .ToArray()));
  322. }
  323. }
  324. private static string GetFriendlyName(Type type)
  325. {
  326. var friendlyName = type.Name;
  327. if (!type.IsGenericType)
  328. return friendlyName;
  329. var iBacktick = friendlyName.IndexOf('`');
  330. if (iBacktick > 0)
  331. friendlyName = friendlyName.Remove(iBacktick);
  332. var genericParameters = type.GetGenericArguments()
  333. .Select(x => GetFriendlyName(x))
  334. .ToArray();
  335. friendlyName += "<" + string.Join(", ", genericParameters) + ">";
  336. return friendlyName;
  337. }
  338. private bool TypesMatch(Type schemaType, Type databaseType, bool isNullable)
  339. {
  340. if (schemaType == databaseType)
  341. return true;
  342. if (isNullable && schemaType == typeof(Nullable<>).MakeGenericType(databaseType))
  343. return true;
  344. return false;
  345. }
  346. }
  347. }