/mltool/evaluate.py

https://bitbucket.org/duilio/mltool · Python · 95 lines · 45 code · 15 blank · 35 comment · 7 complexity · 3b062590ba2af36b2c6bd2a8abb14ff2 MD5 · raw file

  1. #!/usr/bin/env python
  2. """
  3. mltool.evaluate
  4. ~~~~~~~~~~~~~~~
  5. The metrics considered for the evaluation are two:
  6. - `NDCG`_
  7. - `RMSE`_
  8. .. _`NDCG`: http://en.wikipedia.org/wiki/Discounted_cumulative_gain
  9. .. _`RMSE`: http://en.wikipedia.org/wiki/Mean_squared_error
  10. """
  11. import math
  12. from itertools import izip, groupby
  13. from operator import itemgetter
  14. import numpy as np
  15. from mltool.predict import predict_all
  16. def dcg(preds, labels):
  17. order = np.argsort(preds)[::-1]
  18. preds = np.take(preds, order)
  19. labels = np.take(labels, order)
  20. D = np.log2(np.arange(2, len(labels)+2))
  21. DCG = np.cumsum((2**labels - 1)/D)
  22. return DCG
  23. def evaluate_preds(preds, dataset, ndcg_at=10):
  24. """Evaluate predicted value against a labelled dataset.
  25. :param preds: predicted values, in the same order as the samples in the
  26. dataset
  27. :param dataset: a Dataset object with all labels set
  28. :param ndcg_at: position at which evaluate NDCG
  29. :type preds: list-like
  30. :return: Return the pair RMSE and NDCG scores.
  31. """
  32. ndcg = 0.0
  33. rmse = 0.0
  34. nqueries = 0
  35. count = 0
  36. for _, resultset in groupby(izip(dataset.queries,
  37. preds, dataset.labels),
  38. key=itemgetter(0)):
  39. resultset = list(resultset)
  40. labels = np.array([l for (_, _, l) in resultset])
  41. preds = np.array([p for (_, p, _) in resultset])
  42. rmse += np.sum((labels - preds) ** 2)
  43. count += len(labels)
  44. if labels.any():
  45. ideal_DCG = dcg(labels, labels)
  46. DCG = dcg(preds, labels)
  47. k = min(ndcg_at, len(DCG)) - 1
  48. ndcg += DCG[k] / ideal_DCG[k]
  49. nqueries += 1
  50. else:
  51. ndcg += 0.5
  52. nqueries += 1
  53. rmse /= count
  54. rmse = math.sqrt(rmse)
  55. ndcg /= nqueries
  56. return rmse, ndcg
  57. def evaluate_model(model, dataset, ndcg_at=10, return_preds=False):
  58. """Evaluate a model against a labelled dataset.
  59. :param model: the model to evaluate
  60. :param dataset: a Dataset object with all labels set
  61. :param ndcg_at: position at which evaluate NDCG
  62. :type preds: list-like
  63. :return: Return the pair RMSE and NDCG scores.
  64. """
  65. preds = predict_all(model, dataset)
  66. rmse, ndcg = evaluate_preds(preds, dataset, ndcg_at)
  67. if return_preds:
  68. return rmse, ndcg, preds
  69. else:
  70. return rmse, ndcg