/data.py

https://github.com/iwtw/pytorch-TP-GAN
Python | 114 lines | 91 code | 17 blank | 6 comment | 5 complexity | 4a72d36e830106586234fb52105fa34d MD5 | raw file
  1. from torch.utils.data import Dataset
  2. import torchvision.transforms as transforms
  3. from PIL import Image
  4. import numpy as np
  5. from math import floor
  6. from utils import landmarks_68_to_5
  7. def process(img , landmarks_5pts):
  8. batch = {}
  9. name = ['left_eye','right_eye','nose','mouth']
  10. patch_size = {
  11. 'left_eye':(40,40),
  12. 'right_eye':(40,40),
  13. 'nose':(40,32),
  14. 'mouth':(48,32),
  15. }
  16. landmarks_5pts[3,0] = (landmarks_5pts[3,0] + landmarks_5pts[4,0]) / 2.0
  17. landmarks_5pts[3,1] = (landmarks_5pts[3,1] + landmarks_5pts[4,1]) / 2.0
  18. # crops
  19. for i in range(4):
  20. x = floor(landmarks_5pts[i,0])
  21. y = floor(landmarks_5pts[i,1])
  22. batch[ name[i] ] = img.crop( (x - patch_size[ name[i] ][0]//2 + 1 , y - patch_size[ name[i] ][1]//2 + 1 , x + patch_size[ name[i] ][0]//2 + 1 , y + patch_size[ name[i] ][1]//2 + 1 ) )
  23. return batch
  24. class TrainDataset( Dataset):
  25. def __init__( self , img_list ):
  26. super(type(self),self).__init__()
  27. self.img_list = img_list
  28. def __len__( self ):
  29. return len(self.img_list)
  30. def __getitem__( self , idx ):
  31. #filename processing
  32. batch = {}
  33. img_name = self.img_list[idx].split('/')
  34. img_frontal_name = self.img_list[idx].split('_')
  35. img_frontal_name[-2] = '051'
  36. img_frontal_name = '_'.join( img_frontal_name ).split('/')
  37. batch['img'] = Image.open( '/'.join( img_name ) )
  38. batch['img32'] = Image.open( '/'.join( img_name[:-2] + ['32x32' , img_name[-1] ] ) )
  39. batch['img64'] = Image.open( '/'.join( img_name[:-2] + ['64x64' , img_name[-1] ] ) )
  40. batch['img_frontal'] = Image.open( '/'.join(img_frontal_name) )
  41. batch['img32_frontal'] = Image.open( '/'.join( img_frontal_name[:-2] + ['32x32' , img_frontal_name[-1] ] ) )
  42. batch['img64_frontal'] = Image.open( '/'.join( img_frontal_name[:-2] + ['64x64' , img_frontal_name[-1] ] ) )
  43. patch_name_list = ['left_eye','right_eye','nose','mouth']
  44. for patch_name in patch_name_list:
  45. batch[patch_name] = Image.open( '/'.join(img_name[:-2] + ['patch' , patch_name , img_name[-1] ]) )
  46. batch[patch_name+'_frontal'] = Image.open( '/'.join(img_frontal_name[:-2] + ['patch' , patch_name , img_frontal_name[-1] ]) )
  47. totensor = transforms.ToTensor()
  48. for k in batch:
  49. batch[k] = totensor( batch[k] )
  50. batch[k] = batch[k] *2.0 -1.0
  51. #if batch[k].max() <= 0.9:
  52. # print( "{} {} {}".format( batch[k].max(), self.img_list[idx] , k ))
  53. #if batch[k].min() >= -0.9:
  54. # print( "{} {} {}".format( batch[k].min() , self.img_list[idx] , k ) )
  55. batch['label'] = int( self.img_list[idx].split('/')[-1].split('_')[0] )
  56. return batch
  57. class PretrainDataset( Dataset):
  58. def __init__( self , img_list ):
  59. super(type(self),self).__init__()
  60. self.img_list = img_list
  61. def __len__( self):
  62. return len(self.img_list)
  63. def __getitem__(self,idx):
  64. batch = {}
  65. totensor = transforms.ToTensor()
  66. img = Image.open( self.img_list[idx] )
  67. img = totensor( img )
  68. img = img*2.0 - 1.0
  69. batch['img'] = img
  70. batch['label'] = int( self.img_list[idx].split('/')[-1].split('_')[0] )
  71. return batch
  72. class TestDataset( Dataset):
  73. def __init__( self , img_list , lm_list):
  74. super(type(self),self).__init__()
  75. self.img_list = img_list
  76. self.lm_list = lm_list
  77. assert len(img_list) == len(lm_list)
  78. def __len__(self):
  79. return len(self.img_list)
  80. def __getitem__(self,idx):
  81. img_name = self.img_list[idx]
  82. img = Image.open( img_name )
  83. lm = np.array( self.lm_list[idx].split(' ') , np.float32 ).reshape(-1,2)
  84. lm = landmarks_68_to_5( lm )
  85. for i in range(5):
  86. lm[i][0] *= 128/img.width
  87. lm[i][1] *= 128/img.height
  88. img = img.resize( (128,128) , Image.LANCZOS)
  89. batch = process( img , lm )
  90. batch['img'] = img
  91. batch['img64'] = img.resize( (64,64) , Image.LANCZOS )
  92. batch['img32'] = batch['img64'].resize( (32,32) , Image.LANCZOS )
  93. to_tensor = transforms.ToTensor()
  94. for k in batch:
  95. batch[k] = to_tensor( batch[k] )
  96. batch[k] = batch[k] * 2.0 - 1.0
  97. return batch