PageRenderTime 53ms CodeModel.GetById 18ms RepoModel.GetById 1ms app.codeStats 0ms

/lucene/src/test/org/apache/lucene/search/function/TestCustomScoreQuery.java

https://github.com/simplegeo/lucene-solr-3.1
Java | 348 lines | 241 code | 50 blank | 57 comment | 16 complexity | 876e1a65fc7f1cb58757a43d25a5e768 MD5 | raw file
  1. package org.apache.lucene.search.function;
  2. /**
  3. * Licensed to the Apache Software Foundation (ASF) under one or more
  4. * contributor license agreements. See the NOTICE file distributed with
  5. * this work for additional information regarding copyright ownership.
  6. * The ASF licenses this file to You under the Apache License, Version 2.0
  7. * (the "License"); you may not use this file except in compliance with
  8. * the License. You may obtain a copy of the License at
  9. *
  10. * http://www.apache.org/licenses/LICENSE-2.0
  11. *
  12. * Unless required by applicable law or agreed to in writing, software
  13. * distributed under the License is distributed on an "AS IS" BASIS,
  14. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. * See the License for the specific language governing permissions and
  16. * limitations under the License.
  17. */
  18. import org.apache.lucene.queryParser.QueryParser;
  19. import org.apache.lucene.queryParser.ParseException;
  20. import org.apache.lucene.search.*;
  21. import org.junit.Test;
  22. import java.io.IOException;
  23. import java.util.HashMap;
  24. import java.util.Map;
  25. import org.apache.lucene.index.IndexReader;
  26. import org.apache.lucene.index.Term;
  27. /**
  28. * Test CustomScoreQuery search.
  29. */
  30. public class TestCustomScoreQuery extends FunctionTestSetup {
  31. /* @override constructor */
  32. public TestCustomScoreQuery() {
  33. super(true);
  34. }
  35. /**
  36. * Test that CustomScoreQuery of Type.BYTE returns the expected scores.
  37. */
  38. @Test
  39. public void testCustomScoreByte() throws Exception, ParseException {
  40. // INT field values are small enough to be parsed as byte
  41. doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.BYTE, 1.0);
  42. doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.BYTE, 2.0);
  43. }
  44. /**
  45. * Test that CustomScoreQuery of Type.SHORT returns the expected scores.
  46. */
  47. @Test
  48. public void testCustomScoreShort() throws Exception, ParseException {
  49. // INT field values are small enough to be parsed as short
  50. doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.SHORT, 1.0);
  51. doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.SHORT, 3.0);
  52. }
  53. /**
  54. * Test that CustomScoreQuery of Type.INT returns the expected scores.
  55. */
  56. @Test
  57. public void testCustomScoreInt() throws Exception, ParseException {
  58. doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.INT, 1.0);
  59. doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.INT, 4.0);
  60. }
  61. /**
  62. * Test that CustomScoreQuery of Type.FLOAT returns the expected scores.
  63. */
  64. @Test
  65. public void testCustomScoreFloat() throws Exception, ParseException {
  66. // INT field can be parsed as float
  67. doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.FLOAT, 1.0);
  68. doTestCustomScore(INT_FIELD, FieldScoreQuery.Type.FLOAT, 5.0);
  69. // same values, but in float format
  70. doTestCustomScore(FLOAT_FIELD, FieldScoreQuery.Type.FLOAT, 1.0);
  71. doTestCustomScore(FLOAT_FIELD, FieldScoreQuery.Type.FLOAT, 6.0);
  72. }
  73. // must have static class otherwise serialization tests fail
  74. private static class CustomAddQuery extends CustomScoreQuery {
  75. // constructor
  76. CustomAddQuery(Query q, ValueSourceQuery qValSrc) {
  77. super(q, qValSrc);
  78. }
  79. /*(non-Javadoc) @see org.apache.lucene.search.function.CustomScoreQuery#name() */
  80. @Override
  81. public String name() {
  82. return "customAdd";
  83. }
  84. @Override
  85. protected CustomScoreProvider getCustomScoreProvider(IndexReader reader) {
  86. return new CustomScoreProvider(reader) {
  87. @Override
  88. public float customScore(int doc, float subQueryScore, float valSrcScore) {
  89. return subQueryScore + valSrcScore;
  90. }
  91. @Override
  92. public Explanation customExplain(int doc, Explanation subQueryExpl, Explanation valSrcExpl) {
  93. float valSrcScore = valSrcExpl == null ? 0 : valSrcExpl.getValue();
  94. Explanation exp = new Explanation(valSrcScore + subQueryExpl.getValue(), "custom score: sum of:");
  95. exp.addDetail(subQueryExpl);
  96. if (valSrcExpl != null) {
  97. exp.addDetail(valSrcExpl);
  98. }
  99. return exp;
  100. }
  101. };
  102. }
  103. }
  104. // must have static class otherwise serialization tests fail
  105. private static class CustomMulAddQuery extends CustomScoreQuery {
  106. // constructor
  107. CustomMulAddQuery(Query q, ValueSourceQuery qValSrc1, ValueSourceQuery qValSrc2) {
  108. super(q, new ValueSourceQuery[]{qValSrc1, qValSrc2});
  109. }
  110. /*(non-Javadoc) @see org.apache.lucene.search.function.CustomScoreQuery#name() */
  111. @Override
  112. public String name() {
  113. return "customMulAdd";
  114. }
  115. @Override
  116. protected CustomScoreProvider getCustomScoreProvider(IndexReader reader) {
  117. return new CustomScoreProvider(reader) {
  118. @Override
  119. public float customScore(int doc, float subQueryScore, float valSrcScores[]) {
  120. if (valSrcScores.length == 0) {
  121. return subQueryScore;
  122. }
  123. if (valSrcScores.length == 1) {
  124. return subQueryScore + valSrcScores[0];
  125. // confirm that skipping beyond the last doc, on the
  126. // previous reader, hits NO_MORE_DOCS
  127. }
  128. return (subQueryScore + valSrcScores[0]) * valSrcScores[1]; // we know there are two
  129. }
  130. @Override
  131. public Explanation customExplain(int doc, Explanation subQueryExpl, Explanation valSrcExpls[]) {
  132. if (valSrcExpls.length == 0) {
  133. return subQueryExpl;
  134. }
  135. Explanation exp = new Explanation(valSrcExpls[0].getValue() + subQueryExpl.getValue(), "sum of:");
  136. exp.addDetail(subQueryExpl);
  137. exp.addDetail(valSrcExpls[0]);
  138. if (valSrcExpls.length == 1) {
  139. exp.setDescription("CustomMulAdd, sum of:");
  140. return exp;
  141. }
  142. Explanation exp2 = new Explanation(valSrcExpls[1].getValue() * exp.getValue(), "custom score: product of:");
  143. exp2.addDetail(valSrcExpls[1]);
  144. exp2.addDetail(exp);
  145. return exp2;
  146. }
  147. };
  148. }
  149. }
  150. private final class CustomExternalQuery extends CustomScoreQuery {
  151. @Override
  152. protected CustomScoreProvider getCustomScoreProvider(IndexReader reader) throws IOException {
  153. final int[] values = FieldCache.DEFAULT.getInts(reader, INT_FIELD);
  154. return new CustomScoreProvider(reader) {
  155. @Override
  156. public float customScore(int doc, float subScore, float valSrcScore) throws IOException {
  157. assertTrue(doc <= reader.maxDoc());
  158. return values[doc];
  159. }
  160. };
  161. }
  162. public CustomExternalQuery(Query q) {
  163. super(q);
  164. }
  165. }
  166. @Test
  167. public void testCustomExternalQuery() throws Exception {
  168. QueryParser qp = new QueryParser(TEST_VERSION_CURRENT, TEXT_FIELD,anlzr);
  169. String qtxt = "first aid text"; // from the doc texts in FunctionQuerySetup.
  170. Query q1 = qp.parse(qtxt);
  171. final Query q = new CustomExternalQuery(q1);
  172. log(q);
  173. IndexSearcher s = new IndexSearcher(dir);
  174. TopDocs hits = s.search(q, 1000);
  175. assertEquals(N_DOCS, hits.totalHits);
  176. for(int i=0;i<N_DOCS;i++) {
  177. final int doc = hits.scoreDocs[i].doc;
  178. final float score = hits.scoreDocs[i].score;
  179. assertEquals("doc=" + doc, (float) 1+(4*doc) % N_DOCS, score, 0.0001);
  180. }
  181. s.close();
  182. }
  183. @Test
  184. public void testRewrite() throws Exception {
  185. final IndexSearcher s = new IndexSearcher(dir, true);
  186. Query q = new TermQuery(new Term(TEXT_FIELD, "first"));
  187. CustomScoreQuery original = new CustomScoreQuery(q);
  188. CustomScoreQuery rewritten = (CustomScoreQuery) original.rewrite(s.getIndexReader());
  189. assertTrue("rewritten query should be identical, as TermQuery does not rewrite", original == rewritten);
  190. assertTrue("no hits for query", s.search(rewritten,1).totalHits > 0);
  191. assertEquals(s.search(q,1).totalHits, s.search(rewritten,1).totalHits);
  192. q = new TermRangeQuery(TEXT_FIELD, null, null, true, true); // everything
  193. original = new CustomScoreQuery(q);
  194. rewritten = (CustomScoreQuery) original.rewrite(s.getIndexReader());
  195. assertTrue("rewritten query should not be identical, as TermRangeQuery rewrites", original != rewritten);
  196. assertTrue("no hits for query", s.search(rewritten,1).totalHits > 0);
  197. assertEquals(s.search(q,1).totalHits, s.search(original,1).totalHits);
  198. assertEquals(s.search(q,1).totalHits, s.search(rewritten,1).totalHits);
  199. s.close();
  200. }
  201. // Test that FieldScoreQuery returns docs with expected score.
  202. private void doTestCustomScore(String field, FieldScoreQuery.Type tp, double dboost) throws Exception, ParseException {
  203. float boost = (float) dboost;
  204. IndexSearcher s = new IndexSearcher(dir, true);
  205. FieldScoreQuery qValSrc = new FieldScoreQuery(field, tp); // a query that would score by the field
  206. QueryParser qp = new QueryParser(TEST_VERSION_CURRENT, TEXT_FIELD, anlzr);
  207. String qtxt = "first aid text"; // from the doc texts in FunctionQuerySetup.
  208. // regular (boolean) query.
  209. Query q1 = qp.parse(qtxt);
  210. log(q1);
  211. // custom query, that should score the same as q1.
  212. Query q2CustomNeutral = new CustomScoreQuery(q1);
  213. q2CustomNeutral.setBoost(boost);
  214. log(q2CustomNeutral);
  215. // custom query, that should (by default) multiply the scores of q1 by that of the field
  216. CustomScoreQuery q3CustomMul = new CustomScoreQuery(q1, qValSrc);
  217. q3CustomMul.setStrict(true);
  218. q3CustomMul.setBoost(boost);
  219. log(q3CustomMul);
  220. // custom query, that should add the scores of q1 to that of the field
  221. CustomScoreQuery q4CustomAdd = new CustomAddQuery(q1, qValSrc);
  222. q4CustomAdd.setStrict(true);
  223. q4CustomAdd.setBoost(boost);
  224. log(q4CustomAdd);
  225. // custom query, that multiplies and adds the field score to that of q1
  226. CustomScoreQuery q5CustomMulAdd = new CustomMulAddQuery(q1, qValSrc, qValSrc);
  227. q5CustomMulAdd.setStrict(true);
  228. q5CustomMulAdd.setBoost(boost);
  229. log(q5CustomMulAdd);
  230. // do al the searches
  231. TopDocs td1 = s.search(q1, null, 1000);
  232. TopDocs td2CustomNeutral = s.search(q2CustomNeutral, null, 1000);
  233. TopDocs td3CustomMul = s.search(q3CustomMul, null, 1000);
  234. TopDocs td4CustomAdd = s.search(q4CustomAdd, null, 1000);
  235. TopDocs td5CustomMulAdd = s.search(q5CustomMulAdd, null, 1000);
  236. // put results in map so we can verify the scores although they have changed
  237. Map<Integer,Float> h1 = topDocsToMap(td1);
  238. Map<Integer,Float> h2CustomNeutral = topDocsToMap(td2CustomNeutral);
  239. Map<Integer,Float> h3CustomMul = topDocsToMap(td3CustomMul);
  240. Map<Integer,Float> h4CustomAdd = topDocsToMap(td4CustomAdd);
  241. Map<Integer,Float> h5CustomMulAdd = topDocsToMap(td5CustomMulAdd);
  242. verifyResults(boost, s,
  243. h1, h2CustomNeutral, h3CustomMul, h4CustomAdd, h5CustomMulAdd,
  244. q1, q2CustomNeutral, q3CustomMul, q4CustomAdd, q5CustomMulAdd);
  245. s.close();
  246. }
  247. // verify results are as expected.
  248. private void verifyResults(float boost, IndexSearcher s,
  249. Map<Integer,Float> h1, Map<Integer,Float> h2customNeutral, Map<Integer,Float> h3CustomMul, Map<Integer,Float> h4CustomAdd, Map<Integer,Float> h5CustomMulAdd,
  250. Query q1, Query q2, Query q3, Query q4, Query q5) throws Exception {
  251. // verify numbers of matches
  252. log("#hits = "+h1.size());
  253. assertEquals("queries should have same #hits",h1.size(),h2customNeutral.size());
  254. assertEquals("queries should have same #hits",h1.size(),h3CustomMul.size());
  255. assertEquals("queries should have same #hits",h1.size(),h4CustomAdd.size());
  256. assertEquals("queries should have same #hits",h1.size(),h5CustomMulAdd.size());
  257. QueryUtils.check(random, q1,s);
  258. QueryUtils.check(random, q2,s);
  259. QueryUtils.check(random, q3,s);
  260. QueryUtils.check(random, q4,s);
  261. QueryUtils.check(random, q5,s);
  262. // verify scores ratios
  263. for (final Integer doc : h1.keySet()) {
  264. log("doc = "+doc);
  265. float fieldScore = expectedFieldScore(s.getIndexReader().document(doc).get(ID_FIELD));
  266. log("fieldScore = " + fieldScore);
  267. assertTrue("fieldScore should not be 0", fieldScore > 0);
  268. float score1 = h1.get(doc);
  269. logResult("score1=", s, q1, doc, score1);
  270. float score2 = h2customNeutral.get(doc);
  271. logResult("score2=", s, q2, doc, score2);
  272. assertEquals("same score (just boosted) for neutral", boost * score1, score2, TEST_SCORE_TOLERANCE_DELTA);
  273. float score3 = h3CustomMul.get(doc);
  274. logResult("score3=", s, q3, doc, score3);
  275. assertEquals("new score for custom mul", boost * fieldScore * score1, score3, TEST_SCORE_TOLERANCE_DELTA);
  276. float score4 = h4CustomAdd.get(doc);
  277. logResult("score4=", s, q4, doc, score4);
  278. assertEquals("new score for custom add", boost * (fieldScore + score1), score4, TEST_SCORE_TOLERANCE_DELTA);
  279. float score5 = h5CustomMulAdd.get(doc);
  280. logResult("score5=", s, q5, doc, score5);
  281. assertEquals("new score for custom mul add", boost * fieldScore * (score1 + fieldScore), score5, TEST_SCORE_TOLERANCE_DELTA);
  282. }
  283. }
  284. private void logResult(String msg, Searcher s, Query q, int doc, float score1) throws IOException {
  285. log(msg+" "+score1);
  286. log("Explain by: "+q);
  287. log(s.explain(q,doc));
  288. }
  289. // since custom scoring modifies the order of docs, map results
  290. // by doc ids so that we can later compare/verify them
  291. private Map<Integer,Float> topDocsToMap(TopDocs td) {
  292. Map<Integer,Float> h = new HashMap<Integer,Float>();
  293. for (int i=0; i<td.totalHits; i++) {
  294. h.put(td.scoreDocs[i].doc, td.scoreDocs[i].score);
  295. }
  296. return h;
  297. }
  298. }