/baleen-entity-linking/src/main/java/uk/gov/dstl/baleen/entity/linking/supplier/mongo/MongoCandidateSupplier.java

https://github.com/dstl/baleen · Java · 128 lines · 92 code · 26 blank · 10 comment · 7 complexity · 118738e5aa1934afcc21998b23c8a51c MD5 · raw file

  1. // Copyright (c) Committed Software 2018, opensource@committed.io
  2. package uk.gov.dstl.baleen.entity.linking.supplier.mongo;
  3. import static com.mongodb.client.model.Filters.or;
  4. import static com.mongodb.client.model.Filters.regex;
  5. import java.util.ArrayList;
  6. import java.util.Arrays;
  7. import java.util.Collection;
  8. import java.util.HashMap;
  9. import java.util.HashSet;
  10. import java.util.List;
  11. import java.util.Map;
  12. import java.util.Optional;
  13. import java.util.function.Function;
  14. import org.bson.Document;
  15. import org.bson.conversions.Bson;
  16. import com.mongodb.client.FindIterable;
  17. import com.mongodb.client.MongoCollection;
  18. import com.mongodb.client.MongoDatabase;
  19. import uk.gov.dstl.baleen.entity.linking.Candidate;
  20. import uk.gov.dstl.baleen.entity.linking.CandidateSupplier;
  21. import uk.gov.dstl.baleen.entity.linking.EntityInformation;
  22. import uk.gov.dstl.baleen.entity.linking.util.DefaultCandidate;
  23. import uk.gov.dstl.baleen.entity.linking.util.StringArgumentsHandler;
  24. import uk.gov.dstl.baleen.exceptions.BaleenException;
  25. import uk.gov.dstl.baleen.types.semantic.Entity;
  26. /**
  27. * Candidate Supplier for retrieving candidates from Mongo
  28. *
  29. * @param <T> The type of Entity the Candidates relate to
  30. */
  31. public class MongoCandidateSupplier<T extends Entity> implements CandidateSupplier<T> {
  32. private static final String DEFAULT_MONGO_ID = "_id";
  33. /** The Mongo Collection name */
  34. public static final String PARAM_COLLECTION = "collection";
  35. /** The Mongo field to search against */
  36. public static final String PARAM_SEARCH_FIELD = "searchField";
  37. /** Thew document ID field */
  38. public static final String PARAM_ID_FIELD = "idField";
  39. private Map<String, String> argumentsMap = new HashMap<>();
  40. private final Function<Map<String, String>, MongoFactory> factorySupplier;
  41. private MongoFactory mongoFactory;
  42. private MongoDatabase mongoDatabase;
  43. /** Default constructor */
  44. public MongoCandidateSupplier() {
  45. this(RealMongoFactory::new);
  46. }
  47. MongoCandidateSupplier(Function<Map<String, String>, MongoFactory> factorySupplier) {
  48. this.factorySupplier = factorySupplier;
  49. }
  50. @Override
  51. public Collection<Candidate> getCandidates(EntityInformation<T> entityInformation) {
  52. Collection<Candidate> candidates = new HashSet<>();
  53. MongoCollection<Document> collection =
  54. mongoDatabase.getCollection(argumentsMap.get(PARAM_COLLECTION));
  55. Optional<Bson> buildQuery = buildQuery(entityInformation);
  56. if (buildQuery.isPresent()) {
  57. FindIterable<Document> documents = collection.find(buildQuery.get());
  58. for (Document document : documents) {
  59. Map<String, String> map = new MongoDocumentFlattener(document).flatten();
  60. String candidateID =
  61. document.get(argumentsMap.getOrDefault(PARAM_ID_FIELD, DEFAULT_MONGO_ID)).toString();
  62. String candidateName = document.get(argumentsMap.get(PARAM_SEARCH_FIELD)).toString();
  63. candidates.add(new DefaultCandidate(candidateID, candidateName, map));
  64. }
  65. }
  66. return candidates;
  67. }
  68. @Override
  69. public void configure(String[] argumentPairs) throws BaleenException {
  70. argumentsMap = new StringArgumentsHandler(argumentPairs).createStringsMap();
  71. mongoFactory = factorySupplier.apply(argumentsMap);
  72. mongoDatabase = mongoFactory.createDatabase();
  73. }
  74. private Optional<Bson> buildQuery(EntityInformation<T> entityInformation) {
  75. List<String> searchValues = new ArrayList<>();
  76. entityInformation.getMentions().stream()
  77. .filter(m -> m.getValue() != null)
  78. .forEach(
  79. mention -> {
  80. String[] mentionSearchTerms = mention.getValue().split(" ");
  81. searchValues.addAll(Arrays.asList(mentionSearchTerms));
  82. });
  83. if (searchValues.isEmpty()) {
  84. return Optional.empty();
  85. }
  86. List<Bson> bsonList = new ArrayList<>();
  87. for (String partialSearchTerm : searchValues) {
  88. bsonList.add(regex(argumentsMap.get(PARAM_SEARCH_FIELD), partialSearchTerm, "i"));
  89. }
  90. return Optional.of(or(bsonList));
  91. }
  92. @Override
  93. public void close() throws BaleenException {
  94. if (mongoFactory != null) {
  95. try {
  96. mongoFactory.close();
  97. } catch (Exception e) {
  98. throw new BaleenException(e);
  99. }
  100. }
  101. }
  102. }