/mltool/evaluate.py
Python | 95 lines | 45 code | 15 blank | 35 comment | 7 complexity | 3b062590ba2af36b2c6bd2a8abb14ff2 MD5 | raw file
1#!/usr/bin/env python
2"""
3mltool.evaluate
4~~~~~~~~~~~~~~~
5
6The metrics considered for the evaluation are two:
7
8- `NDCG`_
9- `RMSE`_
10
11.. _`NDCG`: http://en.wikipedia.org/wiki/Discounted_cumulative_gain
12.. _`RMSE`: http://en.wikipedia.org/wiki/Mean_squared_error
13
14"""
15import math
16from itertools import izip, groupby
17from operator import itemgetter
18
19import numpy as np
20
21from mltool.predict import predict_all
22
23
24def dcg(preds, labels):
25 order = np.argsort(preds)[::-1]
26 preds = np.take(preds, order)
27 labels = np.take(labels, order)
28
29 D = np.log2(np.arange(2, len(labels)+2))
30 DCG = np.cumsum((2**labels - 1)/D)
31 return DCG
32
33
34def evaluate_preds(preds, dataset, ndcg_at=10):
35 """Evaluate predicted value against a labelled dataset.
36
37 :param preds: predicted values, in the same order as the samples in the
38 dataset
39 :param dataset: a Dataset object with all labels set
40 :param ndcg_at: position at which evaluate NDCG
41 :type preds: list-like
42
43 :return: Return the pair RMSE and NDCG scores.
44
45 """
46 ndcg = 0.0
47 rmse = 0.0
48 nqueries = 0
49 count = 0
50
51 for _, resultset in groupby(izip(dataset.queries,
52 preds, dataset.labels),
53 key=itemgetter(0)):
54 resultset = list(resultset)
55 labels = np.array([l for (_, _, l) in resultset])
56 preds = np.array([p for (_, p, _) in resultset])
57
58 rmse += np.sum((labels - preds) ** 2)
59 count += len(labels)
60
61 if labels.any():
62 ideal_DCG = dcg(labels, labels)
63 DCG = dcg(preds, labels)
64 k = min(ndcg_at, len(DCG)) - 1
65 ndcg += DCG[k] / ideal_DCG[k]
66 nqueries += 1
67 else:
68 ndcg += 0.5
69 nqueries += 1
70
71 rmse /= count
72 rmse = math.sqrt(rmse)
73 ndcg /= nqueries
74
75 return rmse, ndcg
76
77
78def evaluate_model(model, dataset, ndcg_at=10, return_preds=False):
79 """Evaluate a model against a labelled dataset.
80
81 :param model: the model to evaluate
82 :param dataset: a Dataset object with all labels set
83 :param ndcg_at: position at which evaluate NDCG
84 :type preds: list-like
85
86 :return: Return the pair RMSE and NDCG scores.
87
88 """
89 preds = predict_all(model, dataset)
90 rmse, ndcg = evaluate_preds(preds, dataset, ndcg_at)
91
92 if return_preds:
93 return rmse, ndcg, preds
94 else:
95 return rmse, ndcg