/tools/tf2openapi/types/tfsignaturedef_test.go

https://github.com/kubeflow/kfserving · Go · 534 lines · 515 code · 16 blank · 3 comment · 1 complexity · 8bfa05291ca191f11318fcb4b11cc11a MD5 · raw file

  1. package types
  2. import (
  3. "fmt"
  4. "testing"
  5. "github.com/getkin/kin-openapi/openapi3"
  6. "github.com/onsi/gomega"
  7. "github.com/kubeflow/kfserving/pkg/utils"
  8. "github.com/kubeflow/kfserving/tools/tf2openapi/generated/framework"
  9. pb "github.com/kubeflow/kfserving/tools/tf2openapi/generated/protobuf"
  10. )
  11. /* Expected values */
  12. func expectedTFSignatureDef() TFSignatureDef {
  13. return TFSignatureDef{
  14. Key: "Signature Def Key",
  15. Method: Predict,
  16. Inputs: []TFTensor{
  17. {
  18. Name: "input",
  19. DType: DtInt8,
  20. Shape: TFShape{-1, 3},
  21. Rank: 2,
  22. },
  23. },
  24. Outputs: outputTensors(),
  25. }
  26. }
  27. func outputTensors() []TFTensor {
  28. return []TFTensor{
  29. {
  30. Name: "output",
  31. DType: DtInt8,
  32. Shape: TFShape{-1, 3},
  33. Rank: 2,
  34. },
  35. }
  36. }
  37. // corresponding response schemas in row and col fmt for SignatureDef above
  38. func expectedResponseSchemaRowFmt() *openapi3.Schema {
  39. return &openapi3.Schema{
  40. Type: "object",
  41. Properties: map[string]*openapi3.SchemaRef{
  42. "predictions": {
  43. Value: &openapi3.Schema{
  44. Type: "array",
  45. Items: &openapi3.SchemaRef{
  46. Value: &openapi3.Schema{
  47. Type: "array",
  48. MaxItems: utils.UInt64(3),
  49. MinItems: 3,
  50. Items: &openapi3.SchemaRef{
  51. Value: &openapi3.Schema{
  52. Type: "number",
  53. },
  54. },
  55. },
  56. },
  57. },
  58. },
  59. },
  60. Required: []string{"predictions"},
  61. AdditionalPropertiesAllowed: utils.Bool(false),
  62. }
  63. }
  64. func expectedResponseSchemaColFmt() *openapi3.Schema {
  65. return &openapi3.Schema{
  66. Type: "object",
  67. Properties: map[string]*openapi3.SchemaRef{
  68. "outputs": {
  69. Value: &openapi3.Schema{
  70. Type: "array",
  71. Items: &openapi3.SchemaRef{
  72. Value: &openapi3.Schema{
  73. Type: "array",
  74. MaxItems: utils.UInt64(3),
  75. MinItems: 3,
  76. Items: &openapi3.SchemaRef{
  77. Value: &openapi3.Schema{
  78. Type: "number",
  79. },
  80. },
  81. },
  82. },
  83. },
  84. },
  85. },
  86. Required: []string{"outputs"},
  87. AdditionalPropertiesAllowed: utils.Bool(false),
  88. }
  89. }
  90. /* Fake protobuf structs to use as test inputs */
  91. func goodTensorsPb(name string) map[string]*pb.TensorInfo {
  92. return map[string]*pb.TensorInfo{
  93. name: {
  94. Dtype: framework.DataType_DT_INT8,
  95. TensorShape: &framework.TensorShapeProto{
  96. Dim: []*framework.TensorShapeProto_Dim{
  97. {Size: -1},
  98. {Size: 3},
  99. },
  100. UnknownRank: false,
  101. },
  102. },
  103. }
  104. }
  105. func badTensorsPb(name string) map[string]*pb.TensorInfo {
  106. return map[string]*pb.TensorInfo{
  107. name: {
  108. Dtype: framework.DataType_DT_COMPLEX128,
  109. TensorShape: &framework.TensorShapeProto{
  110. Dim: []*framework.TensorShapeProto_Dim{
  111. {Size: -1},
  112. {Size: 3},
  113. },
  114. UnknownRank: false,
  115. },
  116. },
  117. }
  118. }
  119. func TestCreateTFSignatureDefTypical(t *testing.T) {
  120. g := gomega.NewGomegaWithT(t)
  121. tfSignatureDef, err := NewTFSignatureDef("Signature Def Key", "tensorflow/serving/predict",
  122. goodTensorsPb("input"),
  123. goodTensorsPb("output"))
  124. expectedSignatureDef := expectedTFSignatureDef()
  125. g.Expect(tfSignatureDef).Should(gomega.Equal(expectedSignatureDef))
  126. g.Expect(err).Should(gomega.BeNil())
  127. }
  128. func TestCreateTFSignatureDefWithErrInputs(t *testing.T) {
  129. g := gomega.NewGomegaWithT(t)
  130. inputTensors := badTensorsPb("input")
  131. outputTensors := goodTensorsPb("output")
  132. _, err := NewTFSignatureDef("Signature Def Key", "tensorflow/serving/predict", inputTensors, outputTensors)
  133. expectedErr := fmt.Sprintf(UnsupportedDataTypeError, "input", "DT_COMPLEX128")
  134. g.Expect(err).Should(gomega.MatchError(expectedErr))
  135. }
  136. func TestCreateTFSignatureDefWithErrOutputs(t *testing.T) {
  137. g := gomega.NewGomegaWithT(t)
  138. inputTensors := goodTensorsPb("input")
  139. outputTensors := badTensorsPb("output")
  140. _, err := NewTFSignatureDef("Signature Def Key", "tensorflow/serving/predict", inputTensors, outputTensors)
  141. expectedErr := fmt.Sprintf(UnsupportedDataTypeError, "output", "DT_COMPLEX128")
  142. g.Expect(err).Should(gomega.MatchError(expectedErr))
  143. }
  144. func TestCreateTFSignatureDefWithErrMethod(t *testing.T) {
  145. g := gomega.NewGomegaWithT(t)
  146. inputTensors := goodTensorsPb("input")
  147. outputTensors := goodTensorsPb("output")
  148. _, err := NewTFSignatureDef("Signature Def Key", "tensorflow/serving/bad", inputTensors, outputTensors)
  149. expectedErr := fmt.Sprintf(UnsupportedSignatureMethodError, "Signature Def Key", "tensorflow/serving/bad")
  150. g.Expect(err).Should(gomega.MatchError(expectedErr))
  151. }
  152. func TestTFSignatureDefVariousFmt(t *testing.T) {
  153. g := gomega.NewGomegaWithT(t)
  154. scenarios := map[string]struct {
  155. tfSigDef TFSignatureDef
  156. expectedRequestSchema *openapi3.Schema
  157. expectedResponseSchema *openapi3.Schema
  158. }{
  159. "RowSchemaMultipleTensors": {
  160. tfSigDef: TFSignatureDef{
  161. Key: "Signature Def Key",
  162. Inputs: []TFTensor{
  163. {
  164. Name: "signal",
  165. DType: DtInt8,
  166. Shape: TFShape{-1, 5},
  167. Rank: 2,
  168. },
  169. {
  170. Name: "sensor",
  171. DType: DtInt8,
  172. Shape: TFShape{-1, 2, 2},
  173. Rank: 3,
  174. },
  175. },
  176. Outputs: outputTensors(),
  177. },
  178. expectedRequestSchema: &openapi3.Schema{
  179. Type: "object",
  180. Properties: map[string]*openapi3.SchemaRef{
  181. "instances": {
  182. Value: &openapi3.Schema{
  183. Type: "array",
  184. Items: &openapi3.SchemaRef{
  185. Value: &openapi3.Schema{
  186. Type: "object",
  187. Properties: map[string]*openapi3.SchemaRef{
  188. "signal": {
  189. Value: &openapi3.Schema{
  190. Type: "array",
  191. MaxItems: utils.UInt64(5),
  192. MinItems: 5,
  193. Items: &openapi3.SchemaRef{
  194. Value: &openapi3.Schema{
  195. Type: "number",
  196. },
  197. },
  198. },
  199. },
  200. "sensor": {
  201. Value: &openapi3.Schema{
  202. Type: "array",
  203. MaxItems: utils.UInt64(2),
  204. MinItems: 2,
  205. Items: &openapi3.SchemaRef{
  206. Value: &openapi3.Schema{
  207. Type: "array",
  208. MaxItems: utils.UInt64(2),
  209. MinItems: 2,
  210. Items: &openapi3.SchemaRef{
  211. Value: &openapi3.Schema{
  212. Type: "number",
  213. },
  214. },
  215. },
  216. },
  217. },
  218. },
  219. },
  220. Required: []string{"signal", "sensor"},
  221. AdditionalPropertiesAllowed: utils.Bool(false),
  222. },
  223. },
  224. },
  225. },
  226. },
  227. Required: []string{"instances"},
  228. AdditionalPropertiesAllowed: utils.Bool(false),
  229. },
  230. expectedResponseSchema: expectedResponseSchemaRowFmt(),
  231. },
  232. "RowSchemaSingleTensor": {
  233. tfSigDef: expectedTFSignatureDef(),
  234. expectedRequestSchema: &openapi3.Schema{
  235. Type: "object",
  236. Properties: map[string]*openapi3.SchemaRef{
  237. "instances": {
  238. Value: &openapi3.Schema{
  239. Type: "array",
  240. Items: &openapi3.SchemaRef{
  241. Value: &openapi3.Schema{
  242. Type: "array",
  243. MaxItems: utils.UInt64(3),
  244. MinItems: 3,
  245. Items: &openapi3.SchemaRef{
  246. Value: &openapi3.Schema{
  247. Type: "number",
  248. },
  249. },
  250. },
  251. },
  252. },
  253. },
  254. },
  255. Required: []string{"instances"},
  256. AdditionalPropertiesAllowed: utils.Bool(false),
  257. },
  258. expectedResponseSchema: expectedResponseSchemaRowFmt(),
  259. },
  260. "ColSchemaMultipleTensors": {
  261. tfSigDef: TFSignatureDef{
  262. Key: "Signature Def Key",
  263. Inputs: []TFTensor{
  264. {
  265. Name: "signal",
  266. DType: DtInt8,
  267. Shape: TFShape{2, 5},
  268. Rank: 2,
  269. },
  270. {
  271. Name: "sensor",
  272. DType: DtInt8,
  273. Shape: TFShape{2, 2, 2},
  274. Rank: 3,
  275. },
  276. },
  277. Outputs: outputTensors(),
  278. },
  279. expectedRequestSchema: &openapi3.Schema{
  280. Type: "object",
  281. Properties: map[string]*openapi3.SchemaRef{
  282. "inputs": {
  283. Value: &openapi3.Schema{
  284. Type: "object",
  285. Properties: map[string]*openapi3.SchemaRef{
  286. "signal": {
  287. Value: &openapi3.Schema{
  288. Type: "array",
  289. MaxItems: utils.UInt64(2),
  290. MinItems: 2,
  291. Items: &openapi3.SchemaRef{
  292. Value: &openapi3.Schema{
  293. Type: "array",
  294. MaxItems: utils.UInt64(5),
  295. MinItems: 5,
  296. Items: &openapi3.SchemaRef{
  297. Value: &openapi3.Schema{
  298. Type: "number",
  299. },
  300. },
  301. },
  302. },
  303. },
  304. },
  305. "sensor": {
  306. Value: &openapi3.Schema{
  307. Type: "array",
  308. MaxItems: utils.UInt64(2),
  309. MinItems: 2,
  310. Items: &openapi3.SchemaRef{
  311. Value: &openapi3.Schema{
  312. Type: "array",
  313. MaxItems: utils.UInt64(2),
  314. MinItems: 2,
  315. Items: &openapi3.SchemaRef{
  316. Value: &openapi3.Schema{
  317. Type: "array",
  318. MaxItems: utils.UInt64(2),
  319. MinItems: 2,
  320. Items: &openapi3.SchemaRef{
  321. Value: &openapi3.Schema{
  322. Type: "number",
  323. },
  324. },
  325. },
  326. },
  327. },
  328. },
  329. },
  330. },
  331. },
  332. Required: []string{"signal", "sensor"},
  333. AdditionalPropertiesAllowed: utils.Bool(false),
  334. },
  335. },
  336. },
  337. Required: []string{"inputs"},
  338. AdditionalPropertiesAllowed: utils.Bool(false),
  339. },
  340. expectedResponseSchema: expectedResponseSchemaColFmt(),
  341. },
  342. "ColSchemaSingleTensor": {
  343. tfSigDef: TFSignatureDef{
  344. Key: "Signature Def Key",
  345. Inputs: []TFTensor{
  346. {
  347. Name: "signal",
  348. DType: DtInt8,
  349. Shape: TFShape{2, 5},
  350. Rank: 2,
  351. },
  352. },
  353. Outputs: outputTensors(),
  354. }, expectedRequestSchema: &openapi3.Schema{
  355. Type: "object",
  356. Properties: map[string]*openapi3.SchemaRef{
  357. "inputs": {
  358. Value: &openapi3.Schema{
  359. Type: "array",
  360. MaxItems: utils.UInt64(2),
  361. MinItems: 2,
  362. Items: &openapi3.SchemaRef{
  363. Value: &openapi3.Schema{
  364. Type: "array",
  365. MaxItems: utils.UInt64(5),
  366. MinItems: 5,
  367. Items: &openapi3.SchemaRef{
  368. Value: &openapi3.Schema{
  369. Type: "number",
  370. },
  371. },
  372. },
  373. },
  374. },
  375. },
  376. },
  377. Required: []string{"inputs"},
  378. AdditionalPropertiesAllowed: utils.Bool(false),
  379. }, expectedResponseSchema: expectedResponseSchemaColFmt(),
  380. },
  381. "UnknownRank": {
  382. tfSigDef: TFSignatureDef{
  383. Key: "Signature Def Key",
  384. Inputs: []TFTensor{
  385. {
  386. Name: "signal",
  387. DType: DtInt8,
  388. Rank: -1,
  389. },
  390. },
  391. Outputs: outputTensors(),
  392. },
  393. expectedRequestSchema: &openapi3.Schema{
  394. Type: "object",
  395. Properties: map[string]*openapi3.SchemaRef{
  396. "inputs": {
  397. Value: &openapi3.Schema{},
  398. },
  399. },
  400. Required: []string{"inputs"},
  401. AdditionalPropertiesAllowed: utils.Bool(false),
  402. },
  403. expectedResponseSchema: expectedResponseSchemaColFmt(),
  404. },
  405. "Scalar": {
  406. tfSigDef: TFSignatureDef{
  407. Key: "Signature Def Key",
  408. Inputs: []TFTensor{
  409. {
  410. Name: "signal",
  411. DType: DtInt8,
  412. Shape: TFShape{},
  413. Rank: 0,
  414. },
  415. },
  416. Outputs: outputTensors(),
  417. },
  418. expectedRequestSchema: &openapi3.Schema{
  419. Type: "object",
  420. Properties: map[string]*openapi3.SchemaRef{
  421. "inputs": {
  422. Value: &openapi3.Schema{
  423. Type: "number",
  424. },
  425. },
  426. },
  427. Required: []string{"inputs"},
  428. AdditionalPropertiesAllowed: utils.Bool(false),
  429. },
  430. expectedResponseSchema: expectedResponseSchemaColFmt(),
  431. },
  432. "MultipleScalar": {
  433. tfSigDef: TFSignatureDef{
  434. Key: "Signature Def Key",
  435. Inputs: []TFTensor{
  436. {
  437. Name: "signal",
  438. DType: DtInt8,
  439. Shape: TFShape{},
  440. Rank: 0,
  441. },
  442. {
  443. Name: "sensor",
  444. DType: DtInt8,
  445. Shape: TFShape{},
  446. Rank: 0,
  447. },
  448. },
  449. Outputs: outputTensors(),
  450. },
  451. expectedRequestSchema: &openapi3.Schema{
  452. Type: "object",
  453. Properties: map[string]*openapi3.SchemaRef{
  454. "inputs": {
  455. Value: &openapi3.Schema{
  456. Type: "object",
  457. Properties: map[string]*openapi3.SchemaRef{
  458. "signal": {
  459. Value: &openapi3.Schema{
  460. Type: "number",
  461. },
  462. },
  463. "sensor": {
  464. Value: &openapi3.Schema{
  465. Type: "number",
  466. },
  467. },
  468. },
  469. Required: []string{"signal", "sensor"},
  470. AdditionalPropertiesAllowed: utils.Bool(false),
  471. },
  472. },
  473. },
  474. Required: []string{"inputs"},
  475. AdditionalPropertiesAllowed: utils.Bool(false),
  476. },
  477. expectedResponseSchema: expectedResponseSchemaColFmt(),
  478. },
  479. }
  480. for name, scenario := range scenarios {
  481. t.Logf("Running %s ...", name)
  482. requestSchema, responseSchema, err := scenario.tfSigDef.Schema()
  483. g.Expect(requestSchema).Should(gomega.Equal(scenario.expectedRequestSchema))
  484. g.Expect(responseSchema).Should(gomega.Equal(scenario.expectedResponseSchema))
  485. g.Expect(err).To(gomega.BeNil())
  486. }
  487. }
  488. func TestTFSignatureDefNonPredict(t *testing.T) {
  489. g := gomega.NewGomegaWithT(t)
  490. tfSigDef := expectedTFSignatureDef()
  491. tfSigDef.Method = Classify
  492. _, _, err := tfSigDef.Schema()
  493. g.Expect(err).To(gomega.MatchError(UnsupportedAPISchemaError))
  494. }
  495. func TestTFSignatureDefInconsistentInputOutputFormatError(t *testing.T) {
  496. g := gomega.NewGomegaWithT(t)
  497. tfSigDef := TFSignatureDef{
  498. Key: "Signature Def Key",
  499. Method: Predict,
  500. Inputs: []TFTensor{
  501. {
  502. Name: "input",
  503. DType: DtInt8,
  504. Shape: TFShape{-1, 3},
  505. Rank: 2,
  506. },
  507. },
  508. Outputs: []TFTensor{
  509. {
  510. Name: "output",
  511. DType: DtInt8,
  512. Rank: -1,
  513. },
  514. },
  515. }
  516. _, _, err := tfSigDef.Schema()
  517. g.Expect(err).Should(gomega.MatchError(InconsistentInputOutputFormatError))
  518. }