/src/purchases_to_jrows.py

https://github.com/datagym-ru/retailhero-recomender-baseline
Python | 161 lines | 126 code | 33 blank | 2 comment | 19 complexity | 9a00118177b13e261dc2ee93345fb6b1 MD5 | raw file
  1. import json
  2. import os
  3. import pandas as pd
  4. from tqdm import tqdm
  5. import config as cfg
  6. from utils import md5_hash
  7. class Transaction:
  8. def __init__(self, transaction_id, transaction_datetime, **kwargs):
  9. self.data = {
  10. **{"tid": transaction_id, "datetime": transaction_datetime, "products": [],},
  11. **kwargs,
  12. }
  13. def add_item(
  14. self, product_id: str, product_quantity: float, trn_sum_from_iss: float, trn_sum_from_red: float,
  15. ) -> None:
  16. p = {
  17. "product_id": product_id,
  18. "quantity": product_quantity,
  19. "s": trn_sum_from_iss,
  20. "r": "0" if trn_sum_from_red is None or pd.isna(trn_sum_from_red) else trn_sum_from_red,
  21. }
  22. self.data["products"].append(p)
  23. def as_dict(self,):
  24. return self.data
  25. def transaction_id(self,):
  26. return self.data["tid"]
  27. class ClientHistory:
  28. def __init__(
  29. self, client_id,
  30. ):
  31. self.data = {
  32. "client_id": client_id,
  33. "transaction_history": [],
  34. }
  35. def add_transaction(
  36. self, transaction,
  37. ):
  38. self.data["transaction_history"].append(transaction)
  39. def as_dict(self,):
  40. return self.data
  41. def client_id(self,):
  42. return self.data["client_id"]
  43. class RowSplitter:
  44. def __init__(
  45. self, output_path, n_shards=16,
  46. ):
  47. self.n_shards = n_shards
  48. os.makedirs(
  49. output_path, exist_ok=True,
  50. )
  51. self.outs = []
  52. for i in range(self.n_shards):
  53. self.outs.append(open(output_path + "/{:02d}.jsons".format(i), "w",))
  54. self._client = None
  55. self._transaction = None
  56. def finish(self,):
  57. self.flush()
  58. for outs in self.outs:
  59. outs.close()
  60. def flush(self,):
  61. if self._client is not None:
  62. self._client.add_transaction(self._transaction.as_dict())
  63. # rows are sharded by cliend_id
  64. shard_idx = md5_hash(self._client.client_id()) % self.n_shards
  65. data = self._client.as_dict()
  66. self.outs[shard_idx].write(json.dumps(data) + "\n")
  67. self._client = None
  68. self._transaction = None
  69. def consume_row(
  70. self, row,
  71. ):
  72. if self._client is not None and self._client.client_id() != row.client_id:
  73. self.flush()
  74. if self._client is None:
  75. self._client = ClientHistory(client_id=row.client_id)
  76. if self._transaction is not None and self._transaction.transaction_id() != row.transaction_id:
  77. self._client.add_transaction(self._transaction.as_dict())
  78. self._transaction = None
  79. if self._transaction is None:
  80. self._transaction = Transaction(
  81. transaction_id=row.transaction_id,
  82. transaction_datetime=row.transaction_datetime,
  83. rpr=row.regular_points_received,
  84. epr=row.express_points_received,
  85. rps=row.regular_points_spent,
  86. eps=row.express_points_spent,
  87. sum=row.purchase_sum,
  88. store_id=row.store_id,
  89. )
  90. self._transaction.add_item(
  91. product_id=row.product_id,
  92. product_quantity=row.product_quantity,
  93. trn_sum_from_iss=row.trn_sum_from_iss,
  94. trn_sum_from_red=row.trn_sum_from_red,
  95. )
  96. def split_data_to_chunks(
  97. input_path, output_dir, n_shards=16,
  98. ):
  99. splitter = RowSplitter(output_path=output_dir, n_shards=n_shards,)
  100. print("split_data_to_chunks: {} -> {}".format(input_path, output_dir,))
  101. for df in tqdm(pd.read_csv(input_path, chunksize=500000,)):
  102. for row in df.itertuples():
  103. splitter.consume_row(row)
  104. splitter.finish()
  105. def calculate_unique_clients_from_input(input_path,):
  106. client_set = set()
  107. print("calculate_unique_clients_from: {}".format(input_path))
  108. for df in tqdm(pd.read_csv(input_path, chunksize=500000,)):
  109. client_set.update(set([row.client_id for row in df.itertuples()]))
  110. return len(client_set)
  111. def calculate_unique_clients_from_output(output_dir,):
  112. import glob
  113. client_cnt = 0
  114. print("calculate_unique_clients_from: {}".format(output_dir))
  115. for js_file in glob.glob(output_dir + "/*.jsons"):
  116. for _ in open(js_file):
  117. client_cnt += 1
  118. return client_cnt
  119. if __name__ == "__main__":
  120. purchases_csv_path = cfg.PURCHASE_CSV_PATH
  121. output_jsons_dir = cfg.JSONS_DIR
  122. split_data_to_chunks(
  123. purchases_csv_path, output_jsons_dir, n_shards=16,
  124. )
  125. # check splitting for correctness
  126. _from_input = calculate_unique_clients_from_input(purchases_csv_path)
  127. _from_output = calculate_unique_clients_from_output(output_jsons_dir)
  128. assert _from_input == _from_output