/pytorch/vadam/datasets.py

https://github.com/emtiyaz/vadam · Python · 406 lines · 296 code · 97 blank · 13 comment · 48 complexity · be3f66d9be922785bd84e51660f6cb20 MD5 · raw file

  1. import torch
  2. import numpy as np
  3. import torch.utils.data as data
  4. from torch.utils.data.dataloader import DataLoader
  5. import torchvision.datasets as dset
  6. import torchvision.transforms as transforms
  7. import sklearn.model_selection as modsel
  8. ##########################################################
  9. ## PyTorch Dataset for presplit classification datasets ##
  10. ##########################################################
  11. from vadam.data_classes.classification.australian_presplit import AustralianPresplit
  12. from vadam.data_classes.classification.breast_cancer_presplit import BreastCancerPresplit
  13. ######################################################
  14. ## PyTorch Dataset for presplit regression datasets ##
  15. ######################################################
  16. from vadam.data_classes.regression.boston_presplit import BostonPresplit
  17. from vadam.data_classes.regression.concrete_presplit import ConcretePresplit
  18. from vadam.data_classes.regression.energy_presplit import EnergyPresplit
  19. from vadam.data_classes.regression.kin8nm_presplit import Kin8nmPresplit
  20. from vadam.data_classes.regression.naval_presplit import NavalPresplit
  21. from vadam.data_classes.regression.powerplant_presplit import PowerplantPresplit
  22. from vadam.data_classes.regression.wine_presplit import WinePresplit
  23. from vadam.data_classes.regression.yacht_presplit import YachtPresplit
  24. DEFAULT_DATA_FOLDER = "./data"
  25. ################################################
  26. ## Construct class for dealing with data sets ##
  27. ################################################
  28. class Dataset():
  29. def __init__(self, data_set, data_folder=DEFAULT_DATA_FOLDER):
  30. super(type(self), self).__init__()
  31. if data_set == "mnist":
  32. self.train_set = dset.MNIST(root = data_folder,
  33. train = True,
  34. transform = transforms.ToTensor(),
  35. download = True)
  36. self.test_set = dset.MNIST(root = data_folder,
  37. train = False,
  38. transform = transforms.ToTensor())
  39. self.task = "classification"
  40. self.num_features = 28 * 28
  41. self.num_classes = 10
  42. elif data_set == "australian_presplit":
  43. self.train_set = AustralianPresplit(root = data_folder,
  44. train = True)
  45. self.test_set = AustralianPresplit(root = data_folder,
  46. train = False)
  47. self.task = "classification"
  48. self.num_features = 14
  49. self.num_classes = 2
  50. elif data_set == "breastcancer_presplit":
  51. self.train_set = BreastCancerPresplit(root = data_folder,
  52. train = True)
  53. self.test_set = BreastCancerPresplit(root = data_folder,
  54. train = False)
  55. self.task = "classification"
  56. self.num_features = 10
  57. self.num_classes = 2
  58. elif data_set in ["boston" + str(i) for i in range(20)]:
  59. self.train_set = BostonPresplit(root = data_folder,
  60. data_set = data_set,
  61. train = True)
  62. self.test_set = BostonPresplit(root = data_folder,
  63. data_set = data_set,
  64. train = False)
  65. self.task = "regression"
  66. self.num_features = 13
  67. self.num_classes = None
  68. elif data_set in ["concrete" + str(i) for i in range(20)]:
  69. self.train_set = ConcretePresplit(root = data_folder,
  70. data_set = data_set,
  71. train = True)
  72. self.test_set = ConcretePresplit(root = data_folder,
  73. data_set = data_set,
  74. train = False)
  75. self.task = "regression"
  76. self.num_features = 8
  77. self.num_classes = None
  78. elif data_set in ["energy" + str(i) for i in range(20)]:
  79. self.train_set = EnergyPresplit(root = data_folder,
  80. data_set = data_set,
  81. train = True)
  82. self.test_set = EnergyPresplit(root = data_folder,
  83. data_set = data_set,
  84. train = False)
  85. self.task = "regression"
  86. self.num_features = 8
  87. self.num_classes = None
  88. elif data_set in ["kin8nm" + str(i) for i in range(20)]:
  89. self.train_set = Kin8nmPresplit(root = data_folder,
  90. data_set = data_set,
  91. train = True)
  92. self.test_set = Kin8nmPresplit(root = data_folder,
  93. data_set = data_set,
  94. train = False)
  95. self.task = "regression"
  96. self.num_features = 8
  97. self.num_classes = None
  98. elif data_set in ["naval" + str(i) for i in range(20)]:
  99. self.train_set = NavalPresplit(root = data_folder,
  100. data_set = data_set,
  101. train = True)
  102. self.test_set = NavalPresplit(root = data_folder,
  103. data_set = data_set,
  104. train = False)
  105. self.task = "regression"
  106. self.num_features = 16
  107. self.num_classes = None
  108. elif data_set in ["powerplant" + str(i) for i in range(20)]:
  109. self.train_set = PowerplantPresplit(root = data_folder,
  110. data_set = data_set,
  111. train = True)
  112. self.test_set = PowerplantPresplit(root = data_folder,
  113. data_set = data_set,
  114. train = False)
  115. self.task = "regression"
  116. self.num_features = 4
  117. self.num_classes = None
  118. elif data_set in ["wine" + str(i) for i in range(20)]:
  119. self.train_set = WinePresplit(root = data_folder,
  120. data_set = data_set,
  121. train = True)
  122. self.test_set = WinePresplit(root = data_folder,
  123. data_set = data_set,
  124. train = False)
  125. self.task = "regression"
  126. self.num_features = 11
  127. self.num_classes = None
  128. elif data_set in ["yacht" + str(i) for i in range(20)]:
  129. self.train_set = YachtPresplit(root = data_folder,
  130. data_set = data_set,
  131. train = True)
  132. self.test_set = YachtPresplit(root = data_folder,
  133. data_set = data_set,
  134. train = False)
  135. self.task = "regression"
  136. self.num_features = 6
  137. self.num_classes = None
  138. else:
  139. RuntimeError("Unknown data set")
  140. def get_train_size(self):
  141. return len(self.train_set)
  142. def get_test_size(self):
  143. return len(self.test_set)
  144. def get_train_loader(self, batch_size, shuffle=True):
  145. train_loader = DataLoader(dataset = self.train_set,
  146. batch_size = batch_size,
  147. shuffle = shuffle)
  148. return train_loader
  149. def get_test_loader(self, batch_size, shuffle=False):
  150. test_loader = DataLoader(dataset = self.test_set,
  151. batch_size = batch_size,
  152. shuffle = shuffle)
  153. return test_loader
  154. def load_full_train_set(self, use_cuda=torch.cuda.is_available()):
  155. full_train_loader = DataLoader(dataset = self.train_set,
  156. batch_size = len(self.train_set),
  157. shuffle = False)
  158. x_train, y_train = next(iter(full_train_loader))
  159. if use_cuda:
  160. x_train, y_train = x_train.cuda(), y_train.cuda()
  161. return x_train, y_train
  162. def load_full_test_set(self, use_cuda=torch.cuda.is_available()):
  163. full_test_loader = DataLoader(dataset = self.test_set,
  164. batch_size = len(self.test_set),
  165. shuffle = False)
  166. x_test, y_test = next(iter(full_test_loader))
  167. if use_cuda:
  168. x_test, y_test = x_test.cuda(), y_test.cuda()
  169. return x_test, y_test
  170. #######################################################################
  171. ## Construct class for dealing with data sets using cross-validation ##
  172. #######################################################################
  173. class DatasetCV():
  174. def __init__(self, data_set, n_splits=3, seed=None, data_folder=DEFAULT_DATA_FOLDER):
  175. super(type(self), self).__init__()
  176. self.n_splits = n_splits
  177. self.seed = seed
  178. self.current_split = 0
  179. if data_set == "mnist":
  180. self.data = dset.MNIST(root = data_folder,
  181. train = True,
  182. transform = transforms.ToTensor(),
  183. download = True)
  184. self.task = "classification"
  185. self.num_features = 28 * 28
  186. self.num_classes = 10
  187. elif data_set == "australian_presplit":
  188. self.data = AustralianPresplit(root = data_folder,
  189. train = True)
  190. self.task = "classification"
  191. self.num_features = 14
  192. self.num_classes = 2
  193. elif data_set == "breastcancer_presplit":
  194. self.data = BreastCancerPresplit(root = data_folder,
  195. train = True)
  196. self.task = "classification"
  197. self.num_features = 10
  198. self.num_classes = 2
  199. elif data_set in ["boston" + str(i) for i in range(20)]:
  200. self.data = BostonPresplit(root = data_folder,
  201. data_set = data_set,
  202. train = True)
  203. self.task = "regression"
  204. self.num_features = 13
  205. self.num_classes = None
  206. elif data_set in ["concrete" + str(i) for i in range(20)]:
  207. self.data = ConcretePresplit(root = data_folder,
  208. data_set = data_set,
  209. train = True)
  210. self.task = "regression"
  211. self.num_features = 8
  212. self.num_classes = None
  213. elif data_set in ["energy" + str(i) for i in range(20)]:
  214. self.data = EnergyPresplit(root = data_folder,
  215. data_set = data_set,
  216. train = True)
  217. self.task = "regression"
  218. self.num_features = 8
  219. self.num_classes = None
  220. elif data_set in ["kin8nm" + str(i) for i in range(20)]:
  221. self.data = Kin8nmPresplit(root = data_folder,
  222. data_set = data_set,
  223. train = True)
  224. self.task = "regression"
  225. self.num_features = 8
  226. self.num_classes = None
  227. elif data_set in ["naval" + str(i) for i in range(20)]:
  228. self.data = NavalPresplit(root = data_folder,
  229. data_set = data_set,
  230. train = True)
  231. self.task = "regression"
  232. self.num_features = 16
  233. self.num_classes = None
  234. elif data_set in ["powerplant" + str(i) for i in range(20)]:
  235. self.data = PowerplantPresplit(root = data_folder,
  236. data_set = data_set,
  237. train = True)
  238. self.task = "regression"
  239. self.num_features = 4
  240. self.num_classes = None
  241. elif data_set in ["wine" + str(i) for i in range(20)]:
  242. self.data = WinePresplit(root = data_folder,
  243. data_set = data_set,
  244. train = True)
  245. self.task = "regression"
  246. self.num_features = 11
  247. self.num_classes = None
  248. elif data_set in ["yacht" + str(i) for i in range(20)]:
  249. self.data = YachtPresplit(root = data_folder,
  250. data_set = data_set,
  251. train = True)
  252. self.task = "regression"
  253. self.num_features = 6
  254. self.num_classes = None
  255. else:
  256. RuntimeError("Unknown data set")
  257. # Store CV splits
  258. cv = modsel.KFold(n_splits=n_splits, shuffle=True, random_state=seed)
  259. splits = cv.split(range(len(self.data)))
  260. self.split_idx_val = []
  261. for (_, idx_val) in splits:
  262. self.split_idx_val.append(idx_val)
  263. def get_full_data_size(self):
  264. return len(self.data)
  265. def get_current_split(self):
  266. return self.current_split
  267. def set_current_split(self, split):
  268. if split >= 0 and split <= self.n_splits-1:
  269. self.current_split = split
  270. else:
  271. RuntimeError("Split higher than number of splits")
  272. def _get_current_val_idx(self):
  273. return self.split_idx_val[self.current_split]
  274. def _get_current_train_idx(self):
  275. return np.setdiff1d(range(len(self.data)), self.split_idx_val[self.current_split])
  276. def get_current_val_size(self):
  277. return len(self._get_current_val_idx())
  278. def get_current_train_size(self):
  279. return len(self.data) - len(self._get_current_val_idx())
  280. def get_current_train_loader(self, batch_size, shuffle=True):
  281. train_set = Subset(self.data, self._get_current_train_idx())
  282. train_loader = DataLoader(dataset = train_set,
  283. batch_size = batch_size,
  284. shuffle = shuffle)
  285. return train_loader
  286. def get_current_val_loader(self, batch_size, shuffle=True):
  287. val_set = Subset(self.data, self._get_current_val_idx())
  288. val_loader = DataLoader(dataset = val_set,
  289. batch_size = batch_size,
  290. shuffle = shuffle)
  291. return val_loader
  292. def load_current_train_set(self, use_cuda=torch.cuda.is_available()):
  293. x_train, y_train = self.data.__getitem__(self._get_current_train_idx())
  294. if use_cuda:
  295. x_train, y_train = x_train.cuda(), y_train.cuda()
  296. return x_train, y_train
  297. def load_current_val_set(self, use_cuda=torch.cuda.is_available()):
  298. x_test, y_test = self.data.__getitem__(self._get_current_val_idx())
  299. if use_cuda:
  300. x_test, y_test = x_test.cuda(), y_test.cuda()
  301. return x_test, y_test
  302. class Subset(data.Dataset):
  303. def __init__(self, dataset, indices):
  304. self.dataset = dataset
  305. self.indices = indices
  306. def __getitem__(self, idx):
  307. return self.dataset[self.indices[idx]]
  308. def __len__(self):
  309. return len(self.indices)