/test/test_image.py

https://github.com/pytorch/vision
Python | 159 lines | 123 code | 32 blank | 4 comment | 30 complexity | c5e6ea83b5b3fb3d2df30e9eb921ea5a MD5 | raw file
  1. import os
  2. import io
  3. import glob
  4. import unittest
  5. import sys
  6. import torch
  7. import torchvision
  8. from PIL import Image
  9. from torchvision.io.image import (
  10. read_png, decode_png, read_jpeg, decode_jpeg, encode_jpeg, write_jpeg)
  11. import numpy as np
  12. IMAGE_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
  13. IMAGE_DIR = os.path.join(IMAGE_ROOT, "fakedata", "imagefolder")
  14. DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
  15. def get_images(directory, img_ext):
  16. assert os.path.isdir(directory)
  17. for root, _, files in os.walk(directory):
  18. if os.path.basename(root) in {'damaged_jpeg', 'jpeg_write'}:
  19. continue
  20. for fl in files:
  21. _, ext = os.path.splitext(fl)
  22. if ext == img_ext:
  23. yield os.path.join(root, fl)
  24. class ImageTester(unittest.TestCase):
  25. def test_read_jpeg(self):
  26. for img_path in get_images(IMAGE_ROOT, ".jpg"):
  27. img_pil = torch.load(img_path.replace('jpg', 'pth'))
  28. img_pil = img_pil.permute(2, 0, 1)
  29. img_ljpeg = read_jpeg(img_path)
  30. self.assertTrue(img_ljpeg.equal(img_pil))
  31. def test_decode_jpeg(self):
  32. for img_path in get_images(IMAGE_ROOT, ".jpg"):
  33. img_pil = torch.load(img_path.replace('jpg', 'pth'))
  34. img_pil = img_pil.permute(2, 0, 1)
  35. size = os.path.getsize(img_path)
  36. img_ljpeg = decode_jpeg(torch.from_file(img_path, dtype=torch.uint8, size=size))
  37. self.assertTrue(img_ljpeg.equal(img_pil))
  38. with self.assertRaisesRegex(ValueError, "Expected a non empty 1-dimensional tensor."):
  39. decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
  40. with self.assertRaisesRegex(ValueError, "Expected a torch.uint8 tensor."):
  41. decode_jpeg(torch.empty((100, ), dtype=torch.float16))
  42. with self.assertRaises(RuntimeError):
  43. decode_jpeg(torch.empty((100), dtype=torch.uint8))
  44. def test_damaged_images(self):
  45. # Test image with bad Huffman encoding (should not raise)
  46. bad_huff = os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')
  47. try:
  48. _ = read_jpeg(bad_huff)
  49. except RuntimeError:
  50. self.assertTrue(False)
  51. # Truncated images should raise an exception
  52. truncated_images = glob.glob(
  53. os.path.join(DAMAGED_JPEG, 'corrupt*.jpg'))
  54. for image_path in truncated_images:
  55. with self.assertRaises(RuntimeError):
  56. read_jpeg(image_path)
  57. def test_encode_jpeg(self):
  58. for img_path in get_images(IMAGE_ROOT, ".jpg"):
  59. dirname = os.path.dirname(img_path)
  60. filename, _ = os.path.splitext(os.path.basename(img_path))
  61. write_folder = os.path.join(dirname, 'jpeg_write')
  62. expected_file = os.path.join(
  63. write_folder, '{0}_pil.jpg'.format(filename))
  64. img = read_jpeg(img_path)
  65. with open(expected_file, 'rb') as f:
  66. pil_bytes = f.read()
  67. pil_bytes = torch.as_tensor(list(pil_bytes), dtype=torch.uint8)
  68. for src_img in [img, img.contiguous()]:
  69. # PIL sets jpeg quality to 75 by default
  70. jpeg_bytes = encode_jpeg(src_img, quality=75)
  71. self.assertTrue(jpeg_bytes.equal(pil_bytes))
  72. with self.assertRaisesRegex(
  73. RuntimeError, "Input tensor dtype should be uint8"):
  74. encode_jpeg(torch.empty((3, 100, 100), dtype=torch.float32))
  75. with self.assertRaisesRegex(
  76. ValueError, "Image quality should be a positive number "
  77. "between 1 and 100"):
  78. encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=-1)
  79. with self.assertRaisesRegex(
  80. ValueError, "Image quality should be a positive number "
  81. "between 1 and 100"):
  82. encode_jpeg(torch.empty((3, 100, 100), dtype=torch.uint8), quality=101)
  83. with self.assertRaisesRegex(
  84. RuntimeError, "The number of channels should be 1 or 3, got: 5"):
  85. encode_jpeg(torch.empty((5, 100, 100), dtype=torch.uint8))
  86. with self.assertRaisesRegex(
  87. RuntimeError, "Input data should be a 3-dimensional tensor"):
  88. encode_jpeg(torch.empty((1, 3, 100, 100), dtype=torch.uint8))
  89. with self.assertRaisesRegex(
  90. RuntimeError, "Input data should be a 3-dimensional tensor"):
  91. encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
  92. def test_write_jpeg(self):
  93. for img_path in get_images(IMAGE_ROOT, ".jpg"):
  94. img = read_jpeg(img_path)
  95. basedir = os.path.dirname(img_path)
  96. filename, _ = os.path.splitext(os.path.basename(img_path))
  97. torch_jpeg = os.path.join(
  98. basedir, '{0}_torch.jpg'.format(filename))
  99. pil_jpeg = os.path.join(
  100. basedir, 'jpeg_write', '{0}_pil.jpg'.format(filename))
  101. write_jpeg(img, torch_jpeg, quality=75)
  102. with open(torch_jpeg, 'rb') as f:
  103. torch_bytes = f.read()
  104. with open(pil_jpeg, 'rb') as f:
  105. pil_bytes = f.read()
  106. os.remove(torch_jpeg)
  107. self.assertEqual(torch_bytes, pil_bytes)
  108. def test_read_png(self):
  109. # Check across .png
  110. for img_path in get_images(IMAGE_DIR, ".png"):
  111. img_pil = torch.from_numpy(np.array(Image.open(img_path)))
  112. img_pil = img_pil.permute(2, 0, 1)
  113. img_lpng = read_png(img_path)
  114. self.assertTrue(img_lpng.equal(img_pil))
  115. def test_decode_png(self):
  116. for img_path in get_images(IMAGE_DIR, ".png"):
  117. img_pil = torch.from_numpy(np.array(Image.open(img_path)))
  118. img_pil = img_pil.permute(2, 0, 1)
  119. size = os.path.getsize(img_path)
  120. img_lpng = decode_png(torch.from_file(img_path, dtype=torch.uint8, size=size))
  121. self.assertTrue(img_lpng.equal(img_pil))
  122. with self.assertRaises(ValueError):
  123. decode_png(torch.empty((), dtype=torch.uint8))
  124. with self.assertRaises(RuntimeError):
  125. decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
  126. if __name__ == '__main__':
  127. unittest.main()