/workspace/data/sample-modelnet40.py

https://github.com/AlexGeControl/3D-Point-Cloud-Analytics
Python | 82 lines | 54 code | 13 blank | 15 comment | 5 complexity | f0625bdcc5183599e1e400ba5004c5f1 MD5 | raw file
  1. #!/opt/conda/envs/point-cloud/bin/python
  2. import argparse
  3. import os
  4. import pathlib
  5. import shutil
  6. import glob
  7. from random import sample
  8. def get_arguments():
  9. """ gets command line arguments.
  10. :return:
  11. """
  12. # init parser:
  13. parser = argparse.ArgumentParser("Downsample ModelNet40 by category.")
  14. # add required and optional groups:
  15. required = parser.add_argument_group('Required')
  16. optional = parser.add_argument_group('Optional')
  17. # add required:
  18. required.add_argument(
  19. "-n", dest="num_per_category", help="The number of samples per category.",
  20. required=True, type=int
  21. )
  22. # add optional:
  23. optional.add_argument(
  24. "-i", dest="input", help="Input path of original ModelNet 40 dataset. Defaults to $PWD/ModelNet40",
  25. default='./ModelNet40'
  26. )
  27. optional.add_argument(
  28. "-t", dest="type", help="Which subset to dowmsample from. Defaults to train",
  29. default='train'
  30. )
  31. optional.add_argument(
  32. "-o", dest="output", help="Output path of downsampled ModelNet 40 dataset. Defaults to $PWD/ModelNet40Downsampled",
  33. default='./ModelNet40Downsampled'
  34. )
  35. # parse arguments:
  36. return parser.parse_args()
  37. if __name__ == '__main__':
  38. # parse arguments:
  39. arguments = get_arguments()
  40. print(f'Downsample ModelNet40 -- {arguments.num_per_category} per category: Start ...')
  41. # create output dir:
  42. try:
  43. pathlib.Path(arguments.output).mkdir(parents=True, exist_ok=False)
  44. except FileExistsError:
  45. shutil.rmtree(arguments.output)
  46. pathlib.Path(
  47. os.path.join(arguments.output, 'off')
  48. ).mkdir(parents=True, exist_ok=False)
  49. pathlib.Path(
  50. os.path.join(arguments.output, 'ply')
  51. ).mkdir(parents=True, exist_ok=False)
  52. # sample from input dir:
  53. for category in os.listdir(arguments.input):
  54. # identify dataset root dir of given category:
  55. input_dir = os.path.join(arguments.input, category, arguments.type)
  56. # find all *.off and downsample:
  57. pattern = os.path.join(input_dir, '*.off')
  58. samples = sample(glob.glob(pattern), arguments.num_per_category)
  59. # move the samples into output dir:
  60. print(f'\t{category}')
  61. for src in samples:
  62. dst = os.path.join(
  63. arguments.output,
  64. 'off',
  65. os.path.basename(src)
  66. )
  67. shutil.copyfile(src, dst)
  68. print(f'\t\t{src} --> {dst}')
  69. print(f'Downsample ModelNet40 -- {arguments.num_per_category} per category: Done.')