/examples/android/app/src/main/java/com/tensorspeech/tensorflowtts/module/FastSpeech2.java

https://github.com/dathudeptrai/TensorflowTTS · Java · 82 lines · 64 code · 13 blank · 5 comment · 2 complexity · bdebae2687b83da0dac84eae411940e5 MD5 · raw file

  1. package com.tensorspeech.tensorflowtts.module;
  2. import android.annotation.SuppressLint;
  3. import android.util.Log;
  4. import org.tensorflow.lite.DataType;
  5. import org.tensorflow.lite.Interpreter;
  6. import org.tensorflow.lite.Tensor;
  7. import org.tensorflow.lite.support.tensorbuffer.TensorBuffer;
  8. import java.io.File;
  9. import java.nio.FloatBuffer;
  10. import java.util.Arrays;
  11. import java.util.HashMap;
  12. import java.util.Map;
  13. /**
  14. * @author {@link "mailto:xuefeng.ding@outlook.com" "Xuefeng Ding"}
  15. * Created 2020-07-20 17:26
  16. *
  17. */
  18. public class FastSpeech2 extends AbstractModule {
  19. private static final String TAG = "FastSpeech2";
  20. private Interpreter mModule;
  21. public FastSpeech2(String modulePath) {
  22. try {
  23. mModule = new Interpreter(new File(modulePath), getOption());
  24. int input = mModule.getInputTensorCount();
  25. for (int i = 0; i < input; i++) {
  26. Tensor inputTensor = mModule.getInputTensor(i);
  27. Log.d(TAG, "input:" + i +
  28. " name:" + inputTensor.name() +
  29. " shape:" + Arrays.toString(inputTensor.shape()) +
  30. " dtype:" + inputTensor.dataType());
  31. }
  32. int output = mModule.getOutputTensorCount();
  33. for (int i = 0; i < output; i++) {
  34. Tensor outputTensor = mModule.getOutputTensor(i);
  35. Log.d(TAG, "output:" + i +
  36. " name:" + outputTensor.name() +
  37. " shape:" + Arrays.toString(outputTensor.shape()) +
  38. " dtype:" + outputTensor.dataType());
  39. }
  40. Log.d(TAG, "successfully init");
  41. } catch (Exception e) {
  42. e.printStackTrace();
  43. }
  44. }
  45. public TensorBuffer getMelSpectrogram(int[] inputIds, float speed) {
  46. Log.d(TAG, "input id length: " + inputIds.length);
  47. mModule.resizeInput(0, new int[]{1, inputIds.length});
  48. mModule.allocateTensors();
  49. @SuppressLint("UseSparseArrays")
  50. Map<Integer, Object> outputMap = new HashMap<>();
  51. FloatBuffer outputBuffer = FloatBuffer.allocate(350000);
  52. outputMap.put(0, outputBuffer);
  53. int[][] inputs = new int[1][inputIds.length];
  54. inputs[0] = inputIds;
  55. long time = System.currentTimeMillis();
  56. mModule.runForMultipleInputsOutputs(
  57. new Object[]{inputs, new int[1][1], new int[]{0}, new float[]{speed}, new float[]{1F}, new float[]{1F}},
  58. outputMap);
  59. Log.d(TAG, "time cost: " + (System.currentTimeMillis() - time));
  60. int size = mModule.getOutputTensor(0).shape()[2];
  61. int[] shape = {1, outputBuffer.position() / size, size};
  62. TensorBuffer spectrogram = TensorBuffer.createFixedSize(shape, DataType.FLOAT32);
  63. float[] outputArray = new float[outputBuffer.position()];
  64. outputBuffer.rewind();
  65. outputBuffer.get(outputArray);
  66. spectrogram.loadArray(outputArray);
  67. return spectrogram;
  68. }
  69. }