/checkmate.py

https://github.com/vonclites/checkmate
Python | 149 lines | 130 code | 2 blank | 17 comment | 0 complexity | 6eb309d87ed175a4841d09014eececda MD5 | raw file
  1. import os
  2. import glob
  3. import json
  4. import numpy as np
  5. import tensorflow as tf
  6. class BestCheckpointSaver(object):
  7. """Maintains a directory containing only the best n checkpoints
  8. Inside the directory is a best_checkpoints JSON file containing a dictionary
  9. mapping of the best checkpoint filepaths to the values by which the checkpoints
  10. are compared. Only the best n checkpoints are contained in the directory and JSON file.
  11. This is a light-weight wrapper class only intended to work in simple,
  12. non-distributed settings. It is not intended to work with the tf.Estimator
  13. framework.
  14. """
  15. def __init__(self, save_dir, num_to_keep=1, maximize=True, saver=None):
  16. """Creates a `BestCheckpointSaver`
  17. `BestCheckpointSaver` acts as a wrapper class around a `tf.train.Saver`
  18. Args:
  19. save_dir: The directory in which the checkpoint files will be saved
  20. num_to_keep: The number of best checkpoint files to retain
  21. maximize: Define 'best' values to be the highest values. For example,
  22. set this to True if selecting for the checkpoints with the highest
  23. given accuracy. Or set to False to select for checkpoints with the
  24. lowest given error rate.
  25. saver: A `tf.train.Saver` to use for saving checkpoints. A default
  26. `tf.train.Saver` will be created if none is provided.
  27. """
  28. self._num_to_keep = num_to_keep
  29. self._save_dir = save_dir
  30. self._save_path = os.path.join(save_dir, 'best.ckpt')
  31. self._maximize = maximize
  32. self._saver = saver if saver else tf.train.Saver(
  33. max_to_keep=None,
  34. save_relative_paths=True
  35. )
  36. if not os.path.exists(save_dir):
  37. os.makedirs(save_dir)
  38. self.best_checkpoints_file = os.path.join(save_dir, 'best_checkpoints')
  39. def handle(self, value, sess, global_step_tensor):
  40. """Updates the set of best checkpoints based on the given result.
  41. Args:
  42. value: The value by which to rank the checkpoint.
  43. sess: A tf.Session to use to save the checkpoint
  44. global_step_tensor: A `tf.Tensor` represent the global step
  45. """
  46. global_step = sess.run(global_step_tensor)
  47. current_ckpt = 'best.ckpt-{}'.format(global_step)
  48. value = float(value)
  49. if not os.path.exists(self.best_checkpoints_file):
  50. self._save_best_checkpoints_file({current_ckpt: value})
  51. self._saver.save(sess, self._save_path, global_step_tensor)
  52. return
  53. best_checkpoints = self._load_best_checkpoints_file()
  54. if len(best_checkpoints) < self._num_to_keep:
  55. best_checkpoints[current_ckpt] = value
  56. self._save_best_checkpoints_file(best_checkpoints)
  57. self._saver.save(sess, self._save_path, global_step_tensor)
  58. return
  59. if self._maximize:
  60. should_save = not all(current_best >= value
  61. for current_best in best_checkpoints.values())
  62. else:
  63. should_save = not all(current_best <= value
  64. for current_best in best_checkpoints.values())
  65. if should_save:
  66. best_checkpoint_list = self._sort(best_checkpoints)
  67. worst_checkpoint = os.path.join(self._save_dir,
  68. best_checkpoint_list.pop(-1)[0])
  69. self._remove_outdated_checkpoint_files(worst_checkpoint)
  70. self._update_internal_saver_state(best_checkpoint_list)
  71. best_checkpoints = dict(best_checkpoint_list)
  72. best_checkpoints[current_ckpt] = value
  73. self._save_best_checkpoints_file(best_checkpoints)
  74. self._saver.save(sess, self._save_path, global_step_tensor)
  75. def _save_best_checkpoints_file(self, updated_best_checkpoints):
  76. with open(self.best_checkpoints_file, 'w') as f:
  77. json.dump(updated_best_checkpoints, f, indent=3)
  78. def _remove_outdated_checkpoint_files(self, worst_checkpoint):
  79. os.remove(os.path.join(self._save_dir, 'checkpoint'))
  80. for ckpt_file in glob.glob(worst_checkpoint + '.*'):
  81. os.remove(ckpt_file)
  82. def _update_internal_saver_state(self, best_checkpoint_list):
  83. best_checkpoint_files = [
  84. (ckpt[0], np.inf) # TODO: Try to use actual file timestamp
  85. for ckpt in best_checkpoint_list
  86. ]
  87. self._saver.set_last_checkpoints_with_time(best_checkpoint_files)
  88. def _load_best_checkpoints_file(self):
  89. with open(self.best_checkpoints_file, 'r') as f:
  90. best_checkpoints = json.load(f)
  91. return best_checkpoints
  92. def _sort(self, best_checkpoints):
  93. best_checkpoints = [
  94. (ckpt, best_checkpoints[ckpt])
  95. for ckpt in sorted(best_checkpoints,
  96. key=best_checkpoints.get,
  97. reverse=self._maximize)
  98. ]
  99. return best_checkpoints
  100. def get_best_checkpoint(best_checkpoint_dir, select_maximum_value=True):
  101. """ Returns filepath to the best checkpoint
  102. Reads the best_checkpoints file in the best_checkpoint_dir directory.
  103. Returns the filepath in the best_checkpoints file associated with
  104. the highest value if select_maximum_value is True, or the filepath
  105. associated with the lowest value if select_maximum_value is False.
  106. Args:
  107. best_checkpoint_dir: Directory containing best_checkpoints JSON file
  108. select_maximum_value: If True, select the filepath associated
  109. with the highest value. Otherwise, select the filepath associated
  110. with the lowest value.
  111. Returns:
  112. The full path to the best checkpoint file
  113. """
  114. best_checkpoints_file = os.path.join(best_checkpoint_dir, 'best_checkpoints')
  115. assert os.path.exists(best_checkpoints_file)
  116. with open(best_checkpoints_file, 'r') as f:
  117. best_checkpoints = json.load(f)
  118. best_checkpoints = [
  119. ckpt for ckpt in sorted(best_checkpoints,
  120. key=best_checkpoints.get,
  121. reverse=select_maximum_value)
  122. ]
  123. return os.path.join(best_checkpoint_dir, best_checkpoints[0])