PageRenderTime 47ms CodeModel.GetById 18ms RepoModel.GetById 0ms app.codeStats 0ms

/zipline/utils/data_source_tables_gen.py

https://gitlab.com/lbennett/zipline
Python | 210 lines | 185 code | 11 blank | 14 comment | 1 complexity | 6633a871828117ff4dee77873a53682e MD5 | raw file
  1. #
  2. # Copyright 2014 Quantopian, Inc.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import sys
  16. import getopt
  17. import traceback
  18. import numpy as np
  19. import pandas as pd
  20. import datetime
  21. import logging
  22. import tables
  23. import gzip
  24. import glob
  25. import os
  26. import random
  27. import csv
  28. import time
  29. from six import print_
  30. FORMAT = "%(asctime)-15s -8s %(message)s"
  31. logging.basicConfig(format=FORMAT, level=logging.INFO)
  32. class Usage(Exception):
  33. def __init__(self, msg):
  34. self.msg = msg
  35. OHLCTableDescription = {'sid': tables.StringCol(14, pos=2),
  36. 'dt': tables.Int64Col(pos=1),
  37. 'open': tables.Float64Col(dflt=np.NaN, pos=3),
  38. 'high': tables.Float64Col(dflt=np.NaN, pos=4),
  39. 'low': tables.Float64Col(dflt=np.NaN, pos=5),
  40. 'close': tables.Float64Col(dflt=np.NaN, pos=6),
  41. "volume": tables.Int64Col(dflt=0, pos=7)}
  42. def process_line(line):
  43. dt = np.datetime64(line["dt"]).astype(np.int64)
  44. sid = line["sid"]
  45. open_p = float(line["open"])
  46. high_p = float(line["high"])
  47. low_p = float(line["low"])
  48. close_p = float(line["close"])
  49. volume = int(line["volume"])
  50. return (dt, sid, open_p, high_p, low_p, close_p, volume)
  51. def parse_csv(csv_reader):
  52. previous_date = None
  53. data = []
  54. dtype = [('dt', 'int64'), ('sid', '|S14'), ('open', float),
  55. ('high', float), ('low', float), ('close', float),
  56. ('volume', int)]
  57. for line in csv_reader:
  58. row = process_line(line)
  59. current_date = line["dt"][:10].replace("-", "")
  60. if previous_date and previous_date != current_date:
  61. rows = np.array(data, dtype=dtype).view(np.recarray)
  62. yield current_date, rows
  63. data = []
  64. data.append(row)
  65. previous_date = current_date
  66. def merge_all_files_into_pytables(file_dir, file_out):
  67. """
  68. process each file into pytables
  69. """
  70. start = None
  71. start = datetime.datetime.now()
  72. out_h5 = tables.openFile(file_out,
  73. mode="w",
  74. title="bars",
  75. filters=tables.Filters(complevel=9,
  76. complib='zlib'))
  77. table = None
  78. for file_in in glob.glob(file_dir + "/*.gz"):
  79. gzip_file = gzip.open(file_in)
  80. expected_header = ["dt", "sid", "open", "high", "low", "close",
  81. "volume"]
  82. csv_reader = csv.DictReader(gzip_file)
  83. header = csv_reader.fieldnames
  84. if header != expected_header:
  85. logging.warn("expected header %s\n" % (expected_header))
  86. logging.warn("header_found %s" % (header))
  87. return
  88. for current_date, rows in parse_csv(csv_reader):
  89. table = out_h5.createTable("/TD", "date_" + current_date,
  90. OHLCTableDescription,
  91. expectedrows=len(rows),
  92. createparents=True)
  93. table.append(rows)
  94. table.flush()
  95. if table is not None:
  96. table.flush()
  97. end = datetime.datetime.now()
  98. diff = (end - start).seconds
  99. logging.debug("finished it took %d." % (diff))
  100. def create_fake_csv(file_in):
  101. fields = ["dt", "sid", "open", "high", "low", "close", "volume"]
  102. gzip_file = gzip.open(file_in, "w")
  103. dict_writer = csv.DictWriter(gzip_file, fieldnames=fields)
  104. current_dt = datetime.date.today() - datetime.timedelta(days=2)
  105. current_dt = pd.Timestamp(current_dt).replace(hour=9)
  106. current_dt = current_dt.replace(minute=30)
  107. end_time = pd.Timestamp(datetime.date.today())
  108. end_time = end_time.replace(hour=16)
  109. last_price = 10.0
  110. while current_dt < end_time:
  111. row = {}
  112. row["dt"] = current_dt
  113. row["sid"] = "test"
  114. last_price += random.randint(-20, 100) / 10000.0
  115. row["close"] = last_price
  116. row["open"] = last_price - 0.01
  117. row["low"] = last_price - 0.02
  118. row["high"] = last_price + 0.02
  119. row["volume"] = random.randint(10, 1000) * 10
  120. dict_writer.writerow(row)
  121. current_dt += datetime.timedelta(minutes=1)
  122. if current_dt.hour > 16:
  123. current_dt += datetime.timedelta(days=1)
  124. current_dt = current_dt.replace(hour=9)
  125. current_dt = current_dt.replace(minute=30)
  126. gzip_file.close()
  127. def main(argv=None):
  128. """
  129. This script cleans minute bars into pytables file
  130. data_source_tables_gen.py
  131. [--tz_in] sets time zone of data only reasonably fast way to use
  132. time.tzset()
  133. [--dir_in] iterates through directory provided of csv files in gzip form
  134. in form:
  135. dt, sid, open, high, low, close, volume
  136. 2012-01-01T12:30:30,1234HT,1, 2,3,4.0
  137. [--fake_csv] creates a fake sample csv to iterate through
  138. [--file_out] determines output file
  139. """
  140. if argv is None:
  141. argv = sys.argv
  142. try:
  143. dir_in = None
  144. file_out = "./all.h5"
  145. fake_csv = None
  146. try:
  147. opts, args = getopt.getopt(argv[1:], "hdft",
  148. ["help",
  149. "dir_in=",
  150. "debug",
  151. "tz_in=",
  152. "fake_csv=",
  153. "file_out="])
  154. except getopt.error as msg:
  155. raise Usage(msg)
  156. for opt, value in opts:
  157. if opt in ("--help", "-h"):
  158. print_(main.__doc__)
  159. if opt in ("-d", "--debug"):
  160. logging.basicConfig(format=FORMAT,
  161. level=logging.DEBUG)
  162. if opt in ("-d", "--dir_in"):
  163. dir_in = value
  164. if opt in ("-o", "--file_out"):
  165. file_out = value
  166. if opt in ("--fake_csv"):
  167. fake_csv = value
  168. if opt in ("--tz_in"):
  169. os.environ['TZ'] = value
  170. time.tzset()
  171. try:
  172. if dir_in:
  173. merge_all_files_into_pytables(dir_in, file_out)
  174. if fake_csv:
  175. create_fake_csv(fake_csv)
  176. except Exception:
  177. error = "An unhandled error occured in the"
  178. error += "data_source_tables_gen.py script."
  179. error += "\n\nTraceback:\n"
  180. error += '-' * 70 + "\n"
  181. error += "".join(traceback.format_tb(sys.exc_info()[2]))
  182. error += repr(sys.exc_info()[1]) + "\n"
  183. error += str(sys.exc_info()[1]) + "\n"
  184. error += '-' * 70 + "\n"
  185. print_(error)
  186. except Usage as err:
  187. print_(err.msg)
  188. print_("for help use --help")
  189. return 2
  190. if __name__ == "__main__":
  191. sys.exit(main())