/drbayes/tests/lyric-tests/tests.rkt

http://github.com/ntoronto/plt-stuff · Racket · 266 lines · 135 code · 38 blank · 93 comment · 10 complexity · e3565c3beecddaa18e29a6d9087b7c7c MD5 · raw file

  1. #lang typed/racket
  2. (require plot/typed
  3. math/distributions
  4. math/statistics
  5. math/flonum
  6. "../../main.rkt"
  7. "../test-utils.rkt"
  8. "../profile.rkt"
  9. "../normal-normal.rkt")
  10. (printf "starting...~n~n")
  11. (error-print-width 1024)
  12. (interval-max-splits 5)
  13. (define n 1000)
  14. #;; Test: Normal-Normal model
  15. ;; Preimage should be a banana shape
  16. (begin
  17. (interval-max-splits 2)
  18. ;(interval-min-length (expt 0.5 1.0))
  19. (define/drbayes e
  20. (let* ([x (normal 0 1)]
  21. [y (normal x 1)])
  22. (list x y)))
  23. (define B (set-list reals (real-set 0.9 1.1)))
  24. (normal-normal/lw 0 1 '(1.0) '(1.0)))
  25. #;; Test: Normal-Normal model with circular condition
  26. ;; Preimage should look like a football set up for a field goal
  27. (begin
  28. (interval-max-splits 0)
  29. (define/drbayes (hypot x y)
  30. (sqrt (+ (sqr x) (sqr y))))
  31. (define/drbayes e
  32. (let* ([x0 (normal 0 1)]
  33. [x1 (normal x0 1)])
  34. (list x0 x1 (hypot x0 x1))))
  35. (define B (set-list reals reals (real-set 0.99 1.01))))
  36. #;; Test: thermometer that goes to 100
  37. (begin
  38. (interval-max-splits 4)
  39. (define e
  40. (drbayes
  41. (let* ([x (normal 90 10)]
  42. [y (normal x 1)])
  43. (list x (if (y . > . 100) 100 y)))))
  44. (define B (set-list reals (real-set 99.0 100.0))))
  45. (define ε 0.5)
  46. (: interval-near (Flonum -> Set))
  47. (define (interval-near x)
  48. (real-set (- x ε) (+ x ε)))
  49. #;; Test: Normal-Normal model with more observations
  50. ;; Density plot, mean, and stddev should be similar to those produced by `normal-normal/lw'
  51. (begin
  52. (interval-max-splits 2)
  53. ;(interval-min-length (flexpt 0.5 5.0))
  54. (define/drbayes e
  55. (let ([x (normal 0 1)])
  56. (list x
  57. (normal x 1)
  58. (normal x 1)
  59. (normal x 1)
  60. (normal x 1)
  61. (normal x 1)
  62. (normal x 1))))
  63. (define B
  64. (set-list reals
  65. (interval-near 3.3)
  66. (interval-near 2.0)
  67. (interval-near 1.0)
  68. (interval-near 0.2)
  69. (interval-near 1.5)
  70. (interval-near 2.4)))
  71. (normal-normal/lw 0 1 '(3.3 2.0 1.0 0.2 1.5 2.4) '(1.0 1.0 1.0 1.0 1.0 1.0)))
  72. ;; ===================================================================================================
  73. (define-values (f h idxs)
  74. (match-let ([(meaning _ f h k) e])
  75. (values (run/bot* f '()) (run/pre* h '()) (k '()))))
  76. (define (empty-set-error)
  77. (error 'drbayes-sample "cannot sample from the empty set"))
  78. (define refine
  79. (if (empty-set? B) (empty-set-error) (preimage-refiner h B)))
  80. (define S
  81. (let ([S (refine (cons omegas traces))])
  82. (if (empty-set? S) (empty-set-error) S)))
  83. (match-define (cons R T) S)
  84. (printf "idxs = ~v~n" idxs)
  85. (printf "R = ~v~n" R)
  86. (printf "T = ~v~n" T)
  87. (newline)
  88. (struct: domain-sample ([S : Nonempty-Store-Rect]
  89. [s : Store]
  90. [b : Maybe-Value]
  91. [measure : Flonum]
  92. [prob : Flonum]
  93. [point-prob : Flonum]
  94. [weight : Flonum])
  95. #:transparent)
  96. (: accept-sample? (domain-sample -> Boolean))
  97. (define (accept-sample? s)
  98. (define b (domain-sample-b s))
  99. (and (not (bottom? b))
  100. (set-member? B b)))
  101. (: orig-samples (Listof store-rect-sample))
  102. (define orig-samples
  103. (time
  104. ;profile-expr
  105. (refinement-sample* S idxs refine n)))
  106. (: all-samples (Listof domain-sample))
  107. (define all-samples
  108. (time
  109. ;profile-expr
  110. (let: loop : (Listof domain-sample) ([orig-samples : (Listof store-rect-sample) orig-samples])
  111. (cond
  112. [(empty? orig-samples) empty]
  113. [else
  114. (define s (first orig-samples))
  115. (match-define (store-rect-sample S m p) s)
  116. (define pt (refinement-sample-point S idxs refine))
  117. ;(match-define (cons R T) S)
  118. ;(define r (omega-set-sample-point R))
  119. ;(define t (trace-set-sample-point T))
  120. ;(define pt (store-sample (cons r t) m))
  121. (match pt
  122. [(store-sample s q)
  123. (define b (f (cons s null)))
  124. (cons (domain-sample S s b m p q (/ q p)) (loop (rest orig-samples)))]
  125. [_
  126. (define r (omega-set-sample-point R))
  127. (define t (trace-set-sample-point T))
  128. (define s (cons r t))
  129. (define b (bottom (delay "refinement-sample-point failed")))
  130. (cons (domain-sample S s b m p m (/ m p)) (loop (rest orig-samples)))])]))))
  131. (newline)
  132. (define samples (filter accept-sample? all-samples))
  133. (define ws (map domain-sample-weight samples))
  134. (define ps (map domain-sample-prob samples))
  135. (define ms (map domain-sample-measure samples))
  136. (define not-samples (filter (compose not accept-sample?) all-samples))
  137. (define num-all-samples (length all-samples))
  138. (define num-samples (length samples))
  139. (define num-not-samples (length not-samples))
  140. (define accept-prob (fl (/ num-samples num-all-samples)))
  141. (printf "search stats:~n")
  142. (get-search-stats)
  143. (newline)
  144. #|
  145. (printf "cache stats:~n")
  146. (get-cache-stats)
  147. (newline)
  148. |#
  149. (printf "unique numbers of primitive rvs: ~v~n"
  150. (sort
  151. (remove-duplicates
  152. (map (λ: ([d : domain-sample])
  153. (length (omega-set->list (car (domain-sample-S d)))))
  154. all-samples))
  155. <))
  156. (newline)
  157. (printf "accepted samples: ~v (~v%)~n" (length samples) (* 100.0 accept-prob))
  158. (newline)
  159. (define all-alpha (min 1.0 (/ 250.0 (fl num-all-samples))))
  160. (define alpha (min 1.0 (/ 250.0 (fl num-samples))))
  161. (plot-z-ticks no-ticks)
  162. (plot3d (list (rectangles3d (append*
  163. (map (λ: ([d : domain-sample])
  164. (omega-rect->plot-rects (car (domain-sample-S d))))
  165. not-samples))
  166. #:alpha all-alpha #:color 1 #:line-color 1)
  167. (rectangles3d (append*
  168. (map (λ: ([d : domain-sample])
  169. (omega-rect->plot-rects (car (domain-sample-S d))))
  170. samples))
  171. #:alpha all-alpha #:color 3 #:line-color 3))
  172. #:x-min 0 #:x-max 1 #:y-min 0 #:y-max 1 #:z-min 0 #:z-max 1)
  173. (: domain-sample->omega-point (domain-sample -> (Listof Flonum)))
  174. (define (domain-sample->omega-point d)
  175. (omega->point (car (domain-sample-s d))))
  176. (plot3d (list (points3d (map domain-sample->omega-point not-samples)
  177. #:sym 'dot #:size 12 #:alpha all-alpha #:color 1 #:fill-color 1)
  178. (points3d (map domain-sample->omega-point samples)
  179. #:sym 'dot #:size 12 #:alpha all-alpha #:color 3 #:fill-color 3))
  180. #:x-min 0 #:x-max 1 #:y-min 0 #:y-max 1 #:z-min 0 #:z-max 1
  181. #:x-label "x1" #:y-label "x2" #:z-label "x3")
  182. (plot3d (points3d (sample (discrete-dist (map domain-sample->omega-point samples) ws)
  183. num-samples)
  184. #:sym 'dot #:size 12 #:alpha alpha)
  185. #:x-min 0 #:x-max 1 #:y-min 0 #:y-max 1 #:z-min 0 #:z-max 1
  186. #:x-label "x1" #:y-label "x2" #:z-label "x3")
  187. (: xss (Listof (Listof Flonum)))
  188. (define xss
  189. (map (λ: ([d : domain-sample])
  190. (define lst (value->listof-flonum (cast (domain-sample-b d) Value)))
  191. (maybe-pad-list lst 3 random))
  192. samples))
  193. (with-handlers ([exn? (λ (_) (printf "image points scatter plot failed~n"))])
  194. (plot3d (points3d xss #:sym 'dot #:size 12 #:alpha alpha)
  195. #:x-label "x1" #:y-label "x2" #:z-label "x3"))
  196. (with-handlers ([exn? (λ (_) (printf "resampled image points scatter plot failed~n"))])
  197. (plot3d (points3d (sample (discrete-dist xss ws) num-samples)
  198. #:sym 'dot #:size 12 #:alpha alpha)
  199. #:x-label "x1" #:y-label "x2" #:z-label "x3"))
  200. (define x0s (map (inst first Flonum Flonum) xss))
  201. (with-handlers ([exn? (λ (_) (printf "weight density plot failed~n"))])
  202. (plot (density ws) #:x-label "weight" #:y-label "density"))
  203. (with-handlers ([exn? (λ (_) (printf "weight/measure scatter plot failed~n"))])
  204. (plot (points (map (λ: ([w : Flonum] [m : Flonum]) (list w m)) ws ms)
  205. #:sym 'dot #:size 12 #:alpha alpha)
  206. #:x-label "weight" #:y-label "measure"))
  207. (printf "Corr(W,M) = ~v~n" (correlation ws ms))
  208. (with-handlers ([exn? (λ (_) (printf "density plot failed~n"))])
  209. (plot (density (sample (discrete-dist x0s ws) num-samples) 2)
  210. #:x-label "x0" #:y-label "density"))
  211. (printf "E[x0] = ~v~n" (mean x0s (ann ws (Sequenceof Real))))
  212. (printf "sd[x0] = ~v~n" (stddev x0s (ann ws (Sequenceof Real))))