PageRenderTime 31ms CodeModel.GetById 40ms RepoModel.GetById 0ms app.codeStats 0ms

/drivers/python/rethinkdb/_import.py

https://gitlab.com/freesoftware/rethinkdb
Python | 1155 lines | 1097 code | 30 blank | 28 comment | 50 complexity | 8c76372fe71ca20656bff6739dc3f7a5 MD5 | raw file
  1. #!/usr/bin/env python
  2. '''`rethinkdb import` loads data into a RethinkDB cluster'''
  3. from __future__ import print_function
  4. import codecs, collections, csv, ctypes, json, multiprocessing
  5. import optparse, os, re, signal, sys, time, traceback
  6. from . import ast, errors, query, utils_common
  7. try:
  8. unicode
  9. except NameError:
  10. unicode = str
  11. try:
  12. from Queue import Empty, Full
  13. except ImportError:
  14. from queue import Empty, Full
  15. try:
  16. from multiprocessing import Queue, SimpleQueue
  17. except ImportError:
  18. from multiprocessing.queues import Queue, SimpleQueue
  19. #json parameters
  20. json_read_chunk_size = 128 * 1024
  21. json_max_buffer_size = 128 * 1024 * 1024
  22. max_nesting_depth = 100
  23. Error = collections.namedtuple("Error", ["message", "traceback", "file"])
  24. class SourceFile(object):
  25. format = None # set by subclasses
  26. name = None
  27. db = None
  28. table = None
  29. primary_key = None
  30. indexes = None
  31. write_hook = None
  32. source_options = None
  33. start_time = None
  34. end_time = None
  35. query_runner = None
  36. _source = None # open filehandle for the source
  37. # - internal synchronization variables
  38. _bytes_size = None
  39. _bytes_read = None # -1 until started
  40. _total_rows = None # -1 until known
  41. _rows_read = None
  42. _rows_written = None
  43. def __init__(self, source, db, table, query_runner, primary_key=None, indexes=None, write_hook=None, source_options=None):
  44. assert self.format is not None, 'Subclass %s must have a format' % self.__class__.__name__
  45. assert db is not 'rethinkdb', "Error: Cannot import tables into the system database: 'rethinkdb'"
  46. # query_runner
  47. assert isinstance(query_runner, utils_common.RetryQuery)
  48. self.query_runner = query_runner
  49. # reporting information
  50. self._bytes_size = multiprocessing.Value(ctypes.c_longlong, -1)
  51. self._bytes_read = multiprocessing.Value(ctypes.c_longlong, -1)
  52. self._total_rows = multiprocessing.Value(ctypes.c_longlong, -1)
  53. self._rows_read = multiprocessing.Value(ctypes.c_longlong, 0)
  54. self._rows_written = multiprocessing.Value(ctypes.c_longlong, 0)
  55. # source
  56. sourceLength = 0
  57. if hasattr(source, 'read'):
  58. if unicode != str or 'b' in source.mode:
  59. # Python2.x or binary file, assume utf-8 encoding
  60. self._source = codecs.getreader("utf-8")(source)
  61. else:
  62. # assume that it has the right encoding on it
  63. self._source = source
  64. else:
  65. try:
  66. self._source = codecs.open(source, mode="r", encoding="utf-8")
  67. except IOError as e:
  68. raise ValueError('Unable to open source file "%s": %s' % (str(source), str(e)))
  69. if hasattr(self._source, 'name') and self._source.name and os.path.isfile(self._source.name):
  70. self._bytes_size.value = os.path.getsize(source)
  71. if self._bytes_size.value == 0:
  72. raise ValueError('Source is zero-length: %s' % source)
  73. # table info
  74. self.db = db
  75. self.table = table
  76. self.primary_key = primary_key
  77. self.indexes = indexes or []
  78. self.write_hook = write_hook or []
  79. # options
  80. self.source_options = source_options or {}
  81. # name
  82. if hasattr(self._source, 'name') and self._source.name:
  83. self.name = os.path.basename(self._source.name)
  84. else:
  85. self.name = '%s.%s' % (self.db, self.table)
  86. def __hash__(self):
  87. return hash((self.db, self.table))
  88. def get_line(self):
  89. '''Returns a single line from the file'''
  90. raise NotImplementedError('This needs to be implemented on the %s subclass' % self.format)
  91. # - bytes
  92. @property
  93. def bytes_size(self):
  94. return self._bytes_size.value
  95. @bytes_size.setter
  96. def bytes_size(self, value):
  97. self._bytes_size.value = value
  98. @property
  99. def bytes_read(self):
  100. return self._bytes_read.value
  101. @bytes_read.setter
  102. def bytes_read(self, value):
  103. self._bytes_read.value = value
  104. # - rows
  105. @property
  106. def total_rows(self):
  107. return self._total_rows.value
  108. @total_rows.setter
  109. def total_rows(self, value):
  110. self._total_rows.value = value
  111. @property
  112. def rows_read(self):
  113. return self._rows_read.value
  114. @rows_read.setter
  115. def rows_read(self, value):
  116. self._rows_read.value = value
  117. @property
  118. def rows_written(self):
  119. return self._rows_written.value
  120. def add_rows_written(self, increment): # we have multiple writers to coordinate
  121. with self._rows_written.get_lock():
  122. self._rows_written.value += increment
  123. # - percent done
  124. @property
  125. def percentDone(self):
  126. '''return a float between 0 and 1 for a reasonable guess of percentage complete'''
  127. # assume that reading takes 50% of the time and writing the other 50%
  128. completed = 0.0 # of 2.0
  129. # - add read percentage
  130. if self._bytes_size.value <= 0 or self._bytes_size.value <= self._bytes_read.value:
  131. completed += 1.0
  132. elif self._bytes_read.value < 0 and self._total_rows.value >= 0:
  133. # done by rows read
  134. if self._rows_read > 0:
  135. completed += float(self._rows_read) / float(self._total_rows.value)
  136. else:
  137. # done by bytes read
  138. if self._bytes_read.value > 0:
  139. completed += float(self._bytes_read.value) / float(self._bytes_size.value)
  140. read = completed
  141. # - add written percentage
  142. if self._rows_read.value or self._rows_written.value:
  143. totalRows = float(self._total_rows.value)
  144. if totalRows == 0:
  145. completed += 1.0
  146. elif totalRows < 0:
  147. # a guesstimate
  148. perRowSize = float(self._bytes_read.value) / float(self._rows_read.value)
  149. totalRows = float(self._rows_read.value) + (float(self._bytes_size.value - self._bytes_read.value) / perRowSize)
  150. completed += float(self._rows_written.value) / totalRows
  151. else:
  152. # accurate count
  153. completed += float(self._rows_written.value) / totalRows
  154. # - return the value
  155. return completed * 0.5
  156. def setup_table(self):
  157. '''Ensure that the db, table, and indexes exist and are correct'''
  158. # - ensure the table exists and is ready
  159. self.query_runner(
  160. "create table: %s.%s" % (self.db, self.table),
  161. ast.expr([self.table]).set_difference(query.db(self.db).table_list()).for_each(query.db(self.db).table_create(query.row, **self.source_options.create_args if 'create_args' in self.source_options else {}))
  162. )
  163. self.query_runner("wait for %s.%s" % (self.db, self.table), query.db(self.db).table(self.table).wait(timeout=30))
  164. # - ensure that the primary key on the table is correct
  165. primaryKey = self.query_runner(
  166. "primary key %s.%s" % (self.db, self.table),
  167. query.db(self.db).table(self.table).info()["primary_key"],
  168. )
  169. if self.primary_key is None:
  170. self.primary_key = primaryKey
  171. elif primaryKey != self.primary_key:
  172. raise RuntimeError("Error: table %s.%s primary key was `%s` rather than the expected: %s" % (self.db, table.table, primaryKey, self.primary_key))
  173. def restore_indexes(self, warning_queue):
  174. # recreate secondary indexes - dropping existing on the assumption they are wrong
  175. if self.indexes:
  176. existing_indexes = self.query_runner("indexes from: %s.%s" % (self.db, self.table), query.db(self.db).table(self.table).index_list())
  177. try:
  178. created_indexes = []
  179. for index in self.indexes:
  180. if index["index"] in existing_indexes: # drop existing versions
  181. self.query_runner(
  182. "drop index: %s.%s:%s" % (self.db, self.table, index["index"]),
  183. query.db(self.db).table(self.table).index_drop(index["index"])
  184. )
  185. self.query_runner(
  186. "create index: %s.%s:%s" % (self.db, self.table, index["index"]),
  187. query.db(self.db).table(self.table).index_create(index["index"], index["function"])
  188. )
  189. created_indexes.append(index["index"])
  190. # wait for all of the created indexes to build
  191. self.query_runner(
  192. "waiting for indexes on %s.%s" % (self.db, self.table),
  193. query.db(self.db).table(self.table).index_wait(query.args(created_indexes))
  194. )
  195. except RuntimeError as e:
  196. ex_type, ex_class, tb = sys.exc_info()
  197. warning_queue.put((ex_type, ex_class, traceback.extract_tb(tb), self._source.name))
  198. existing_hook = self.query_runner("Write hook from: %s.%s" % (self.db, self.table), query.db(self.db).table(self.table).get_write_hook())
  199. try:
  200. created_hook = []
  201. if self.write_hook != []:
  202. self.query_runner(
  203. "drop hook: %s.%s" % (self.db, self.table),
  204. query.db(self.db).table(self.table).set_write_hook(None)
  205. )
  206. self.query_runner(
  207. "create hook: %s.%s:%s" % (self.db, self.table, self.write_hook),
  208. query.db(self.db).table(self.table).set_write_hook(self.write_hook["function"])
  209. )
  210. except RuntimeError as re:
  211. ex_type, ex_class, tb = sys.exec_info()
  212. warning_queue.put((ex_type, ex_class, traceback.extract_tb(tb), self._source.name))
  213. def batches(self, batch_size=None, warning_queue=None):
  214. # setup table
  215. self.setup_table()
  216. # default batch_size
  217. if batch_size is None:
  218. batch_size = utils_common.default_batch_size
  219. else:
  220. batch_size = int(batch_size)
  221. assert batch_size > 0
  222. # setup
  223. self.setup_file(warning_queue=warning_queue)
  224. # - yield batches
  225. batch = []
  226. try:
  227. needMoreData = False
  228. while True:
  229. if needMoreData:
  230. self.fill_buffer()
  231. needMoreData = False
  232. while len(batch) < batch_size:
  233. try:
  234. row = self.get_line()
  235. # ToDo: validate the line
  236. batch.append(row)
  237. except NeedMoreData:
  238. needMoreData = True
  239. break
  240. except Exception:
  241. raise
  242. else:
  243. yield batch
  244. batch = []
  245. except StopIteration as e:
  246. # yield any final batch
  247. if batch:
  248. yield batch
  249. # - check the end of the file
  250. self.teardown()
  251. # - rebuild indexes
  252. if self.indexes:
  253. self.restore_indexes(warning_queue)
  254. # -
  255. raise e
  256. def setup_file(self, warning_queue=None):
  257. raise NotImplementedError("Subclasses need to implement this")
  258. def teardown(self):
  259. pass
  260. def read_to_queue(self, work_queue, exit_event, error_queue, warning_queue, timing_queue, fields=None, ignore_signals=True, batch_size=None):
  261. if ignore_signals: # ToDo: work out when we are in a worker process automatically
  262. signal.signal(signal.SIGINT, signal.SIG_IGN) # workers should ignore these
  263. if batch_size is None:
  264. batch_size = utils_common.default_batch_size
  265. self.start_time = time.time()
  266. try:
  267. timePoint = time.time()
  268. for batch in self.batches(warning_queue=warning_queue):
  269. timing_queue.put(('reader_work', time.time() - timePoint))
  270. timePoint = time.time()
  271. # apply the fields filter
  272. if fields:
  273. for row in batch:
  274. for key in [x for x in row.keys() if x not in fields]:
  275. del row[key]
  276. while not exit_event.is_set():
  277. try:
  278. work_queue.put((self.db, self.table, batch), timeout=0.1)
  279. self._rows_read.value += len(batch)
  280. break
  281. except Full:
  282. pass
  283. else:
  284. break
  285. timing_queue.put(('reader_wait', time.time() - timePoint))
  286. timePoint = time.time()
  287. # - report relevant errors
  288. except Exception as e:
  289. error_queue.put(Error(str(e), traceback.format_exc(), self.name))
  290. exit_event.set()
  291. raise
  292. finally:
  293. self.end_time = time.time()
  294. class NeedMoreData(Exception):
  295. pass
  296. class JsonSourceFile(SourceFile):
  297. format = 'json'
  298. decoder = json.JSONDecoder()
  299. json_array = None
  300. found_first = False
  301. _buffer_size = json_read_chunk_size
  302. _buffer_str = None
  303. _buffer_pos = None
  304. _buffer_end = None
  305. def fill_buffer(self):
  306. if self._buffer_str is None:
  307. self._buffer_str = ''
  308. self._buffer_pos = 0
  309. self._buffer_end = 0
  310. elif self._buffer_pos == 0:
  311. # double the buffer under the assumption that the documents are too large to fit
  312. if self._buffer_size == json_max_buffer_size:
  313. raise Exception("Error: JSON max buffer size exceeded on file %s (from position %d). Use '--max-document-size' to extend your buffer." % (self.name, self.bytes_processed))
  314. self._buffer_size = min(self._buffer_size * 2, json_max_buffer_size)
  315. # add more data
  316. readTarget = self._buffer_size - self._buffer_end + self._buffer_pos
  317. assert readTarget > 0
  318. newChunk = self._source.read(readTarget)
  319. if len(newChunk) == 0:
  320. raise StopIteration() # file ended
  321. self._buffer_str = self._buffer_str[self._buffer_pos:] + newChunk
  322. self._bytes_read.value += len(newChunk)
  323. # reset markers
  324. self._buffer_pos = 0
  325. self._buffer_end = len(self._buffer_str) - 1
  326. def get_line(self):
  327. '''Return a line from the current _buffer_str, or raise NeedMoreData trying'''
  328. # advance over any whitespace
  329. self._buffer_pos = json.decoder.WHITESPACE.match(self._buffer_str, self._buffer_pos).end()
  330. if self._buffer_pos >= self._buffer_end:
  331. raise NeedMoreData()
  332. # read over a comma if we are not the first item in a json_array
  333. if self.json_array and self.found_first and self._buffer_str[self._buffer_pos] == ",":
  334. self._buffer_pos += 1
  335. if self._buffer_pos >= self._buffer_end:
  336. raise NeedMoreData()
  337. # advance over any post-comma whitespace
  338. self._buffer_pos = json.decoder.WHITESPACE.match(self._buffer_str, self._buffer_pos).end()
  339. if self._buffer_pos >= self._buffer_end:
  340. raise NeedMoreData()
  341. # parse and return an object
  342. try:
  343. row, self._buffer_pos = self.decoder.raw_decode(self._buffer_str, idx=self._buffer_pos)
  344. self.found_first = True
  345. return row
  346. except (ValueError, IndexError) as e:
  347. raise NeedMoreData()
  348. def setup_file(self, warning_queue=None):
  349. # - move to the first record
  350. # advance through any leading whitespace
  351. while True:
  352. self.fill_buffer()
  353. self._buffer_pos = json.decoder.WHITESPACE.match(self._buffer_str, 0).end()
  354. if self._buffer_pos == 0:
  355. break
  356. # check the first character
  357. try:
  358. if self._buffer_str[0] == "[":
  359. self.json_array = True
  360. self._buffer_pos = 1
  361. elif self._buffer_str[0] == "{":
  362. self.json_array = False
  363. else:
  364. raise ValueError("Error: JSON format not recognized - file does not begin with an object or array")
  365. except IndexError:
  366. raise ValueError("Error: JSON file was empty of content")
  367. def teardown(self):
  368. # - check the end of the file
  369. # note: fill_buffer should have guaranteed that we have only the data in the end
  370. # advance through any leading whitespace
  371. self._buffer_pos = json.decoder.WHITESPACE.match(self._buffer_str, self._buffer_pos).end()
  372. # check the end of the array if we have it
  373. if self.json_array:
  374. if self._buffer_str[self._buffer_pos] != "]":
  375. snippit = self._buffer_str[self._buffer_pos:]
  376. extra = '' if len(snippit) <= 100 else ' and %d more characters' % (len(snippit) - 100)
  377. raise ValueError("Error: JSON array did not end cleanly, rather with: <<%s>>%s" % (snippit[:100], extra))
  378. self._buffer_pos += 1
  379. # advance through any trailing whitespace
  380. self._buffer_pos = json.decoder.WHITESPACE.match(self._buffer_str, self._buffer_pos).end()
  381. snippit = self._buffer_str[self._buffer_pos:]
  382. if len(snippit) > 0:
  383. extra = '' if len(snippit) <= 100 else ' and %d more characters' % (len(snippit) - 100)
  384. raise ValueError("Error: extra data after JSON data: <<%s>>%s" % (snippit[:100], extra))
  385. class CsvSourceFile(SourceFile):
  386. format = "csv"
  387. no_header_row = False
  388. custom_header = None
  389. _reader = None # instance of csv.reader
  390. _columns = None # name of the columns
  391. def __init__(self, *args, **kwargs):
  392. if 'source_options' in kwargs and isinstance(kwargs['source_options'], dict):
  393. if 'no_header_row' in kwargs['source_options']:
  394. self.no_header_row = kwargs['source_options']['no_header_row'] == True
  395. if 'custom_header' in kwargs['source_options']:
  396. self.custom_header = kwargs['source_options']['custom_header']
  397. super(CsvSourceFile, self).__init__(*args, **kwargs)
  398. def byte_counter(self):
  399. '''Generator for getting a byte count on a file being used'''
  400. for line in self._source:
  401. self._bytes_read.value += len(line)
  402. if unicode != str:
  403. yield line.encode("utf-8") # Python2.x csv module does not really handle unicode
  404. else:
  405. yield line
  406. def setup_file(self, warning_queue=None):
  407. # - setup csv.reader with a byte counter wrapper
  408. self._reader = csv.reader(self.byte_counter())
  409. # - get the header information for column names
  410. if not self.no_header_row:
  411. self._columns = next(self._reader)
  412. # field names may override fields from the header
  413. if self.custom_header is not None:
  414. if not self.no_header_row:
  415. warning_queue.put("Ignoring header row on %s: %s" % (self.name, str(self._columns)))
  416. self._columns = self.custom_header
  417. elif self.no_header_row:
  418. raise ValueError("Error: No field name information available")
  419. def get_line(self):
  420. rowRaw = next(self._reader)
  421. if len(self._columns) != len(rowRaw):
  422. raise Exception("Error: '%s' line %d has an inconsistent number of columns: %s" % (self.name, self._reader.line_num, str(row)))
  423. row = {}
  424. for key, value in zip(self._columns, rowRaw): # note: we import all csv fields as strings
  425. # treat empty fields as no entry rather than empty string
  426. if value == '':
  427. continue
  428. row[key] = value if str == unicode else unicode(value, encoding="utf-8")
  429. return row
  430. # ==
  431. usage = """rethinkdb import -d DIR [-c HOST:PORT] [--tls-cert FILENAME] [-p] [--password-file FILENAME]
  432. [--force] [-i (DB | DB.TABLE)] [--clients NUM]
  433. [--shards NUM_SHARDS] [--replicas NUM_REPLICAS]
  434. rethinkdb import -f FILE --table DB.TABLE [-c HOST:PORT] [--tls-cert FILENAME] [-p] [--password-file FILENAME]
  435. [--force] [--clients NUM] [--format (csv | json)] [--pkey PRIMARY_KEY]
  436. [--shards NUM_SHARDS] [--replicas NUM_REPLICAS]
  437. [--delimiter CHARACTER] [--custom-header FIELD,FIELD... [--no-header]]"""
  438. help_epilog = '''
  439. EXAMPLES:
  440. rethinkdb import -d rdb_export -c mnemosyne:39500 --clients 128
  441. Import data into a cluster running on host 'mnemosyne' with a client port at 39500,
  442. using 128 client connections and the named export directory.
  443. rethinkdb import -f site_history.csv --format csv --table test.history --pkey count
  444. Import data into a local cluster and the table 'history' in the 'test' database,
  445. using the named CSV file, and using the 'count' field as the primary key.
  446. rethinkdb import -d rdb_export -c hades -p -i test
  447. Import data into a cluster running on host 'hades' which requires a password,
  448. using only the database 'test' from the named export directory.
  449. rethinkdb import -f subscriber_info.json --fields id,name,hashtag --force
  450. Import data into a local cluster using the named JSON file, and only the fields
  451. 'id', 'name', and 'hashtag', overwriting any existing rows with the same primary key.
  452. rethinkdb import -f user_data.csv --delimiter ';' --no-header --custom-header id,name,number
  453. Import data into a local cluster using the named CSV file with no header and instead
  454. use the fields 'id', 'name', and 'number', the delimiter is a semicolon (rather than
  455. a comma).
  456. '''
  457. def parse_options(argv, prog=None):
  458. parser = utils_common.CommonOptionsParser(usage=usage, epilog=help_epilog, prog=prog)
  459. parser.add_option("--clients", dest="clients", metavar="CLIENTS", default=8, help="client connections to use (default: 8)", type="pos_int")
  460. parser.add_option("--hard-durability", dest="durability", action="store_const", default="soft", help="use hard durability writes (slower, uses less memory)", const="hard")
  461. parser.add_option("--force", dest="force", action="store_true", default=False, help="import even if a table already exists, overwriting duplicate primary keys")
  462. parser.add_option("--batch-size", dest="batch_size", default=utils_common.default_batch_size, help=optparse.SUPPRESS_HELP, type="pos_int")
  463. # Replication settings
  464. replicationOptionsGroup = optparse.OptionGroup(parser, "Replication Options")
  465. replicationOptionsGroup.add_option("--shards", dest="create_args", metavar="SHARDS", help="shards to setup on created tables (default: 1)", type="pos_int", action="add_key")
  466. replicationOptionsGroup.add_option("--replicas", dest="create_args", metavar="REPLICAS", help="replicas to setup on created tables (default: 1)", type="pos_int", action="add_key")
  467. parser.add_option_group(replicationOptionsGroup)
  468. # Directory import options
  469. dirImportGroup = optparse.OptionGroup(parser, "Directory Import Options")
  470. dirImportGroup.add_option("-d", "--directory", dest="directory", metavar="DIRECTORY", default=None, help="directory to import data from")
  471. dirImportGroup.add_option("-i", "--import", dest="db_tables", metavar="DB|DB.TABLE", default=[], help="restore only the given database or table (may be specified multiple times)", action="append", type="db_table")
  472. dirImportGroup.add_option("--no-secondary-indexes", dest="indexes", action="store_false", default=None, help="do not create secondary indexes")
  473. parser.add_option_group(dirImportGroup)
  474. # File import options
  475. fileImportGroup = optparse.OptionGroup(parser, "File Import Options")
  476. fileImportGroup.add_option("-f", "--file", dest="file", metavar="FILE", default=None, help="file to import data from", type="file")
  477. fileImportGroup.add_option("--table", dest="import_table", metavar="DB.TABLE", default=None, help="table to import the data into")
  478. fileImportGroup.add_option("--fields", dest="fields", metavar="FIELD,...", default=None, help="limit which fields to use when importing one table")
  479. fileImportGroup.add_option("--format", dest="format", metavar="json|csv", default=None, help="format of the file (default: json, accepts newline delimited json)", type="choice", choices=["json", "csv"])
  480. fileImportGroup.add_option("--pkey", dest="create_args", metavar="PRIMARY_KEY", default=None, help="field to use as the primary key in the table", action="add_key")
  481. parser.add_option_group(fileImportGroup)
  482. # CSV import options
  483. csvImportGroup = optparse.OptionGroup(parser, "CSV Options")
  484. csvImportGroup.add_option("--delimiter", dest="delimiter", metavar="CHARACTER", default=None, help="character separating fields, or '\\t' for tab")
  485. csvImportGroup.add_option("--no-header", dest="no_header", action="store_true", default=None, help="do not read in a header of field names")
  486. csvImportGroup.add_option("--custom-header", dest="custom_header", metavar="FIELD,...", default=None, help="header to use (overriding file header), must be specified if --no-header")
  487. parser.add_option_group(csvImportGroup)
  488. # JSON import options
  489. jsonOptionsGroup = optparse.OptionGroup(parser, "JSON Options")
  490. jsonOptionsGroup.add_option("--max-document-size", dest="max_document_size", metavar="MAX_SIZE", default=0, help="maximum allowed size (bytes) for a single JSON document (default: 128MiB)", type="pos_int")
  491. jsonOptionsGroup.add_option("--max-nesting-depth", dest="max_nesting_depth", metavar="MAX_DEPTH", default=0, help="maximum depth of the JSON documents (default: 100)", type="pos_int")
  492. parser.add_option_group(jsonOptionsGroup)
  493. options, args = parser.parse_args(argv)
  494. # Check validity of arguments
  495. if len(args) != 0:
  496. raise parser.error("No positional arguments supported. Unrecognized option(s): %s" % args)
  497. # - create_args
  498. if options.create_args is None:
  499. options.create_args = {}
  500. # - options based on file/directory import
  501. if options.directory and options.file:
  502. parser.error("-f/--file and -d/--directory can not be used together")
  503. elif options.directory:
  504. if not os.path.exists(options.directory):
  505. parser.error("-d/--directory does not exist: %s" % options.directory)
  506. if not os.path.isdir(options.directory):
  507. parser.error("-d/--directory is not a directory: %s" % options.directory)
  508. options.directory = os.path.realpath(options.directory)
  509. # disallow invalid options
  510. if options.import_table:
  511. parser.error("--table option is not valid when importing a directory")
  512. if options.fields:
  513. parser.error("--fields option is not valid when importing a directory")
  514. if options.format:
  515. parser.error("--format option is not valid when importing a directory")
  516. if options.create_args:
  517. parser.error("--pkey option is not valid when importing a directory")
  518. if options.delimiter:
  519. parser.error("--delimiter option is not valid when importing a directory")
  520. if options.no_header:
  521. parser.error("--no-header option is not valid when importing a directory")
  522. if options.custom_header:
  523. parser.error("table create options are not valid when importing a directory: %s" % ", ".join([x.lower().replace("_", " ") for x in options.custom_header.keys()]))
  524. # check valid options
  525. if not os.path.isdir(options.directory):
  526. parser.error("Directory to import does not exist: %s" % options.directory)
  527. if options.fields and (len(options.db_tables) > 1 or options.db_tables[0].table is None):
  528. parser.error("--fields option can only be used when importing a single table")
  529. elif options.file:
  530. if not os.path.exists(options.file):
  531. parser.error("-f/--file does not exist: %s" % options.file)
  532. if not os.path.isfile(options.file):
  533. parser.error("-f/--file is not a file: %s" % options.file)
  534. options.file = os.path.realpath(options.file)
  535. # format
  536. if options.format is None:
  537. options.format = os.path.splitext(options.file)[1].lstrip('.')
  538. # import_table
  539. if options.import_table:
  540. res = utils_common._tableNameRegex.match(options.import_table)
  541. if res and res.group("table"):
  542. options.import_table = utils_common.DbTable(res.group("db"), res.group("table"))
  543. else:
  544. parser.error("Invalid --table option: %s" % options.import_table)
  545. else:
  546. parser.error("A value is required for --table when importing from a file")
  547. # fields
  548. options.fields = options.fields.split(",") if options.fields else None
  549. # disallow invalid options
  550. if options.db_tables:
  551. parser.error("-i/--import can only be used when importing a directory")
  552. if options.indexes:
  553. parser.error("--no-secondary-indexes can only be used when importing a directory")
  554. if options.format == "csv":
  555. # disallow invalid options
  556. if options.max_document_size:
  557. parser.error("--max_document_size only affects importing JSON documents")
  558. # delimiter
  559. if options.delimiter is None:
  560. options.delimiter = ","
  561. elif options.delimiter == "\\t":
  562. options.delimiter = "\t"
  563. elif len(options.delimiter) != 1:
  564. parser.error("Specify exactly one character for the --delimiter option: %s" % options.delimiter)
  565. # no_header
  566. if options.no_header is None:
  567. options.no_header = False
  568. elif options.custom_header is None:
  569. parser.error("--custom-header is required if --no-header is specified")
  570. # custom_header
  571. if options.custom_header:
  572. options.custom_header = options.custom_header.split(",")
  573. elif options.format == "json":
  574. # disallow invalid options
  575. if options.delimiter is not None:
  576. parser.error("--delimiter option is not valid for json files")
  577. if options.no_header:
  578. parser.error("--no-header option is not valid for json files")
  579. if options.custom_header is not None:
  580. parser.error("--custom-header option is not valid for json files")
  581. # default options
  582. options.format = "json"
  583. if options.max_document_size > 0:
  584. global json_max_buffer_size
  585. json_max_buffer_size=options.max_document_size
  586. options.file = os.path.abspath(options.file)
  587. else:
  588. parser.error("Unrecognized file format: %s" % options.format)
  589. else:
  590. parser.error("Either -f/--file or -d/--directory is required")
  591. # --
  592. # max_nesting_depth
  593. if options.max_nesting_depth > 0:
  594. global max_nesting_depth
  595. max_nesting_depth = options.max_nesting_depth
  596. # --
  597. return options
  598. # This is run for each client requested, and accepts tasks from the reader processes
  599. def table_writer(tables, options, work_queue, error_queue, warning_queue, exit_event, timing_queue):
  600. signal.signal(signal.SIGINT, signal.SIG_IGN) # workers should ignore these
  601. db = table = batch = None
  602. try:
  603. conflict_action = "replace" if options.force else "error"
  604. timePoint = time.time()
  605. while not exit_event.is_set():
  606. # get a batch
  607. try:
  608. db, table, batch = work_queue.get(timeout=0.1)
  609. except Empty:
  610. continue
  611. timing_queue.put(('writer_wait', time.time() - timePoint))
  612. timePoint = time.time()
  613. # shut down when appropriate
  614. if isinstance(batch, StopIteration):
  615. return
  616. # find the table we are working on
  617. table_info = tables[(db, table)]
  618. tbl = query.db(db).table(table)
  619. # write the batch to the database
  620. try:
  621. res = options.retryQuery(
  622. "write batch to %s.%s" % (db, table),
  623. tbl.insert(ast.expr(batch, nesting_depth=max_nesting_depth), durability=options.durability, conflict=conflict_action, ignore_write_hook=True)
  624. )
  625. if res["errors"] > 0:
  626. raise RuntimeError("Error when importing into table '%s.%s': %s" % (db, table, res["first_error"]))
  627. modified = res["inserted"] + res["replaced"] + res["unchanged"]
  628. if modified != len(batch):
  629. raise RuntimeError("The inserted/replaced/unchanged number did not match when importing into table '%s.%s': %s" % (db, table, res["first_error"]))
  630. table_info.add_rows_written(modified)
  631. except errors.ReqlError:
  632. # the error might have been caused by a comm or temporary error causing a partial batch write
  633. for row in batch:
  634. if not table_info.primary_key in row:
  635. raise RuntimeError("Connection error while importing. Current row does not have the specified primary key (%s), so cannot guarantee absence of duplicates" % table_info.primary_key)
  636. res = None
  637. if conflict_action == "replace":
  638. res = options.retryQuery(
  639. "write row to %s.%s" % (db, table),
  640. tbl.insert(ast.expr(row, nesting_depth=max_nesting_depth), durability=durability, conflict=conflict_action, ignore_write_hook=True)
  641. )
  642. else:
  643. existingRow = options.retryQuery(
  644. "read row from %s.%s" % (db, table),
  645. tbl.get(row[table_info.primary_key])
  646. )
  647. if not existingRow:
  648. res = options.retryQuery(
  649. "write row to %s.%s" % (db, table),
  650. tbl.insert(ast.expr(row, nesting_depth=max_nesting_depth), durability=durability, conflict=conflict_action, ignore_write_hook=True)
  651. )
  652. elif existingRow != row:
  653. raise RuntimeError("Duplicate primary key `%s`:\n%s\n%s" % (table_info.primary_key, str(row), str(existingRow)))
  654. if res["errors"] > 0:
  655. raise RuntimeError("Error when importing into table '%s.%s': %s" % (db, table, res["first_error"]))
  656. if res["inserted"] + res["replaced"] + res["unchanged"] != 1:
  657. raise RuntimeError("The inserted/replaced/unchanged number was not 1 when inserting on '%s.%s': %s" % (db, table, res))
  658. table_info.add_rows_written(1)
  659. timing_queue.put(('writer_work', time.time() - timePoint))
  660. timePoint = time.time()
  661. except Exception as e:
  662. error_queue.put(Error(str(e), traceback.format_exc(), "%s.%s" % (db , table)))
  663. exit_event.set()
  664. def update_progress(tables, debug, exit_event, sleep=0.2):
  665. signal.signal(signal.SIGINT, signal.SIG_IGN) # workers should not get these
  666. # give weights to each of the tables based on file size
  667. totalSize = sum([x.bytes_size for x in tables])
  668. for table in tables:
  669. table.weight = float(table.bytes_size) / totalSize
  670. lastComplete = None
  671. startTime = time.time()
  672. readWrites = collections.deque(maxlen=5) # (time, read, write)
  673. readWrites.append((startTime, 0, 0))
  674. readRate = None
  675. writeRate = None
  676. while True:
  677. try:
  678. if exit_event.is_set():
  679. break
  680. complete = read = write = 0
  681. currentTime = time.time()
  682. for table in tables:
  683. complete += table.percentDone * table.weight
  684. if debug:
  685. read += table.rows_read
  686. write += table.rows_written
  687. readWrites.append((currentTime, read, write))
  688. if complete != lastComplete:
  689. timeDelta = readWrites[-1][0] - readWrites[0][0]
  690. if debug and len(readWrites) > 1 and timeDelta > 0:
  691. readRate = max((readWrites[-1][1] - readWrites[0][1]) / timeDelta, 0)
  692. writeRate = max((readWrites[-1][2] - readWrites[0][2]) / timeDelta, 0)
  693. utils_common.print_progress(complete, indent=2, read=readRate, write=writeRate)
  694. lastComplete = complete
  695. time.sleep(sleep)
  696. except KeyboardInterrupt: break
  697. except Exception as e:
  698. if debug:
  699. print(e)
  700. traceback.print_exc()
  701. def import_tables(options, sources, files_ignored=None):
  702. # Make sure this isn't a pre-`reql_admin` cluster - which could result in data loss
  703. # if the user has a database named 'rethinkdb'
  704. utils_common.check_minimum_version(options, "1.6")
  705. start_time = time.time()
  706. tables = dict(((x.db, x.table), x) for x in sources) # (db, table) => table
  707. work_queue = Queue(options.clients * 3)
  708. error_queue = SimpleQueue()
  709. warning_queue = SimpleQueue()
  710. exit_event = multiprocessing.Event()
  711. interrupt_event = multiprocessing.Event()
  712. timing_queue = SimpleQueue()
  713. errors = []
  714. warnings = []
  715. timingSums = {}
  716. pools = []
  717. progressBar = None
  718. progressBarSleep = 0.2
  719. # - setup KeyboardInterupt handler
  720. signal.signal(signal.SIGINT, lambda a, b: utils_common.abort(pools, exit_event))
  721. # - queue draining
  722. def drainQueues():
  723. # error_queue
  724. while not error_queue.empty():
  725. errors.append(error_queue.get())
  726. # warning_queue
  727. while not warning_queue.empty():
  728. warnings.append(warning_queue.get())
  729. # timing_queue
  730. while not timing_queue.empty():
  731. key, value = timing_queue.get()
  732. if not key in timingSums:
  733. timingSums[key] = value
  734. else:
  735. timingSums[key] += value
  736. # - setup dbs and tables
  737. # create missing dbs
  738. needed_dbs = set([x.db for x in sources])
  739. if "rethinkdb" in needed_dbs:
  740. raise RuntimeError("Error: Cannot import tables into the system database: 'rethinkdb'")
  741. options.retryQuery("ensure dbs: %s" % ", ".join(needed_dbs), ast.expr(needed_dbs).set_difference(query.db_list()).for_each(query.db_create(query.row)))
  742. # check for existing tables, or if --force is enabled ones with mis-matched primary keys
  743. existing_tables = dict([
  744. ((x["db"], x["name"]), x["primary_key"]) for x in
  745. options.retryQuery("list tables", query.db("rethinkdb").table("table_config").pluck(["db", "name", "primary_key"]))
  746. ])
  747. already_exist = []
  748. for source in sources:
  749. if (source.db, source.table) in existing_tables:
  750. if not options.force:
  751. already_exist.append("%s.%s" % (source.db, source.table))
  752. elif source.primary_key is None:
  753. source.primary_key = existing_tables[(source.db, source.table)]
  754. elif source.primary_key != existing_tables[(source.db, source.table)]:
  755. raise RuntimeError("Error: Table '%s.%s' already exists with a different primary key: %s (expected: %s)" % (source.db, source.table, existing_tables[(source.db, source.table)], source.primary_key))
  756. if len(already_exist) == 1:
  757. raise RuntimeError("Error: Table '%s' already exists, run with --force to import into the existing table" % already_exist[0])
  758. elif len(already_exist) > 1:
  759. already_exist.sort()
  760. raise RuntimeError("Error: The following tables already exist, run with --force to import into the existing tables:\n %s" % "\n ".join(already_exist))
  761. # - start the import
  762. try:
  763. # - start the progress bar
  764. if not options.quiet:
  765. progressBar = multiprocessing.Process(
  766. target=update_progress,
  767. name="progress bar",
  768. args=(sources, options.debug, exit_event, progressBarSleep)
  769. )
  770. progressBar.start()
  771. pools.append([progressBar])
  772. # - start the writers
  773. writers = []
  774. pools.append(writers)
  775. for i in range(options.clients):
  776. writer = multiprocessing.Process(
  777. target=table_writer, name="table writer %d" % i,
  778. kwargs={
  779. "tables":tables, "options":options,
  780. "work_queue":work_queue, "error_queue":error_queue, "warning_queue":warning_queue, "timing_queue":timing_queue,
  781. "exit_event":exit_event
  782. }
  783. )
  784. writers.append(writer)
  785. writer.start()
  786. # - read the tables options.clients at a time
  787. readers = []
  788. pools.append(readers)
  789. fileIter = iter(sources)
  790. try:
  791. while not exit_event.is_set():
  792. # add a workers to fill up the readers pool
  793. while len(readers) < options.clients:
  794. table = next(fileIter)
  795. reader = multiprocessing.Process(
  796. target=table.read_to_queue, name="table reader %s.%s" % (table.db, table.table),
  797. kwargs={
  798. "fields":options.fields, "batch_size":options.batch_size,
  799. "work_queue":work_queue, "error_queue":error_queue, "warning_queue":warning_queue, "timing_queue":timing_queue,
  800. "exit_event":exit_event
  801. }
  802. )
  803. readers.append(reader)
  804. reader.start()
  805. # drain the queues
  806. drainQueues()
  807. # reap completed tasks
  808. for reader in readers[:]:
  809. if not reader.is_alive():
  810. readers.remove(reader)
  811. if len(readers) == options.clients:
  812. time.sleep(.05)
  813. except StopIteration:
  814. pass # ran out of new tables
  815. # - wait for the last batch of readers to complete
  816. while readers:
  817. # drain the queues
  818. drainQueues()
  819. # drain the work queue to prevent readers from stalling on exit
  820. if exit_event.is_set():
  821. try:
  822. while True:
  823. work_queue.get(timeout=0.1)
  824. except Empty: pass
  825. # watch the readers
  826. for reader in readers[:]:
  827. try:
  828. reader.join(.1)
  829. except Exception: pass
  830. if not reader.is_alive():
  831. readers.remove(reader)
  832. # - append enough StopIterations to signal all writers
  833. for _ in writers:
  834. while True:
  835. if exit_event.is_set():
  836. break
  837. try:
  838. work_queue.put((None, None, StopIteration()), timeout=0.1)
  839. break
  840. except Full: pass
  841. # - wait for all of the writers
  842. for writer in writers[:]:
  843. while writer.is_alive():
  844. writer.join(0.1)
  845. writers.remove(writer)
  846. # - stop the progress bar
  847. if progressBar:
  848. progressBar.join(progressBarSleep * 2)
  849. if not interrupt_event.is_set():
  850. utils_common.print_progress(1, indent=2)
  851. if progressBar.is_alive():
  852. progressBar.terminate()
  853. # - drain queues
  854. drainQueues()
  855. # - final reporting
  856. if not options.quiet:
  857. # if successful, make sure 100% progress is reported
  858. if len(errors) == 0 and not interrupt_event.is_set():
  859. utils_common.print_progress(1.0, indent=2)
  860. # advance past the progress bar
  861. print('')
  862. # report statistics
  863. plural = lambda num, text: "%d %s%s" % (num, text, "" if num == 1 else "s")
  864. print(" %s imported to %s in %.2f secs" % (plural(sum(x.rows_written for x in sources), "row"), plural(len(sources), "table"), time.time() - start_time))
  865. # report debug statistics
  866. if options.debug:
  867. print('Debug timing:')
  868. for key, value in sorted(timingSums.items(), key=lambda x: x[0]):
  869. print(' %s: %.2f' % (key, value))
  870. finally:
  871. signal.signal(signal.SIGINT, signal.SIG_DFL)
  872. drainQueues()
  873. for error in errors:
  874. print("%s" % error.message, file=sys.stderr)
  875. if options.debug and error.traceback:
  876. print(" Traceback:\n%s" % error.traceback, file=sys.stderr)
  877. if len(error.file) == 4:
  878. print(" In file: %s" % error.file, file=sys.stderr)
  879. for warning in warnings:
  880. print("%s" % warning[1], file=sys.stderr)
  881. if options.debug:
  882. print("%s traceback: %s" % (warning[0].__name__, warning[2]), file=sys.stderr)
  883. if len(warning) == 4:
  884. print("In file: %s" % warning[3], file=sys.stderr)
  885. if interrupt_event.is_set():
  886. raise RuntimeError("Interrupted")
  887. if errors:
  888. raise RuntimeError("Errors occurred during import")
  889. if warnings:
  890. raise RuntimeError("Warnings occurred during import")
  891. def parse_sources(options, files_ignored=None):
  892. def parseInfoFile(path):
  893. primary_key = None
  894. indexes = []
  895. with open(path, 'r') as info_file:
  896. metadata = json.load(info_file)
  897. if "primary_key" in metadata:
  898. primary_key = metadata["primary_key"]
  899. if "indexes" in metadata and options.indexes is not False:
  900. indexes = metadata["indexes"]
  901. if "write_hook" in metadata:
  902. write_hook = metadata["write_hook"]
  903. return primary_key, indexes, write_hook
  904. sources = set()
  905. if files_ignored is None:
  906. files_ignored = []
  907. if options.directory and options.file:
  908. raise RuntimeError("Error: Both --directory and --file cannot be specified together")
  909. elif options.file:
  910. db, table = options.import_table
  911. path, ext = os.path.splitext(options.file)
  912. tableTypeOptions = None
  913. if ext == ".json":
  914. tableType = JsonSourceFile
  915. elif ext == ".csv":
  916. tableType = CsvSourceFile
  917. tableTypeOptions = {
  918. 'no_header_row': options.no_header,
  919. 'custom_header': options.custom_header
  920. }
  921. else:
  922. raise Exception("The table type is not recognised: %s" % ext)
  923. # - parse the info file if it exists
  924. primary_key = options.create_args.get('primary_key', None) if options.create_args else None
  925. indexes = []
  926. write_hook = None
  927. infoPath = path + ".info"
  928. if (primary_key is None or options.indexes is not False) and os.path.isfile(infoPath):
  929. infoPrimaryKey, infoIndexes, infoWriteHook = parseInfoFile(infoPath)
  930. if primary_key is None:
  931. primary_key = infoPrimaryKey
  932. if options.indexes is not False:
  933. indexes = infoIndexes
  934. if write_hook is None:
  935. write_hook = infoWriteHook
  936. sources.add(
  937. tableType(
  938. source=options.file,
  939. db=db, table=table,
  940. query_runner=options.retryQuery,
  941. primary_key=primary_key,
  942. indexes=indexes,
  943. write_hook=write_hook,
  944. source_options=tableTypeOptions
  945. )
  946. )
  947. elif options.directory:
  948. # Scan for all files, make sure no duplicated tables with different formats
  949. dbs = False
  950. files_ignored = []
  951. for root, dirs, files in os.walk(options.directory):
  952. if not dbs:
  953. files_ignored.extend([os.path.join(root, f) for f in files])
  954. # The first iteration through sh