/driver-reactive-streams/src/test/tck/com/mongodb/reactivestreams/client/MongoFixture.java

http://github.com/mongodb/mongo-java-driver · Java · 221 lines · 169 code · 34 blank · 18 comment · 12 complexity · 9ac0047f3511ed32afacf3d615bb868e MD5 · raw file

  1. /*
  2. * Copyright 2008-present MongoDB, Inc.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. package com.mongodb.reactivestreams.client;
  17. import com.mongodb.ClusterFixture;
  18. import com.mongodb.MongoClientSettings;
  19. import com.mongodb.MongoCommandException;
  20. import com.mongodb.MongoException;
  21. import com.mongodb.MongoNamespace;
  22. import com.mongodb.MongoTimeoutException;
  23. import org.bson.Document;
  24. import org.reactivestreams.Publisher;
  25. import org.reactivestreams.Subscriber;
  26. import org.reactivestreams.Subscription;
  27. import java.util.ArrayList;
  28. import java.util.List;
  29. import java.util.concurrent.CountDownLatch;
  30. import java.util.concurrent.TimeUnit;
  31. import static java.util.concurrent.TimeUnit.SECONDS;
  32. /**
  33. * Helper class for asynchronous tests.
  34. */
  35. public final class MongoFixture {
  36. private static MongoClient mongoClient;
  37. private MongoFixture() {
  38. }
  39. public static final long DEFAULT_TIMEOUT_MILLIS = 5000L;
  40. public static final long PUBLISHER_REFERENCE_CLEANUP_TIMEOUT_MILLIS = 1000L;
  41. public static synchronized MongoClient getMongoClient() {
  42. if (mongoClient == null) {
  43. mongoClient = MongoClients.create(getMongoClientSettings());
  44. Runtime.getRuntime().addShutdownHook(new ShutdownHook());
  45. }
  46. return mongoClient;
  47. }
  48. public static MongoClientSettings getMongoClientSettings() {
  49. return getMongoClientSettingsBuilder().build();
  50. }
  51. public static MongoClientSettings.Builder getMongoClientSettingsBuilder() {
  52. return MongoClientSettings.builder().applyConnectionString(ClusterFixture.getConnectionString());
  53. }
  54. public static String getDefaultDatabaseName() {
  55. return ClusterFixture.getDefaultDatabaseName();
  56. }
  57. public static MongoDatabase getDefaultDatabase() {
  58. return getMongoClient().getDatabase(getDefaultDatabaseName());
  59. }
  60. public static void dropDatabase(final String name) {
  61. if (name == null) {
  62. return;
  63. }
  64. try {
  65. run(getMongoClient().getDatabase(name).runCommand(new Document("dropDatabase", 1)));
  66. } catch (MongoCommandException e) {
  67. if (!e.getErrorMessage().contains("ns not found")) {
  68. throw e;
  69. }
  70. }
  71. }
  72. public static void drop(final MongoNamespace namespace) {
  73. try {
  74. run(getMongoClient().getDatabase(namespace.getDatabaseName()).runCommand(new Document("drop", namespace.getCollectionName())));
  75. } catch (MongoCommandException e) {
  76. if (!e.getErrorMessage().contains("ns not found")) {
  77. throw e;
  78. }
  79. }
  80. }
  81. public static <T> List<T> run(final Publisher<T> publisher) {
  82. return run(publisher, () -> {});
  83. }
  84. public static <T> List<T> run(final Publisher<T> publisher, final Runnable onRequest) {
  85. try {
  86. ObservableSubscriber<T> subscriber = new ObservableSubscriber<>(onRequest);
  87. publisher.subscribe(subscriber);
  88. return subscriber.get();
  89. } catch (Throwable t) {
  90. if (t instanceof RuntimeException) {
  91. throw (RuntimeException) t;
  92. }
  93. throw new RuntimeException(t);
  94. }
  95. }
  96. public static void cleanDatabases() {
  97. List<String> dbNames = MongoFixture.run(getMongoClient().listDatabaseNames());
  98. for (String dbName : dbNames) {
  99. if (dbName.startsWith(getDefaultDatabaseName())) {
  100. dropDatabase(dbName);
  101. }
  102. }
  103. }
  104. static class ShutdownHook extends Thread {
  105. @Override
  106. public void run() {
  107. cleanDatabases();
  108. mongoClient.close();
  109. mongoClient = null;
  110. }
  111. }
  112. public static class ObservableSubscriber<T> implements Subscriber<T> {
  113. private final List<T> received;
  114. private final List<Throwable> errors;
  115. private final CountDownLatch latch;
  116. private final Runnable onRequest;
  117. private volatile boolean requested;
  118. private volatile Subscription subscription;
  119. private volatile boolean completed;
  120. public ObservableSubscriber() {
  121. this(() -> {});
  122. }
  123. public ObservableSubscriber(final Runnable onRequest) {
  124. this.received = new ArrayList<T>();
  125. this.errors = new ArrayList<Throwable>();
  126. this.latch = new CountDownLatch(1);
  127. this.onRequest = onRequest;
  128. }
  129. @Override
  130. public void onSubscribe(final Subscription s) {
  131. subscription = s;
  132. }
  133. @Override
  134. public void onNext(final T t) {
  135. received.add(t);
  136. }
  137. @Override
  138. public void onError(final Throwable t) {
  139. errors.add(t);
  140. onComplete();
  141. }
  142. @Override
  143. public void onComplete() {
  144. completed = true;
  145. latch.countDown();
  146. }
  147. public Subscription getSubscription() {
  148. return subscription;
  149. }
  150. public List<T> getReceived() {
  151. return received;
  152. }
  153. public List<Throwable> getErrors() {
  154. return errors;
  155. }
  156. public boolean isCompleted() {
  157. return completed;
  158. }
  159. public List<T> get() {
  160. return await(60, SECONDS).getReceived();
  161. }
  162. public List<T> get(final long timeout, final TimeUnit unit) {
  163. return await(timeout, unit).getReceived();
  164. }
  165. public ObservableSubscriber<T> await(final long timeout, final TimeUnit unit) {
  166. return await(Integer.MAX_VALUE, timeout, unit);
  167. }
  168. public ObservableSubscriber<T> await(final int request, final long timeout, final TimeUnit unit) {
  169. subscription.request(request);
  170. if (!requested) {
  171. requested = true;
  172. onRequest.run();
  173. }
  174. try {
  175. if (!latch.await(timeout, unit)) {
  176. throw new MongoTimeoutException("Publisher onComplete timed out");
  177. }
  178. } catch (InterruptedException e) {
  179. throw new MongoException("Await failed", e);
  180. }
  181. if (!errors.isEmpty()) {
  182. throw new MongoException("Await failed", errors.get(0));
  183. }
  184. return this;
  185. }
  186. }
  187. }