PageRenderTime 67ms CodeModel.GetById 20ms RepoModel.GetById 0ms app.codeStats 0ms

/data_io.py

https://github.com/KateDavis/CauseEffectPairsChallenge
Python | 63 lines | 51 code | 12 blank | 0 comment | 2 complexity | 81cf3d11e604c9830285e9624aedb122 MD5 | raw file
  1. import csv
  2. import json
  3. import numpy as np
  4. import os
  5. import pandas as pd
  6. import pickle
  7. def get_paths():
  8. paths = json.loads(open("SETTINGS.json").read())
  9. for key in paths:
  10. paths[key] = os.path.expandvars(paths[key])
  11. return paths
  12. def parse_dataframe(df):
  13. parse_cell = lambda cell: np.fromstring(cell, dtype=np.float, sep=" ")
  14. df = df.applymap(parse_cell)
  15. return df
  16. def read_train_pairs():
  17. train_path = get_paths()["train_pairs_path"]
  18. return parse_dataframe(pd.read_csv(train_path, index_col="SampleID"))
  19. def read_train_target():
  20. path = get_paths()["train_target_path"]
  21. df = pd.read_csv(path, index_col="SampleID")
  22. df = df.rename(columns = dict(zip(df.columns, ["Target", "Details"])))
  23. return df
  24. def read_train_info():
  25. path = get_paths()["train_info_path"]
  26. return pd.read_csv(path, index_col="SampleID")
  27. def read_valid_pairs():
  28. valid_path = get_paths()["valid_pairs_path"]
  29. return parse_dataframe(pd.read_csv(valid_path, index_col="SampleID"))
  30. def read_valid_info():
  31. path = get_paths()["valid_info_path"]
  32. return pd.read_csv(path, index_col="SampleID")
  33. def read_solution():
  34. solution_path = get_paths()["solution_path"]
  35. return pd.read_csv(solution_path, index_col="SampleID")
  36. def save_model(model):
  37. out_path = get_paths()["model_path"]
  38. pickle.dump(model, open(out_path, "w"))
  39. def load_model():
  40. in_path = get_paths()["model_path"]
  41. return pickle.load(open(in_path))
  42. def read_submission():
  43. submission_path = get_paths()["submission_path"]
  44. return pd.read_csv(submission_path, index_col="SampleID")
  45. def write_submission(predictions):
  46. submission_path = get_paths()["submission_path"]
  47. writer = csv.writer(open(submission_path, "w"), lineterminator="\n")
  48. valid = read_valid_pairs()
  49. rows = [x for x in zip(valid.index, predictions)]
  50. writer.writerow(("SampleID", "Target"))
  51. writer.writerows(rows)