/NMT/iterator.py

https://github.com/VectorFist/RNN-NMT · Python · 92 lines · 76 code · 16 blank · 0 comment · 5 complexity · c1dae4300f5a28ba288a13bdc6d086d5 MD5 · raw file

  1. import tensorflow as tf
  2. import collections
  3. BatchedInput = collections.namedtuple('BatchedInput', ['initializer', 'source', 'target_input',
  4. 'target_output', 'source_sequence_length',
  5. 'target_sequence_length'])
  6. def get_iterator(src_dataset, tgt_dataset, src_vocab_table, tgt_vocab_table,
  7. batch_size, sos, eos, reshuffle_each_iteration=True,
  8. src_max_len=None, tgt_max_len=None):
  9. src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)
  10. tgt_sos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32)
  11. tgt_eos_id = tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32)
  12. output_buffer_size = batch_size * 1000
  13. src_tgt_dataset = tf.data.Dataset.zip((src_dataset, tgt_dataset))
  14. src_tgt_dataset.skip(0)
  15. src_tgt_dataset = src_tgt_dataset.shuffle(output_buffer_size, reshuffle_each_iteration=reshuffle_each_iteration)
  16. src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (tf.string_split([src]).values,
  17. tf.string_split([tgt]).values),
  18. num_parallel_calls=4).prefetch(output_buffer_size)
  19. src_tgt_dataset = src_tgt_dataset.filter(lambda src, tgt: tf.logical_and(tf.size(src) > 0,
  20. tf.size(tgt) > 0))
  21. if src_max_len:
  22. src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src[: src_max_len], tgt),
  23. num_parallel_calls=4).prefetch(output_buffer_size)
  24. if tgt_max_len:
  25. src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src, tgt[: tgt_max_len]),
  26. num_parallel_calls=4).prefetch(output_buffer_size)
  27. src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),
  28. tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
  29. num_parallel_calls=4).prefetch(output_buffer_size)
  30. src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src,
  31. tf.concat(([tgt_sos_id], tgt), 0),
  32. tf.concat((tgt, [tgt_eos_id]), 0)),
  33. num_parallel_calls=4).prefetch(output_buffer_size)
  34. src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt_in, tgt_out: (src, tgt_in, tgt_out,
  35. tf.size(src), tf.size(tgt_in)),
  36. num_parallel_calls=4).prefetch(output_buffer_size)
  37. def key_func(unused_1, unused_2, unused_3, src_len, tgt_len):
  38. if src_max_len:
  39. bucket_width = (src_max_len + 5 - 1) // 5
  40. else:
  41. bucket_width = 10
  42. bucked_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width)
  43. return tf.to_int64(tf.minimum(5, bucked_id))
  44. def reduce_func(unused_key, windowed_data):
  45. return windowed_data.padded_batch(
  46. batch_size,
  47. padded_shapes=(tf.TensorShape([None]), tf.TensorShape([None]),
  48. tf.TensorShape([None]), tf.TensorShape([]), tf.TensorShape([])),
  49. padding_values=(src_eos_id, tgt_eos_id, tgt_eos_id, 0, 0))
  50. batched_dataset = src_tgt_dataset.apply(tf.contrib.data.group_by_window(
  51. key_func=key_func, reduce_func=reduce_func, window_size=batch_size))
  52. batched_iterator = batched_dataset.make_initializable_iterator()
  53. src_ids, tgt_input_ids, tgt_output_ids, src_seq_len, tgt_seq_len = batched_iterator.get_next()
  54. batched_input = BatchedInput(initializer=batched_iterator.initializer,
  55. source=src_ids, target_input=tgt_input_ids,
  56. target_output=tgt_output_ids,
  57. source_sequence_length=src_seq_len,
  58. target_sequence_length=tgt_seq_len)
  59. return batched_input
  60. def get_infer_iterator(src_dataset, src_vocab_table, batch_size, eos, src_max_len=None):
  61. src_eos_id = tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32)
  62. src_dataset = src_dataset.map(lambda src: tf.string_split([src]).values)
  63. if src_max_len:
  64. src_dataset = src_dataset.map(lambda src: src[: src_max_len])
  65. src_dataset = src_dataset.map(lambda src: tf.cast(src_vocab_table.lookup(src), tf.int32))
  66. src_dataset = src_dataset.map(lambda src: (src, tf.size(src)))
  67. batched_dataset = src_dataset.padded_batch(batch_size,
  68. padded_shapes=([-1], []),
  69. padding_values=(src_eos_id, 0))
  70. batched_iterator = batched_dataset.make_initializable_iterator()
  71. src_ids, src_seq_len = batched_iterator.get_next()
  72. batched_input = BatchedInput(initializer=batched_iterator.initializer,
  73. source=src_ids, target_input=None, target_output=None,
  74. source_sequence_length=src_seq_len,
  75. target_sequence_length=None)
  76. return batched_input