/autoseries/ingestion/dataset.py

https://github.com/DenisVorotyntsev/AutoSeries
Python | 145 lines | 100 code | 25 blank | 20 comment | 10 complexity | 06e6463c16bd2547ab1a02670ff58c1e MD5 | raw file
  1. """
  2. AutoWSL datasets.
  3. """
  4. import copy
  5. from os.path import join
  6. from datetime import datetime
  7. import numpy as np
  8. import pandas as pd
  9. import yaml
  10. from common import get_logger
  11. TYPE_MAP = {
  12. 'cat': str,
  13. 'multi-cat': str,
  14. 'str': str,
  15. 'num': np.float64,
  16. 'timestamp': 'str'
  17. }
  18. VERBOSITY_LEVEL = 'WARNING'
  19. LOGGER = get_logger(VERBOSITY_LEVEL, __file__)
  20. PRIMARY_TIMESTAMP = 'primary_timestamp'
  21. LABEL_NAME = 'label'
  22. TIMESTAMP_TYPE_NAME = 'timestamp'
  23. TRAIN_FILE = 'train.data'
  24. TEST_FILE = 'test.data'
  25. TIME_FILE = 'test_time.data'
  26. INFO_FILE = 'info.yaml'
  27. def _date_parser(millisecs):
  28. if np.isnan(float(millisecs)):
  29. return millisecs
  30. return datetime.fromtimestamp(float(millisecs))
  31. class Dataset:
  32. """"Dataset"""
  33. def __init__(self, dataset_dir):
  34. """
  35. train_dataset, test_dataset: list of strings
  36. train_label: np.array
  37. """
  38. self.dataset_dir_ = dataset_dir
  39. self.metadata_ = self._read_metadata(join(dataset_dir, INFO_FILE))
  40. self.train_dataset = None
  41. self.test_dataset = None
  42. self.get_train()
  43. self.get_test()
  44. self._pred_time = self._get_pred_time()
  45. self._primary_timestamp = self.metadata_[PRIMARY_TIMESTAMP]
  46. def get_train(self):
  47. """get train"""
  48. if self.train_dataset is None:
  49. self.train_dataset = self._read_dataset(
  50. join(self.dataset_dir_, TRAIN_FILE))
  51. return copy.deepcopy(self.train_dataset)
  52. def get_test(self):
  53. """get test"""
  54. if self.test_dataset is None:
  55. self.test_dataset = self._read_dataset(
  56. join(self.dataset_dir_, TEST_FILE))
  57. return copy.deepcopy(self.test_dataset)
  58. def get_metadata(self):
  59. """get metadata"""
  60. return copy.deepcopy(self.metadata_)
  61. def is_end(self, idx):
  62. """whether time idx is the end of data"""
  63. return idx == len(self._pred_time)
  64. def _get_period(self, idx1, idx2):
  65. next_time = self._get_time_point(idx2)
  66. timestamp = self.test_dataset[self._primary_timestamp]
  67. select = timestamp < next_time
  68. if idx1 is not None:
  69. last_time = self._get_time_point(idx1)
  70. select &= timestamp >= last_time
  71. return self.test_dataset[select]
  72. def get_history(self, idx):
  73. """get the new history before time idx"""
  74. if idx > 0:
  75. ret = self._get_period(idx-1, idx)
  76. else:
  77. ret = self._get_period(None, idx)
  78. return copy.deepcopy(ret)
  79. def get_next_pred(self, idx):
  80. """get the next pred time point (idx) (maybe batch data)"""
  81. next_time = self._get_time_point(idx)
  82. select = self.test_dataset[self._primary_timestamp] == next_time
  83. data = self.test_dataset[select].drop(
  84. self.metadata_[LABEL_NAME], axis=1)
  85. return copy.deepcopy(data)
  86. def get_all_history(self, idx):
  87. """get all history before idx"""
  88. return copy.deepcopy(self._get_period(None, idx))
  89. def _get_pred_time(self):
  90. """get the pred time point"""
  91. return pd.read_csv(join(self.dataset_dir_, TIME_FILE),
  92. parse_dates=[self.metadata_[PRIMARY_TIMESTAMP]],
  93. date_parser=_date_parser)
  94. def _get_time_point(self, idx):
  95. return self._pred_time.iloc[idx, 0]
  96. @staticmethod
  97. def _read_metadata(metadata_path):
  98. with open(metadata_path, 'r') as ftmp:
  99. return yaml.safe_load(ftmp)
  100. def _read_dataset(self, dataset_path):
  101. schema = self.metadata_['schema']
  102. table_dtype = {key: TYPE_MAP[val] for key, val in schema.items()}
  103. date_list = [key for key, val in schema.items()
  104. if val == TIMESTAMP_TYPE_NAME]
  105. dataset = pd.read_csv(
  106. dataset_path, sep='\t', dtype=table_dtype,
  107. parse_dates=date_list, date_parser=_date_parser)
  108. return dataset
  109. def get_train_num(self):
  110. """ return the number of train instance """
  111. return self.metadata_["train_num"]
  112. def get_test_num(self):
  113. """ return the number of test instance """
  114. return self.metadata_["test_num"]
  115. def get_test_timestamp(self):
  116. """get timestamps of test data"""
  117. return copy.deepcopy(self.test_dataset[self._primary_timestamp])
  118. def get_pred_timestamp(self):
  119. """get timestamps of pred data"""
  120. return copy.deepcopy(self._pred_time)