/torch_geometric/datasets/tosca.py
Python | 97 lines | 85 code | 12 blank | 0 comment | 11 complexity | c0086af9e3a93a16cca4e698fadb7c8b MD5 | raw file
- import os
- import os.path as osp
- import glob
- import torch
- from torch_geometric.data import (InMemoryDataset, Data, download_url,
- extract_zip)
- from torch_geometric.io import read_txt_array
- class TOSCA(InMemoryDataset):
- r"""The TOSCA dataset from the `"Numerical Geometry of Non-Ridig Shapes"
- <https://www.amazon.com/Numerical-Geometry-Non-Rigid-Monographs-Computer/
- dp/0387733000>`_ book, containing 80 meshes.
- Meshes within the same category have the same triangulation and an equal
- number of vertices numbered in a compatible way.
- .. note::
- Data objects hold mesh faces instead of edge indices.
- To convert the mesh to a graph, use the
- :obj:`torch_geometric.transforms.FaceToEdge` as :obj:`pre_transform`.
- To convert the mesh to a point cloud, use the
- :obj:`torch_geometric.transforms.SamplePoints` as :obj:`transform` to
- sample a fixed number of points on the mesh faces according to their
- face area.
- Args:
- root (string): Root directory where the dataset should be saved.
- categories (list, optional): List of categories to include in the
- dataset. Can include the categories :obj:`"Cat"`, :obj:`"Centaur"`,
- :obj:`"David"`, :obj:`"Dog"`, :obj:`"Gorilla"`, :obj:`"Horse"`,
- :obj:`"Michael"`, :obj:`"Victoria"`, :obj:`"Wolf"`. If set to
- :obj:`None`, the dataset will contain all categories. (default:
- :obj:`None`)
- transform (callable, optional): A function/transform that takes in an
- :obj:`torch_geometric.data.Data` object and returns a transformed
- version. The data object will be transformed before every access.
- (default: :obj:`None`)
- pre_transform (callable, optional): A function/transform that takes in
- an :obj:`torch_geometric.data.Data` object and returns a
- transformed version. The data object will be transformed before
- being saved to disk. (default: :obj:`None`)
- pre_filter (callable, optional): A function that takes in an
- :obj:`torch_geometric.data.Data` object and returns a boolean
- value, indicating whether the data object should be included in the
- final dataset. (default: :obj:`None`)
- """
- url = 'http://tosca.cs.technion.ac.il/data/toscahires-asci.zip'
- categories = [
- 'cat', 'centaur', 'david', 'dog', 'gorilla', 'horse', 'michael',
- 'victoria', 'wolf'
- ]
- def __init__(self, root, categories=None, transform=None,
- pre_transform=None, pre_filter=None):
- categories = self.categories if categories is None else categories
- categories = [cat.lower() for cat in categories]
- for cat in categories:
- assert cat in self.categories
- self.categories = categories
- super(TOSCA, self).__init__(root, transform, pre_transform, pre_filter)
- self.data, self.slices = torch.load(self.processed_paths[0])
- @property
- def raw_file_names(self):
- return ['cat0.vert', 'cat0.tri']
- @property
- def processed_file_names(self):
- return '{}.pt'.format('_'.join([cat[:2] for cat in self.categories]))
- def download(self):
- path = download_url(self.url, self.raw_dir)
- extract_zip(path, self.raw_dir)
- os.unlink(path)
- def process(self):
- data_list = []
- for cat in self.categories:
- paths = glob.glob(osp.join(self.raw_dir, '{}*.tri'.format(cat)))
- paths = [path[:-4] for path in paths]
- paths = sorted(paths, key=lambda e: (len(e), e))
- for path in paths:
- pos = read_txt_array('{}.vert'.format(path))
- face = read_txt_array('{}.tri'.format(path), dtype=torch.long)
- data = Data(pos=pos, face=face.t().contiguous())
- if self.pre_filter is not None and not self.pre_filter(data):
- continue
- if self.pre_transform is not None:
- data = self.pre_transform(data)
- data_list.append(data)
- torch.save(self.collate(data_list), self.processed_paths[0])