/aten/src/ATen/test/type_test.cpp

https://github.com/ROCmSoftwarePlatform/pytorch · C++ · 209 lines · 167 code · 27 blank · 15 comment · 4 complexity · a5339ed426d255a52c05d2295b9c076c MD5 · raw file

  1. #include <ATen/ATen.h>
  2. #include <gtest/gtest.h>
  3. #include <torch/torch.h>
  4. #include <ATen/core/jit_type.h>
  5. #include <torch/csrc/jit/frontend/resolver.h>
  6. #include <torch/csrc/jit/serialization/import_source.h>
  7. namespace c10 {
  8. TEST(TypeCustomPrinter, Basic) {
  9. TypePrinter printer =
  10. [](const ConstTypePtr& t) -> c10::optional<std::string> {
  11. if (auto tensorType = t->cast<TensorType>()) {
  12. return "CustomTensor";
  13. }
  14. return c10::nullopt;
  15. };
  16. // Tensor types should be rewritten
  17. torch::Tensor iv = torch::rand({2, 3});
  18. const auto type = TensorType::create(iv);
  19. EXPECT_EQ(type->annotation_str(), "Tensor");
  20. EXPECT_EQ(type->annotation_str(printer), "CustomTensor");
  21. // Unrelated types shoudl not be affected
  22. const auto intType = IntType::get();
  23. EXPECT_EQ(intType->annotation_str(printer), intType->annotation_str());
  24. }
  25. TEST(TypeCustomPrinter, ContainedTypes) {
  26. TypePrinter printer =
  27. [](const ConstTypePtr& t) -> c10::optional<std::string> {
  28. if (auto tensorType = t->cast<TensorType>()) {
  29. return "CustomTensor";
  30. }
  31. return c10::nullopt;
  32. };
  33. torch::Tensor iv = torch::rand({2, 3});
  34. const auto type = TensorType::create(iv);
  35. // Contained types should work
  36. const auto tupleType = TupleType::create({type, IntType::get(), type});
  37. EXPECT_EQ(tupleType->annotation_str(), "Tuple[Tensor, int, Tensor]");
  38. EXPECT_EQ(
  39. tupleType->annotation_str(printer), "Tuple[CustomTensor, int, CustomTensor]");
  40. const auto dictType = DictType::create(IntType::get(), type);
  41. EXPECT_EQ(dictType->annotation_str(printer), "Dict[int, CustomTensor]");
  42. const auto listType = ListType::create(tupleType);
  43. EXPECT_EQ(
  44. listType->annotation_str(printer),
  45. "List[Tuple[CustomTensor, int, CustomTensor]]");
  46. }
  47. TEST(TypeCustomPrinter, NamedTuples) {
  48. TypePrinter printer =
  49. [](const ConstTypePtr& t) -> c10::optional<std::string> {
  50. if (auto tupleType = t->cast<TupleType>()) {
  51. // Rewrite only NamedTuples
  52. if (tupleType->name()) {
  53. return "Rewritten";
  54. }
  55. }
  56. return c10::nullopt;
  57. };
  58. torch::Tensor iv = torch::rand({2, 3});
  59. const auto type = TensorType::create(iv);
  60. std::vector<std::string> field_names = {"foo", "bar"};
  61. const auto namedTupleType = TupleType::createNamed(
  62. "my.named.tuple", field_names, {type, IntType::get()});
  63. EXPECT_EQ(namedTupleType->annotation_str(printer), "Rewritten");
  64. // Put it inside another tuple, should still work
  65. const auto outerTupleType = TupleType::create({IntType::get(), namedTupleType});
  66. EXPECT_EQ(outerTupleType->annotation_str(printer), "Tuple[int, Rewritten]");
  67. }
  68. static TypePtr importType(
  69. std::shared_ptr<CompilationUnit> cu,
  70. const std::string& qual_name,
  71. const std::string& src) {
  72. std::vector<at::IValue> constantTable;
  73. auto source = std::make_shared<torch::jit::Source>(src);
  74. torch::jit::SourceImporter si(
  75. cu,
  76. &constantTable,
  77. [&](const std::string& name) -> std::shared_ptr<torch::jit::Source> {
  78. return source;
  79. },
  80. /*version=*/2);
  81. return si.loadType(qual_name);
  82. }
  83. TEST(TypeEquality, ClassBasic) {
  84. // Even if classes have the same name across two compilation units, they
  85. // should not compare equal.
  86. auto cu = std::make_shared<CompilationUnit>();
  87. const auto src = R"JIT(
  88. class First:
  89. def one(self, x: Tensor, y: Tensor) -> Tensor:
  90. return x
  91. )JIT";
  92. auto classType = importType(cu, "__torch__.First", src);
  93. auto classType2 = cu->get_type("__torch__.First");
  94. // Trivially these should be equal
  95. EXPECT_EQ(*classType, *classType2);
  96. }
  97. TEST(TypeEquality, ClassInequality) {
  98. // Even if classes have the same name across two compilation units, they
  99. // should not compare equal.
  100. auto cu = std::make_shared<CompilationUnit>();
  101. const auto src = R"JIT(
  102. class First:
  103. def one(self, x: Tensor, y: Tensor) -> Tensor:
  104. return x
  105. )JIT";
  106. auto classType = importType(cu, "__torch__.First", src);
  107. auto cu2 = std::make_shared<CompilationUnit>();
  108. const auto src2 = R"JIT(
  109. class First:
  110. def one(self, x: Tensor, y: Tensor) -> Tensor:
  111. return y
  112. )JIT";
  113. auto classType2 = importType(cu2, "__torch__.First", src2);
  114. EXPECT_NE(*classType, *classType2);
  115. }
  116. TEST(TypeEquality, InterfaceEquality) {
  117. // Interfaces defined anywhere should compare equal, provided they share a
  118. // name and interface
  119. auto cu = std::make_shared<CompilationUnit>();
  120. const auto interfaceSrc = R"JIT(
  121. class OneForward(Interface):
  122. def one(self, x: Tensor, y: Tensor) -> Tensor:
  123. pass
  124. def forward(self, x: Tensor) -> Tensor:
  125. pass
  126. )JIT";
  127. auto interfaceType = importType(cu, "__torch__.OneForward", interfaceSrc);
  128. auto cu2 = std::make_shared<CompilationUnit>();
  129. auto interfaceType2 = importType(cu2, "__torch__.OneForward", interfaceSrc);
  130. EXPECT_EQ(*interfaceType, *interfaceType2);
  131. }
  132. TEST(TypeEquality, InterfaceInequality) {
  133. // Interfaces must match for them to compare equal, even if they share a name
  134. auto cu = std::make_shared<CompilationUnit>();
  135. const auto interfaceSrc = R"JIT(
  136. class OneForward(Interface):
  137. def one(self, x: Tensor, y: Tensor) -> Tensor:
  138. pass
  139. def forward(self, x: Tensor) -> Tensor:
  140. pass
  141. )JIT";
  142. auto interfaceType = importType(cu, "__torch__.OneForward", interfaceSrc);
  143. auto cu2 = std::make_shared<CompilationUnit>();
  144. const auto interfaceSrc2 = R"JIT(
  145. class OneForward(Interface):
  146. def two(self, x: Tensor, y: Tensor) -> Tensor:
  147. pass
  148. def forward(self, x: Tensor) -> Tensor:
  149. pass
  150. )JIT";
  151. auto interfaceType2 = importType(cu2, "__torch__.OneForward", interfaceSrc2);
  152. EXPECT_NE(*interfaceType, *interfaceType2);
  153. }
  154. TEST(TypeEquality, TupleEquality) {
  155. // Tuples should be structurally typed
  156. auto type = TupleType::create({IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
  157. auto type2 = TupleType::create({IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
  158. EXPECT_EQ(*type, *type2);
  159. }
  160. TEST(TypeEquality, NamedTupleEquality) {
  161. // Named tuples should compare equal if they share a name and field names
  162. auto type = TupleType::createNamed(
  163. "MyNamedTuple",
  164. {"a", "b", "c", "d"},
  165. {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
  166. auto type2 = TupleType::createNamed(
  167. "MyNamedTuple",
  168. {"a", "b", "c", "d"},
  169. {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
  170. EXPECT_EQ(*type, *type2);
  171. auto differentName = TupleType::createNamed(
  172. "WowSoDifferent",
  173. {"a", "b", "c", "d"},
  174. {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
  175. EXPECT_NE(*type, *differentName);
  176. auto differentField = TupleType::createNamed(
  177. "MyNamedTuple",
  178. {"wow", "so", "very", "different"},
  179. {IntType::get(), TensorType::get(), FloatType::get(), ComplexType::get()});
  180. EXPECT_NE(*type, *differentField);
  181. }
  182. } // namespace c10