/learn2learn/vision/benchmarks/__init__.py

https://github.com/learnables/learn2learn · Python · 128 lines · 75 code · 12 blank · 41 comment · 0 complexity · 008cb130d5515843ff55689f0b3414dd MD5 · raw file

  1. #!/usr/bin/env python3
  2. """
  3. The benchmark modules provides a convenient interface to standardized benchmarks in the literature.
  4. It provides train/validation/test TaskDatasets and TaskTransforms for pre-defined datasets.
  5. This utility is useful for researchers to compare new algorithms against existing benchmarks.
  6. For a more fine-grained control over tasks and data, we recommend directly using `l2l.data.TaskDataset` and `l2l.data.TaskTransforms`.
  7. """
  8. import os
  9. import learn2learn as l2l
  10. from collections import namedtuple
  11. from .omniglot_benchmark import omniglot_tasksets
  12. from .mini_imagenet_benchmark import mini_imagenet_tasksets
  13. from .tiered_imagenet_benchmark import tiered_imagenet_tasksets
  14. from .fc100_benchmark import fc100_tasksets
  15. from .cifarfs_benchmark import cifarfs_tasksets
  16. __all__ = ['list_tasksets', 'get_tasksets']
  17. BenchmarkTasksets = namedtuple('BenchmarkTasksets', ('train', 'validation', 'test'))
  18. _TASKSETS = {
  19. 'omniglot': omniglot_tasksets,
  20. 'mini-imagenet': mini_imagenet_tasksets,
  21. 'tiered-imagenet': tiered_imagenet_tasksets,
  22. 'fc100': fc100_tasksets,
  23. 'cifarfs': cifarfs_tasksets,
  24. }
  25. def list_tasksets():
  26. """
  27. [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/benchmarks/)
  28. **Description**
  29. Returns a list of all available benchmarks.
  30. **Example**
  31. ~~~python
  32. for name in l2l.vision.benchmarks.list_tasksets():
  33. print(name)
  34. tasksets = l2l.vision.benchmarks.get_tasksets(name)
  35. ~~~
  36. """
  37. return _TASKSETS.keys()
  38. def get_tasksets(
  39. name,
  40. train_ways=5,
  41. train_samples=10,
  42. test_ways=5,
  43. test_samples=10,
  44. num_tasks=-1,
  45. root='~/data',
  46. device=None,
  47. **kwargs,
  48. ):
  49. """
  50. [[Source]](https://github.com/learnables/learn2learn/blob/master/learn2learn/vision/benchmarks/)
  51. **Description**
  52. Returns the tasksets for a particular benchmark, using literature standard data and task transformations.
  53. The returned object is a namedtuple with attributes `train`, `validation`, `test` which
  54. correspond to their respective TaskDatasets.
  55. See `examples/vision/maml_miniimagenet.py` for an example.
  56. **Arguments**
  57. * **name** (str) - The name of the benchmark. Full list in `list_tasksets()`.
  58. * **train_ways** (int, *optional*, default=5) - The number of classes per train tasks.
  59. * **train_samples** (int, *optional*, default=10) - The number of samples per train tasks.
  60. * **test_ways** (int, *optional*, default=5) - The number of classes per test tasks. Also used for validation tasks.
  61. * **test_samples** (int, *optional*, default=10) - The number of samples per test tasks. Also used for validation tasks.
  62. * **num_tasks** (int, *optional*, default=-1) - The number of tasks in each TaskDataset.
  63. * **root** (str, *optional*, default='~/data') - Where the data is stored.
  64. **Example**
  65. ~~~python
  66. train_tasks, validation_tasks, test_tasks = l2l.vision.benchmarks.get_tasksets('omniglot')
  67. batch = train_tasks.sample()
  68. or:
  69. tasksets = l2l.vision.benchmarks.get_tasksets('omniglot')
  70. batch = tasksets.train.sample()
  71. ~~~
  72. """
  73. root = os.path.expanduser(root)
  74. if device is not None:
  75. raise NotImplementedError('Device other than None not implemented. (yet)')
  76. # Load task-specific data and transforms
  77. datasets, transforms = _TASKSETS[name](train_ways=train_ways,
  78. train_samples=train_samples,
  79. test_ways=test_ways,
  80. test_samples=test_samples,
  81. root=root,
  82. **kwargs)
  83. train_dataset, validation_dataset, test_dataset = datasets
  84. train_transforms, validation_transforms, test_transforms = transforms
  85. # Instantiate the tasksets
  86. train_tasks = l2l.data.TaskDataset(
  87. dataset=train_dataset,
  88. task_transforms=train_transforms,
  89. num_tasks=num_tasks,
  90. )
  91. validation_tasks = l2l.data.TaskDataset(
  92. dataset=validation_dataset,
  93. task_transforms=validation_transforms,
  94. num_tasks=num_tasks,
  95. )
  96. test_tasks = l2l.data.TaskDataset(
  97. dataset=test_dataset,
  98. task_transforms=test_transforms,
  99. num_tasks=num_tasks,
  100. )
  101. return BenchmarkTasksets(train_tasks, validation_tasks, test_tasks)