PageRenderTime 87ms CodeModel.GetById 17ms RepoModel.GetById 1ms app.codeStats 0ms

/aima-core/src/test/java/aima/test/core/unit/probability/CommonProbabilityModelTests.java

http://aima-java.googlecode.com/
Java | 379 lines | 278 code | 40 blank | 61 comment | 3 complexity | 5d4557b896a6fc5fc5ea0ee4e06f5bda MD5 | raw file
Possible License(s): GPL-3.0, Apache-2.0
  1. package aima.test.core.unit.probability;
  2. import org.junit.Assert;
  3. import aima.core.probability.ProbabilityModel;
  4. import aima.core.probability.domain.FiniteIntegerDomain;
  5. import aima.core.probability.example.ExampleRV;
  6. import aima.core.probability.proposition.AssignmentProposition;
  7. import aima.core.probability.proposition.ConjunctiveProposition;
  8. import aima.core.probability.proposition.DisjunctiveProposition;
  9. import aima.core.probability.proposition.EquivalentProposition;
  10. import aima.core.probability.proposition.IntegerSumProposition;
  11. import aima.core.probability.proposition.SubsetProposition;
  12. /**
  13. * @author Ciaran O'Reilly
  14. *
  15. */
  16. public abstract class CommonProbabilityModelTests {
  17. public static final double DELTA_THRESHOLD = ProbabilityModel.DEFAULT_ROUNDING_THRESHOLD;
  18. //
  19. // PROTECTED METHODS
  20. //
  21. protected void test_RollingPairFairDiceModel(ProbabilityModel model) {
  22. Assert.assertTrue(model.isValid());
  23. // Ensure each dice has 1/6 probability
  24. for (int d = 1; d <= 6; d++) {
  25. AssignmentProposition ad1 = new AssignmentProposition(
  26. ExampleRV.DICE_1_RV, d);
  27. AssignmentProposition ad2 = new AssignmentProposition(
  28. ExampleRV.DICE_2_RV, d);
  29. Assert.assertEquals(1.0 / 6.0, model.prior(ad1), DELTA_THRESHOLD);
  30. Assert.assertEquals(1.0 / 6.0, model.prior(ad2), DELTA_THRESHOLD);
  31. }
  32. // Ensure each combination is 1/36
  33. for (int d1 = 1; d1 <= 6; d1++) {
  34. for (int d2 = 1; d2 <= 6; d2++) {
  35. AssignmentProposition ad1 = new AssignmentProposition(
  36. ExampleRV.DICE_1_RV, d1);
  37. AssignmentProposition ad2 = new AssignmentProposition(
  38. ExampleRV.DICE_2_RV, d2);
  39. ConjunctiveProposition d1AndD2 = new ConjunctiveProposition(
  40. ad1, ad2);
  41. Assert.assertEquals(1.0 / 6.0, model.prior(ad1),
  42. DELTA_THRESHOLD);
  43. Assert.assertEquals(1.0 / 6.0, model.prior(ad2),
  44. DELTA_THRESHOLD);
  45. // pg. 485 AIMA3e
  46. Assert.assertEquals(1.0 / 36.0, model.prior(ad1, ad2),
  47. DELTA_THRESHOLD);
  48. Assert.assertEquals(1.0 / 36.0, model.prior(d1AndD2),
  49. DELTA_THRESHOLD);
  50. Assert.assertEquals(1.0 / 6.0, model.posterior(ad1, ad2),
  51. DELTA_THRESHOLD);
  52. Assert.assertEquals(1.0 / 6.0, model.posterior(ad2, ad1),
  53. DELTA_THRESHOLD);
  54. }
  55. }
  56. // Test Sets of events defined via constraint propositions
  57. IntegerSumProposition total11 = new IntegerSumProposition("Total11",
  58. new FiniteIntegerDomain(11), ExampleRV.DICE_1_RV,
  59. ExampleRV.DICE_2_RV);
  60. Assert.assertEquals(2.0 / 36.0, model.prior(total11), DELTA_THRESHOLD);
  61. EquivalentProposition doubles = new EquivalentProposition("Doubles",
  62. ExampleRV.DICE_1_RV, ExampleRV.DICE_2_RV);
  63. Assert.assertEquals(1.0 / 6.0, model.prior(doubles), DELTA_THRESHOLD);
  64. SubsetProposition evenDice1 = new SubsetProposition("EvenDice1",
  65. new FiniteIntegerDomain(2, 4, 6), ExampleRV.DICE_1_RV);
  66. Assert.assertEquals(0.5, model.prior(evenDice1), DELTA_THRESHOLD);
  67. SubsetProposition oddDice2 = new SubsetProposition("OddDice2",
  68. new FiniteIntegerDomain(1, 3, 5), ExampleRV.DICE_2_RV);
  69. Assert.assertEquals(0.5, model.prior(oddDice2), DELTA_THRESHOLD);
  70. // pg. 485 AIMA3e
  71. AssignmentProposition dice1Is5 = new AssignmentProposition(
  72. ExampleRV.DICE_1_RV, 5);
  73. Assert.assertEquals(1.0 / 6.0, model.posterior(doubles, dice1Is5),
  74. DELTA_THRESHOLD);
  75. Assert.assertEquals(1.0, model.prior(ExampleRV.DICE_1_RV),
  76. DELTA_THRESHOLD);
  77. Assert.assertEquals(1.0, model.prior(ExampleRV.DICE_2_RV),
  78. DELTA_THRESHOLD);
  79. Assert.assertEquals(1.0,
  80. model.posterior(ExampleRV.DICE_1_RV, ExampleRV.DICE_2_RV),
  81. DELTA_THRESHOLD);
  82. Assert.assertEquals(1.0,
  83. model.posterior(ExampleRV.DICE_2_RV, ExampleRV.DICE_1_RV),
  84. DELTA_THRESHOLD);
  85. // Test a disjunctive proposition pg.489
  86. // P(a OR b) = P(a) + P(b) - P(a AND b)
  87. // = 1/6 + 1/6 - 1/36
  88. AssignmentProposition dice2Is5 = new AssignmentProposition(
  89. ExampleRV.DICE_2_RV, 5);
  90. DisjunctiveProposition dice1Is5OrDice2Is5 = new DisjunctiveProposition(
  91. dice1Is5, dice2Is5);
  92. Assert.assertEquals(1.0 / 6.0 + 1.0 / 6.0 - 1.0 / 36.0,
  93. model.prior(dice1Is5OrDice2Is5), DELTA_THRESHOLD);
  94. }
  95. protected void test_ToothacheCavityCatchModel(ProbabilityModel model) {
  96. Assert.assertTrue(model.isValid());
  97. AssignmentProposition atoothache = new AssignmentProposition(
  98. ExampleRV.TOOTHACHE_RV, Boolean.TRUE);
  99. AssignmentProposition anottoothache = new AssignmentProposition(
  100. ExampleRV.TOOTHACHE_RV, Boolean.FALSE);
  101. AssignmentProposition acavity = new AssignmentProposition(
  102. ExampleRV.CAVITY_RV, Boolean.TRUE);
  103. AssignmentProposition anotcavity = new AssignmentProposition(
  104. ExampleRV.CAVITY_RV, Boolean.FALSE);
  105. AssignmentProposition acatch = new AssignmentProposition(
  106. ExampleRV.CATCH_RV, Boolean.TRUE);
  107. AssignmentProposition anotcatch = new AssignmentProposition(
  108. ExampleRV.CATCH_RV, Boolean.FALSE);
  109. // AIMA3e pg. 485
  110. Assert.assertEquals(0.2, model.prior(acavity), DELTA_THRESHOLD);
  111. Assert.assertEquals(0.6, model.posterior(acavity, atoothache),
  112. DELTA_THRESHOLD);
  113. ConjunctiveProposition toothacheAndNotCavity = new ConjunctiveProposition(
  114. atoothache, anotcavity);
  115. Assert.assertEquals(0.0,
  116. model.posterior(acavity, toothacheAndNotCavity),
  117. DELTA_THRESHOLD);
  118. Assert.assertEquals(0.0,
  119. model.posterior(acavity, atoothache, anotcavity),
  120. DELTA_THRESHOLD);
  121. // AIMA3e pg. 492
  122. DisjunctiveProposition cavityOrToothache = new DisjunctiveProposition(
  123. acavity, atoothache);
  124. Assert.assertEquals(0.28, model.prior(cavityOrToothache),
  125. DELTA_THRESHOLD);
  126. // AIMA3e pg. 493
  127. Assert.assertEquals(0.4, model.posterior(anotcavity, atoothache),
  128. DELTA_THRESHOLD);
  129. Assert.assertEquals(1.0, model.prior(ExampleRV.TOOTHACHE_RV),
  130. DELTA_THRESHOLD);
  131. Assert.assertEquals(1.0, model.prior(ExampleRV.CAVITY_RV),
  132. DELTA_THRESHOLD);
  133. Assert.assertEquals(1.0, model.prior(ExampleRV.CATCH_RV),
  134. DELTA_THRESHOLD);
  135. Assert.assertEquals(1.0,
  136. model.posterior(ExampleRV.TOOTHACHE_RV, ExampleRV.CAVITY_RV),
  137. DELTA_THRESHOLD);
  138. Assert.assertEquals(1.0,
  139. model.posterior(ExampleRV.TOOTHACHE_RV, ExampleRV.CATCH_RV),
  140. DELTA_THRESHOLD);
  141. Assert.assertEquals(1.0, model.posterior(ExampleRV.TOOTHACHE_RV,
  142. ExampleRV.CAVITY_RV, ExampleRV.CATCH_RV), DELTA_THRESHOLD);
  143. Assert.assertEquals(1.0,
  144. model.posterior(ExampleRV.CAVITY_RV, ExampleRV.TOOTHACHE_RV),
  145. DELTA_THRESHOLD);
  146. Assert.assertEquals(1.0,
  147. model.posterior(ExampleRV.CAVITY_RV, ExampleRV.CATCH_RV),
  148. DELTA_THRESHOLD);
  149. Assert.assertEquals(1.0, model.posterior(ExampleRV.CAVITY_RV,
  150. ExampleRV.TOOTHACHE_RV, ExampleRV.CATCH_RV), DELTA_THRESHOLD);
  151. Assert.assertEquals(1.0,
  152. model.posterior(ExampleRV.CATCH_RV, ExampleRV.CAVITY_RV),
  153. DELTA_THRESHOLD);
  154. Assert.assertEquals(1.0,
  155. model.posterior(ExampleRV.CATCH_RV, ExampleRV.TOOTHACHE_RV),
  156. DELTA_THRESHOLD);
  157. Assert.assertEquals(1.0, model.posterior(ExampleRV.CATCH_RV,
  158. ExampleRV.CAVITY_RV, ExampleRV.TOOTHACHE_RV), DELTA_THRESHOLD);
  159. // AIMA3e pg. 495 - Bayes' Rule
  160. // P(b|a) = P(a|b)P(b)/P(a)
  161. Assert.assertEquals(model.posterior(acavity, atoothache),
  162. (model.posterior(atoothache, acavity) * model.prior(acavity))
  163. / model.prior(atoothache), DELTA_THRESHOLD);
  164. Assert.assertEquals(
  165. model.posterior(acavity, anottoothache),
  166. (model.posterior(anottoothache, acavity) * model.prior(acavity))
  167. / model.prior(anottoothache), DELTA_THRESHOLD);
  168. Assert.assertEquals(
  169. model.posterior(anotcavity, atoothache),
  170. (model.posterior(atoothache, anotcavity) * model
  171. .prior(anotcavity)) / model.prior(atoothache),
  172. DELTA_THRESHOLD);
  173. Assert.assertEquals(
  174. model.posterior(anotcavity, anottoothache),
  175. (model.posterior(anottoothache, anotcavity) * model
  176. .prior(anotcavity)) / model.prior(anottoothache),
  177. DELTA_THRESHOLD);
  178. //
  179. Assert.assertEquals(model.posterior(acavity, acatch),
  180. (model.posterior(acatch, acavity) * model.prior(acavity))
  181. / model.prior(acatch), DELTA_THRESHOLD);
  182. Assert.assertEquals(model.posterior(acavity, anotcatch),
  183. (model.posterior(anotcatch, acavity) * model.prior(acavity))
  184. / model.prior(anotcatch), DELTA_THRESHOLD);
  185. Assert.assertEquals(model.posterior(anotcavity, acatch),
  186. (model.posterior(acatch, anotcavity) * model.prior(anotcavity))
  187. / model.prior(acatch), DELTA_THRESHOLD);
  188. Assert.assertEquals(
  189. model.posterior(anotcavity, anotcatch),
  190. (model.posterior(anotcatch, anotcavity) * model
  191. .prior(anotcavity)) / model.prior(anotcatch),
  192. DELTA_THRESHOLD);
  193. }
  194. // AIMA3e pg. 488, 494
  195. protected void test_ToothacheCavityCatchWeatherModel(ProbabilityModel model) {
  196. // Should be able to run all the same queries for this independent
  197. // sub model.
  198. test_ToothacheCavityCatchModel(model);
  199. // AIMA3e pg. 486
  200. AssignmentProposition asunny = new AssignmentProposition(
  201. ExampleRV.WEATHER_RV, "sunny");
  202. AssignmentProposition arain = new AssignmentProposition(
  203. ExampleRV.WEATHER_RV, "rain");
  204. AssignmentProposition acloudy = new AssignmentProposition(
  205. ExampleRV.WEATHER_RV, "cloudy");
  206. AssignmentProposition asnow = new AssignmentProposition(
  207. ExampleRV.WEATHER_RV, "snow");
  208. Assert.assertEquals(0.6, model.prior(asunny), DELTA_THRESHOLD);
  209. Assert.assertEquals(0.1, model.prior(arain), DELTA_THRESHOLD);
  210. Assert.assertEquals(0.29, model.prior(acloudy), DELTA_THRESHOLD);
  211. Assert.assertEquals(0.01, model.prior(asnow), DELTA_THRESHOLD);
  212. // AIMA3e pg. 488
  213. // P(sunny, cavity)
  214. // P(sunny AND cavity)
  215. AssignmentProposition atoothache = new AssignmentProposition(
  216. ExampleRV.TOOTHACHE_RV, Boolean.TRUE);
  217. AssignmentProposition acatch = new AssignmentProposition(
  218. ExampleRV.CATCH_RV, Boolean.TRUE);
  219. AssignmentProposition acavity = new AssignmentProposition(
  220. ExampleRV.CAVITY_RV, Boolean.TRUE);
  221. ConjunctiveProposition sunnyAndCavity = new ConjunctiveProposition(
  222. asunny, acavity);
  223. // 0.6 (sunny) * 0.2 (cavity) = 0.12
  224. Assert.assertEquals(0.12, model.prior(asunny, acavity), DELTA_THRESHOLD);
  225. Assert.assertEquals(0.12, model.prior(sunnyAndCavity), DELTA_THRESHOLD);
  226. // AIMA3e pg. 494
  227. // P(toothache, catch, cavity, cloudy) =
  228. // P(cloudy | toothache, catch, cavity)P(toothache, catch, cavity)
  229. Assert.assertEquals(
  230. model.prior(atoothache, acatch, acavity, acloudy),
  231. model.posterior(acloudy, atoothache, acatch, acavity)
  232. * model.prior(atoothache, acatch, acavity),
  233. DELTA_THRESHOLD);
  234. ConjunctiveProposition toothacheAndCatchAndCavityAndCloudy = new ConjunctiveProposition(
  235. new ConjunctiveProposition(atoothache, acatch),
  236. new ConjunctiveProposition(acavity, acloudy));
  237. ConjunctiveProposition toothacheAndCatchAndCavity = new ConjunctiveProposition(
  238. new ConjunctiveProposition(atoothache, acatch), acavity);
  239. Assert.assertEquals(
  240. model.prior(toothacheAndCatchAndCavityAndCloudy),
  241. model.posterior(acloudy, atoothache, acatch, acavity)
  242. * model.prior(toothacheAndCatchAndCavity),
  243. DELTA_THRESHOLD);
  244. // P(cloudy | toothache, catch, cavity) = P(cloudy)
  245. // (13.10)
  246. Assert.assertEquals(
  247. model.posterior(acloudy, atoothache, acatch, acavity),
  248. model.prior(acloudy), DELTA_THRESHOLD);
  249. // P(toothache, catch, cavity, cloudy) =
  250. // P(cloudy)P(tootache, catch, cavity)
  251. Assert.assertEquals(
  252. model.prior(atoothache, acatch, acavity, acloudy),
  253. model.prior(acloudy) * model.prior(atoothache, acatch, acavity),
  254. DELTA_THRESHOLD);
  255. // P(a | b) = P(a)
  256. Assert.assertEquals(model.posterior(acavity, acloudy),
  257. model.prior(acavity), DELTA_THRESHOLD);
  258. // P(b | a) = P(b)
  259. Assert.assertEquals(model.posterior(acloudy, acavity),
  260. model.prior(acloudy), DELTA_THRESHOLD);
  261. // P(a AND b) = P(a)P(b)
  262. Assert.assertEquals(model.prior(acavity, acloudy), model.prior(acavity)
  263. * model.prior(acloudy), DELTA_THRESHOLD);
  264. ConjunctiveProposition acavityAndacloudy = new ConjunctiveProposition(
  265. acavity, acloudy);
  266. Assert.assertEquals(model.prior(acavityAndacloudy),
  267. model.prior(acavity) * model.prior(acloudy), DELTA_THRESHOLD);
  268. }
  269. // AIMA3e pg. 496
  270. protected void test_MeningitisStiffNeckModel(ProbabilityModel model) {
  271. Assert.assertTrue(model.isValid());
  272. AssignmentProposition ameningitis = new AssignmentProposition(
  273. ExampleRV.MENINGITIS_RV, true);
  274. AssignmentProposition anotmeningitis = new AssignmentProposition(
  275. ExampleRV.MENINGITIS_RV, false);
  276. AssignmentProposition astiffNeck = new AssignmentProposition(
  277. ExampleRV.STIFF_NECK_RV, true);
  278. AssignmentProposition anotstiffNeck = new AssignmentProposition(
  279. ExampleRV.STIFF_NECK_RV, false);
  280. // P(stiffNeck | meningitis) = 0.7
  281. Assert.assertEquals(0.7, model.posterior(astiffNeck, ameningitis),
  282. DELTA_THRESHOLD);
  283. // P(meningitis) = 1/50000
  284. Assert.assertEquals(0.00002, model.prior(ameningitis), DELTA_THRESHOLD);
  285. // P(~meningitis) = 1-1/50000
  286. Assert.assertEquals(0.99998, model.prior(anotmeningitis),
  287. DELTA_THRESHOLD);
  288. // P(stiffNeck) = 0.01
  289. Assert.assertEquals(0.01, model.prior(astiffNeck), DELTA_THRESHOLD);
  290. // P(~stiffNeck) = 0.99
  291. Assert.assertEquals(0.99, model.prior(anotstiffNeck), DELTA_THRESHOLD);
  292. // P(meningitis | stiffneck)
  293. // = P(stiffneck | meningitis)P(meningitis)/P(stiffneck)
  294. // = (0.7 * 0.00002)/0.01
  295. // = 0.0014 (13.4)
  296. Assert.assertEquals(0.0014, model.posterior(ameningitis, astiffNeck),
  297. DELTA_THRESHOLD);
  298. // Assuming P(~stiffneck | meningitis) = 0.3 (pg. 497), i.e. CPT (row
  299. // must = 1)
  300. //
  301. // P(meningitis | ~stiffneck)
  302. // = P(~stiffneck | meningitis)P(meningitis)/P(~stiffneck)
  303. // = (0.3 * 0.00002)/0.99
  304. // = 0.000006060606
  305. Assert.assertEquals(0.000006060606,
  306. model.posterior(ameningitis, anotstiffNeck), DELTA_THRESHOLD);
  307. }
  308. // AIMA3e pg. 512
  309. protected void test_BurglaryAlarmModel(ProbabilityModel model) {
  310. Assert.assertTrue(model.isValid());
  311. AssignmentProposition aburglary = new AssignmentProposition(
  312. ExampleRV.BURGLARY_RV, Boolean.TRUE);
  313. AssignmentProposition anotburglary = new AssignmentProposition(
  314. ExampleRV.BURGLARY_RV, Boolean.FALSE);
  315. AssignmentProposition anotearthquake = new AssignmentProposition(
  316. ExampleRV.EARTHQUAKE_RV, Boolean.FALSE);
  317. AssignmentProposition aalarm = new AssignmentProposition(
  318. ExampleRV.ALARM_RV, Boolean.TRUE);
  319. AssignmentProposition anotalarm = new AssignmentProposition(
  320. ExampleRV.ALARM_RV, Boolean.FALSE);
  321. AssignmentProposition ajohnCalls = new AssignmentProposition(
  322. ExampleRV.JOHN_CALLS_RV, Boolean.TRUE);
  323. AssignmentProposition amaryCalls = new AssignmentProposition(
  324. ExampleRV.MARY_CALLS_RV, Boolean.TRUE);
  325. // AIMA3e pg. 514
  326. Assert.assertEquals(0.00062811126, model.prior(ajohnCalls, amaryCalls,
  327. aalarm, anotburglary, anotearthquake), DELTA_THRESHOLD);
  328. Assert.assertEquals(0.00049800249, model.prior(ajohnCalls, amaryCalls,
  329. anotalarm, anotburglary, anotearthquake), DELTA_THRESHOLD);
  330. // AIMA3e pg. 524
  331. // P(Burglary = true | JohnCalls = true, MaryCalls = true) = 0.00059224
  332. Assert.assertEquals(0.00059224,
  333. model.prior(aburglary, ajohnCalls, amaryCalls), DELTA_THRESHOLD);
  334. // P(Burglary = false | JohnCalls = true, MaryCalls = true) = 0.0014919
  335. Assert.assertEquals(0.00149185764899,
  336. model.prior(anotburglary, ajohnCalls, amaryCalls),
  337. DELTA_THRESHOLD);
  338. }
  339. }