/torch_geometric/datasets/tosca.py

https://github.com/rusty1s/pytorch_geometric
Python | 97 lines | 85 code | 12 blank | 0 comment | 11 complexity | c0086af9e3a93a16cca4e698fadb7c8b MD5 | raw file
  1. import os
  2. import os.path as osp
  3. import glob
  4. import torch
  5. from torch_geometric.data import (InMemoryDataset, Data, download_url,
  6. extract_zip)
  7. from torch_geometric.io import read_txt_array
  8. class TOSCA(InMemoryDataset):
  9. r"""The TOSCA dataset from the `"Numerical Geometry of Non-Ridig Shapes"
  10. <https://www.amazon.com/Numerical-Geometry-Non-Rigid-Monographs-Computer/
  11. dp/0387733000>`_ book, containing 80 meshes.
  12. Meshes within the same category have the same triangulation and an equal
  13. number of vertices numbered in a compatible way.
  14. .. note::
  15. Data objects hold mesh faces instead of edge indices.
  16. To convert the mesh to a graph, use the
  17. :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
  18. To convert the mesh to a point cloud, use the
  19. :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
  20. sample a fixed number of points on the mesh faces according to their
  21. face area.
  22. Args:
  23. root (string): Root directory where the dataset should be saved.
  24. categories (list, optional): List of categories to include in the
  25. dataset. Can include the categories :obj:`"Cat"`, :obj:`"Centaur"`,
  26. :obj:`"David"`, :obj:`"Dog"`, :obj:`"Gorilla"`, :obj:`"Horse"`,
  27. :obj:`"Michael"`, :obj:`"Victoria"`, :obj:`"Wolf"`. If set to
  28. :obj:`None`, the dataset will contain all categories. (default:
  29. :obj:`None`)
  30. transform (callable, optional): A function/transform that takes in an
  31. :obj:`torch_geometric.data.Data` object and returns a transformed
  32. version. The data object will be transformed before every access.
  33. (default: :obj:`None`)
  34. pre_transform (callable, optional): A function/transform that takes in
  35. an :obj:`torch_geometric.data.Data` object and returns a
  36. transformed version. The data object will be transformed before
  37. being saved to disk. (default: :obj:`None`)
  38. pre_filter (callable, optional): A function that takes in an
  39. :obj:`torch_geometric.data.Data` object and returns a boolean
  40. value, indicating whether the data object should be included in the
  41. final dataset. (default: :obj:`None`)
  42. """
  43. url = 'http://tosca.cs.technion.ac.il/data/toscahires-asci.zip'
  44. categories = [
  45. 'cat', 'centaur', 'david', 'dog', 'gorilla', 'horse', 'michael',
  46. 'victoria', 'wolf'
  47. ]
  48. def __init__(self, root, categories=None, transform=None,
  49. pre_transform=None, pre_filter=None):
  50. categories = self.categories if categories is None else categories
  51. categories = [cat.lower() for cat in categories]
  52. for cat in categories:
  53. assert cat in self.categories
  54. self.categories = categories
  55. super(TOSCA, self).__init__(root, transform, pre_transform, pre_filter)
  56. self.data, self.slices = torch.load(self.processed_paths[0])
  57. @property
  58. def raw_file_names(self):
  59. return ['cat0.vert', 'cat0.tri']
  60. @property
  61. def processed_file_names(self):
  62. return '{}.pt'.format('_'.join([cat[:2] for cat in self.categories]))
  63. def download(self):
  64. path = download_url(self.url, self.raw_dir)
  65. extract_zip(path, self.raw_dir)
  66. os.unlink(path)
  67. def process(self):
  68. data_list = []
  69. for cat in self.categories:
  70. paths = glob.glob(osp.join(self.raw_dir, '{}*.tri'.format(cat)))
  71. paths = [path[:-4] for path in paths]
  72. paths = sorted(paths, key=lambda e: (len(e), e))
  73. for path in paths:
  74. pos = read_txt_array('{}.vert'.format(path))
  75. face = read_txt_array('{}.tri'.format(path), dtype=torch.long)
  76. data = Data(pos=pos, face=face.t().contiguous())
  77. if self.pre_filter is not None and not self.pre_filter(data):
  78. continue
  79. if self.pre_transform is not None:
  80. data = self.pre_transform(data)
  81. data_list.append(data)
  82. torch.save(self.collate(data_list), self.processed_paths[0])