/test.py

https://github.com/hbutsuak95/monolayout
Python | 158 lines | 129 code | 24 blank | 5 comment | 16 complexity | 5077f2480bee6673cfb8338686db471c MD5 | raw file
  1. import argparse
  2. import glob
  3. import os
  4. import PIL.Image as pil
  5. import cv2
  6. from monolayout import model
  7. import numpy as np
  8. import torch
  9. from torchvision import transforms
  10. def get_args():
  11. parser = argparse.ArgumentParser(
  12. description="Testing arguments for MonoLayout")
  13. parser.add_argument("--image_path", type=str,
  14. help="path to folder of images", required=True)
  15. parser.add_argument("--model_path", type=str,
  16. help="path to MonoLayout model", required=True)
  17. parser.add_argument(
  18. "--ext",
  19. type=str,
  20. default="png",
  21. help="extension of images in the folder")
  22. parser.add_argument("--out_dir", type=str,
  23. default="output directory to save topviews")
  24. parser.add_argument("--type", type=str,
  25. default="static/dynamic/both")
  26. return parser.parse_args()
  27. def save_topview(idx, tv, name_dest_im):
  28. tv_np = tv.squeeze().cpu().numpy()
  29. true_top_view = np.zeros((tv_np.shape[1], tv_np.shape[2]))
  30. true_top_view[tv_np[1] > tv_np[0]] = 255
  31. dir_name = os.path.dirname(name_dest_im)
  32. if not os.path.exists(dir_name):
  33. os.makedirs(dir_name)
  34. cv2.imwrite(name_dest_im, true_top_view)
  35. print("Saved prediction to {}".format(name_dest_im))
  36. def test(args):
  37. models = {}
  38. device = torch.device("cuda")
  39. encoder_path = os.path.join(args.model_path, "encoder.pth")
  40. encoder_dict = torch.load(encoder_path, map_location=device)
  41. feed_height = encoder_dict["height"]
  42. feed_width = encoder_dict["width"]
  43. models["encoder"] = model.Encoder(18, feed_width, feed_height, False)
  44. filtered_dict_enc = {
  45. k: v for k,
  46. v in encoder_dict.items() if k in models["encoder"].state_dict()}
  47. models["encoder"].load_state_dict(filtered_dict_enc)
  48. if args.type == "both":
  49. static_decoder_path = os.path.join(
  50. args.model_path, "static_decoder.pth")
  51. dynamic_decoder_path = os.path.join(
  52. args.model_path, "dynamic_decoder.pth")
  53. models["static_decoder"] = model.Decoder(
  54. models["encoder"].resnet_encoder.num_ch_enc)
  55. models["static_decoder"].load_state_dict(
  56. torch.load(static_decoder_path, map_location=device))
  57. models["dynamic_decoder"] = model.Decoder(
  58. models["encoder"].resnet_encoder.num_ch_enc)
  59. models["dynamic_decoder"].load_state_dict(
  60. torch.load(dynamic_decoder_path, map_location=device))
  61. else:
  62. decoder_path = os.path.join(args.model_path, "decoder.pth")
  63. models["decoder"] = model.Decoder(
  64. models["encoder"].resnet_encoder.num_ch_enc)
  65. models["decoder"].load_state_dict(
  66. torch.load(decoder_path, map_location=device))
  67. for key in models.keys():
  68. models[key].to(device)
  69. models[key].eval()
  70. if os.path.isfile(args.image_path):
  71. # Only testing on a single image
  72. paths = [args.image_path]
  73. output_directory = os.path.dirname(args.image_path)
  74. elif os.path.isdir(args.image_path):
  75. # Searching folder for images
  76. paths = glob.glob(os.path.join(
  77. args.image_path, '*.{}'.format(args.ext)))
  78. output_directory = args.out_dir
  79. try:
  80. os.mkdir(output_directory)
  81. except BaseException:
  82. pass
  83. else:
  84. raise Exception(
  85. "Can not find args.image_path: {}".format(
  86. args.image_path))
  87. print("-> Predicting on {:d} test images".format(len(paths)))
  88. # PREDICTING ON EACH IMAGE IN TURN
  89. with torch.no_grad():
  90. for idx, image_path in enumerate(paths):
  91. # Load image and preprocess
  92. input_image = pil.open(image_path).convert('RGB')
  93. original_width, original_height = input_image.size
  94. input_image = input_image.resize(
  95. (feed_width, feed_height), pil.LANCZOS)
  96. input_image = transforms.ToTensor()(input_image).unsqueeze(0)
  97. # PREDICTION
  98. input_image = input_image.to(device)
  99. features = models["encoder"](input_image)
  100. output_name = os.path.splitext(os.path.basename(image_path))[0]
  101. print(
  102. "Processing {:d} of {:d} images- ".format(idx + 1, len(paths)))
  103. if args.type == "both":
  104. static_tv = models["static_decoder"](
  105. features, is_training=False)
  106. dynamic_tv = models["dynamic_decoder"](
  107. features, is_training=False)
  108. save_topview(
  109. idx,
  110. static_tv,
  111. os.path.join(
  112. args.out_dir,
  113. "static",
  114. "{}.png".format(output_name)))
  115. save_topview(
  116. idx,
  117. dynamic_tv,
  118. os.path.join(
  119. args.out_dir,
  120. "dynamic",
  121. "{}.png".format(output_name)))
  122. else:
  123. tv = models["decoder"](features, is_training=False)
  124. save_topview(
  125. idx,
  126. tv,
  127. os.path.join(
  128. args.out_dir,
  129. args.type,
  130. "{}.png".format(output_name)))
  131. print('-> Done!')
  132. if __name__ == "__main__":
  133. args = get_args()
  134. test(args)