/examples/experimental/java-local-helloworld/MyAlgorithm.java

https://gitlab.com/ggsaavedra/PredictionIO · Java · 50 lines · 38 code · 10 blank · 2 comment · 5 complexity · 26ac5577ab7e82b22efea79b3024a4b9 MD5 · raw file

  1. package org.sample.java.helloworld;
  2. import io.prediction.controller.java.*;
  3. import java.util.Map;
  4. import java.util.HashMap;
  5. import org.slf4j.Logger;
  6. import org.slf4j.LoggerFactory;
  7. public class MyAlgorithm extends LJavaAlgorithm<
  8. EmptyAlgorithmParams, MyTrainingData, MyModel, MyQuery, MyPredictedResult> {
  9. final static Logger logger = LoggerFactory.getLogger(MyAlgorithm.class);
  10. @Override
  11. public MyModel train(MyTrainingData data) {
  12. Map<String, Double> sumMap = new HashMap<String, Double>();
  13. Map<String, Integer> countMap = new HashMap<String, Integer>();
  14. // calculate sum and count for each day
  15. for (MyTrainingData.DayTemperature temp : data.temperatures) {
  16. Double sum = sumMap.get(temp.day);
  17. Integer count = countMap.get(temp.day);
  18. if (sum == null) {
  19. sumMap.put(temp.day, temp.temperature);
  20. countMap.put(temp.day, 1);
  21. } else {
  22. sumMap.put(temp.day, sum + temp.temperature);
  23. countMap.put(temp.day, count + 1);
  24. }
  25. }
  26. // calculate the average
  27. Map<String, Double> averageMap = new HashMap<String, Double>();
  28. for (Map.Entry<String, Double> entry : sumMap.entrySet()) {
  29. String day = entry.getKey();
  30. Double average = entry.getValue() / countMap.get(day);
  31. averageMap.put(day, average);
  32. }
  33. return new MyModel(averageMap);
  34. }
  35. @Override
  36. public MyPredictedResult predict(MyModel model, MyQuery query) {
  37. Double temp = model.temperatures.get(query.day);
  38. return new MyPredictedResult(temp);
  39. }
  40. }