/scripts/pack.py

https://github.com/LeelaChessZero/lczero-training
Python | 85 lines | 65 code | 18 blank | 2 comment | 11 complexity | 6ef3bf237cdd36ea2a8a0509a8ee3334 MD5 | raw file
  1. #!/usr/bin/env python3
  2. import glob
  3. import os
  4. import argparse
  5. import gzip
  6. import bz2
  7. import struct
  8. import numpy as np
  9. from multiprocessing import Pool
  10. RECORD_SIZE = 8276
  11. def get_uncompressed_size(filename):
  12. with open(filename, 'rb') as f:
  13. f.seek(-4, 2)
  14. return struct.unpack('I', f.read(4))[0]
  15. def get_sorted_chunk_ids(dirs):
  16. ids = []
  17. for d in dirs:
  18. for f in glob.glob(os.path.join(d, "training.*.gz")):
  19. ids.append(int(os.path.basename(f).split('.')[-2]))
  20. ids.sort()
  21. return ids
  22. def pack(ids):
  23. plies = []
  24. fout_name = os.path.join(argv.output, '{}-{}.bz2'.format(ids[0], ids[-1]))
  25. with bz2.open(fout_name, 'xb') as fout:
  26. for tid in ids:
  27. fin_name = os.path.join(argv.input, 'training.{}.gz'.format(tid))
  28. plies.append(get_uncompressed_size(fin_name) // RECORD_SIZE)
  29. with gzip.open(fin_name, 'rb') as fin:
  30. fout.write(fin.read())
  31. if argv.remove:
  32. os.remove(fin_name)
  33. plylist = np.array(plies, dtype=np.int16)
  34. size = struct.pack('I', len(plylist) * 2)
  35. fout.write(plylist.tobytes())
  36. fout.write(size)
  37. print("Written '{}' {} records".format(fout_name, np.sum(plies)))
  38. def main():
  39. if not os.path.exists(argv.output):
  40. os.makedirs(argv.output)
  41. print("Created directory '{}'".format(argv.output))
  42. ids = get_sorted_chunk_ids([argv.input])
  43. n = len(ids) // argv.number
  44. m = argv.number
  45. print("Processing {} ids, {} - {} ({}x{})".format(len(ids), ids[0],
  46. ids[-1], n, m))
  47. packs = [ids[i * m:i * m + m] for i in range(n)]
  48. # add remaining ids to last pack
  49. packs[-1] += ids[n * m + m:]
  50. with Pool() as pool:
  51. pool.map(pack, packs)
  52. if __name__ == "__main__":
  53. argparser = argparse.ArgumentParser(description=\
  54. 'Repack training.*.gz files in batches of bz2 format.')
  55. argparser.add_argument('-i', '--input', type=str, help='input directory')
  56. argparser.add_argument('-o', '--output', type=str, help='output directory')
  57. argparser.add_argument('-r',
  58. '--remove',
  59. action='store_true',
  60. help='remove input files while processing')
  61. argparser.add_argument('-n',
  62. '--number',
  63. type=int,
  64. default=1000,
  65. help='number of games to repack per bz2 package')
  66. argv = argparser.parse_args()
  67. main()