/tests/testthat/test-rvar-.R

https://github.com/jgabry/posterior · R · 237 lines · 171 code · 45 blank · 21 comment · 1 complexity · 15c36559470fc6c211e92f7f3c84e1b5 MD5 · raw file

  1. # function for making rvars from arrays that expects last index to be
  2. # draws (for testing so that when array structure changes tests don't have to)
  3. rvar_from_array = function(x) {
  4. .dim = dim(x)
  5. last_dim = length(.dim)
  6. new_rvar(aperm(x, c(last_dim, seq_len(last_dim - 1))))
  7. }
  8. # creating rvars ----------------------------------------------------------
  9. test_that("rvar creation with custom dim works", {
  10. x_matrix <- array(1:24, dim = c(2,12))
  11. x_array <- array(1:24, dim = c(2,3,4))
  12. expect_equal(rvar(x_matrix, dim = c(3,4)), rvar(x_array))
  13. })
  14. test_that("rvar can be created with specified number of chains", {
  15. x_array <- array(1:20, dim = c(4,5))
  16. expect_error(rvar(x_array, nchains = 0))
  17. expect_equal(rvar(x_array, nchains = 1), rvar(x_array))
  18. expect_equal(nchains(rvar(x_array, nchains = 2)), 2)
  19. expect_error(rvar(x_array, nchains = 3), "Number of chains does not divide the number of draws")
  20. })
  21. test_that("rvar constructor using with_chains works", {
  22. # multidimensional rvar with chains
  23. x_array_nochains <- array(1:24, dim = c(6,2,2), dimnames = list(
  24. NULL, A = c("a1", "a2"), B = c("b1", "b2")
  25. ))
  26. x_array_chains <- array(1:24, dim = c(3,2,2,2), dimnames = list(
  27. NULL, NULL, A = c("a1", "a2"), B = c("b1", "b2")
  28. ))
  29. x_nochains <- rvar(x_array_nochains, nchains = 2)
  30. x_chains <- rvar(x_array_chains, with_chains = TRUE)
  31. expect_equal(x_chains, x_nochains)
  32. # scalar rvar with chains
  33. x2_array_nochains <- 1:24
  34. x2_array_chains <- array(1:24, dim = c(6,4))
  35. x2_nochains <- rvar(x2_array_nochains, nchains = 4)
  36. x2_chains <- rvar(x2_array_chains, with_chains = TRUE)
  37. expect_equal(x2_chains, x2_nochains)
  38. # NULL rvar
  39. expect_equal(rvar(with_chains = TRUE), rvar())
  40. # can't use with_chains when no chain dimension information provided
  41. expect_error(rvar(1, with_chains = TRUE))
  42. })
  43. # draws_of ----------------------------------------------------------------
  44. test_that("draws_of using with_chains works", {
  45. # retrieving a multidimensional rvar with draws_of using with_chains
  46. x_array_nochains <- array(1:24, dim = c(6,2,2), dimnames = list(
  47. NULL, A = c("a1", "a2"), B = c("b1", "b2")
  48. ))
  49. x_array_chains <- array(1:24, dim = c(3,2,2,2), dimnames = list(
  50. NULL, NULL, A = c("a1", "a2"), B = c("b1", "b2")
  51. ))
  52. x <- rvar(x_array_nochains, nchains = 2)
  53. expect_equal(draws_of(x, with_chains = TRUE), x_array_chains)
  54. # setting a multidimensional rvar with draws_of using with_chains
  55. x2_array_nochains <- x_array_nochains + 2
  56. x2_array_chains <- array(1:24 + 2, dim = c(2,3,2,2), dimnames = list(
  57. NULL, NULL, A = c("a1", "a2"), B = c("b1", "b2")
  58. ))
  59. x2 <- x
  60. draws_of(x2, with_chains = TRUE) <- x2_array_chains
  61. expect_equal(x2, rvar(x2_array_nochains, nchains = 3))
  62. # retrieving a scalar rvar with draws_of using with_chains
  63. x2_array_nochains <- 1:24
  64. x2_array_chains <- array(1:24, dim = c(6,4,1), dimnames = list(NULL))
  65. x2 <- rvar(x2_array_nochains, nchains = 4)
  66. expect_equal(draws_of(x2, with_chains = TRUE), x2_array_chains)
  67. # setting a scalar rvar with draws_of using with_chains
  68. x3_array_nochains <- 1:24 + 2
  69. x3_array_chains <- array(1:24 + 2, dim = c(12,2), dimnames = list(NULL))
  70. x3 <- x2
  71. draws_of(x3, with_chains = TRUE) <- x3_array_chains
  72. expect_equal(x3, rvar(x3_array_nochains, nchains = 2))
  73. # NULL rvar
  74. expect_equal(draws_of(rvar(), with_chains = TRUE), array(numeric(), dim = c(1,1,0), dimnames = list(NULL)))
  75. x_null <- x
  76. draws_of(x_null, with_chains = TRUE) = numeric()
  77. expect_equal(x_null, rvar())
  78. # can't use with_chains when no chain dimension information provided
  79. expect_error(draws_of(x, with_chains = TRUE) <- 1)
  80. })
  81. # unique, duplicated, etc -------------------------------------------------
  82. test_that("unique.rvar and duplicated.rvar work", {
  83. x <- rvar_from_array(matrix(c(1,2,1, 1,2,1, 3,3,3), nrow = 3))
  84. unique_x <- rvar_from_array(matrix(c(1,2, 1,2, 3,3), nrow = 2))
  85. expect_equal(unique(x), unique_x)
  86. expect_equal(as.vector(duplicated(x)), c(FALSE, FALSE, TRUE))
  87. expect_equal(anyDuplicated(x), 3)
  88. x <- rvar(array(c(1,2, 2,3, 1,2, 3,3, 1,2, 2,3), dim = c(2, 2, 3)))
  89. unique_x <- x
  90. unique_x_2 <- rvar(array(c(1,2, 2,3, 1,2, 3,3), dim = c(2, 2, 2)))
  91. expect_equal(unique(x), unique_x)
  92. expect_equal(unique(x, MARGIN = 2), unique_x_2)
  93. })
  94. # tibbles -----------------------------------------------------------------
  95. test_that("rvars work in tibbles", {
  96. skip_if_not_installed("dplyr")
  97. skip_if_not_installed("tidyr")
  98. x_array = array(1:20, dim = c(4,5))
  99. x = rvar_from_array(x_array)
  100. df = tibble::tibble(x, y = x + 1)
  101. expect_equal(df$x, x)
  102. expect_equal(df$y, rvar_from_array(x_array + 1))
  103. expect_equal(dplyr::mutate(df, z = x)$z, x)
  104. expect_equal(dplyr::mutate(df, z = x * 2)$z, rvar_from_array(x_array * 2))
  105. expect_equal(
  106. dplyr::mutate(dplyr::group_by(df, 1:4), z = x * 2)$z,
  107. rvar_from_array(x_array * 2)
  108. )
  109. df = tibble::tibble(g = letters[1:4], x)
  110. ref = tibble::tibble(
  111. a = rvar_from_array(x_array[1,, drop = FALSE]),
  112. b = rvar_from_array(x_array[2,, drop = FALSE]),
  113. c = rvar_from_array(x_array[3,, drop = FALSE]),
  114. d = rvar_from_array(x_array[4,, drop = FALSE])
  115. )
  116. expect_equal(tidyr::pivot_wider(df, names_from = g, values_from = x), ref)
  117. expect_equal(tidyr::pivot_longer(ref, a:d, names_to = "g", values_to = "x"), df)
  118. df$y = df$x + 1
  119. ref2 = tibble::tibble(
  120. y = df$y,
  121. a = c(df$x[[1]], NA, NA, NA),
  122. b = c(rvar(NA), df$x[[2]], NA, NA),
  123. c = c(rvar(NA), NA, df$x[[3]], NA),
  124. d = c(rvar(NA), NA, NA, df$x[[4]]),
  125. )
  126. expect_equal(tidyr::pivot_wider(df, names_from = g, values_from = x), ref2)
  127. })
  128. # broadcasting ------------------------------------------------------------
  129. test_that("broadcast_array works", {
  130. expect_equal(broadcast_array(5, c(1,2,3,1)), array(rep(5, 6), dim = c(1,2,3,1)))
  131. expect_equal(
  132. broadcast_array(array(1:4, c(1,4), dimnames = list("x", letters[1:4])), c(2,4)),
  133. array(rep(1:4, each = 2), c(2,4), dimnames = list(NULL, letters[1:4]))
  134. )
  135. expect_equal(
  136. broadcast_array(array(1:4, c(4,1)), c(4,2)),
  137. array(c(1:4, 1:4), c(4,2))
  138. )
  139. expect_equal(
  140. broadcast_array(array(1:2, dimnames = list(c("a","b"))), c(2,1,1,1)),
  141. array(1:2, c(2,1,1,1), dimnames = list(c("a","b"), NULL, NULL, NULL))
  142. )
  143. expect_error(broadcast_array(array(1:9, dim = c(3,3)), c(1,9)))
  144. expect_error(broadcast_array(array(1:9, dim = c(3,3)), c(9)))
  145. })
  146. # conforming chains / draws -----------------------------------------------
  147. test_that("warnings for unequal draws/chains are correct", {
  148. options(posterior.warn_on_merge_chains = TRUE)
  149. expect_warning(
  150. expect_equal(rvar(1:10) + rvar(1:10, nchains = 2), rvar(1:10 + 1:10)),
  151. "Chains were dropped due to chain information not matching"
  152. )
  153. options(posterior.warn_on_merge_chains = FALSE)
  154. expect_error(
  155. draws_rvars(x = rvar(1:10), y = rvar(1:11)),
  156. "variables have different number of draws"
  157. )
  158. expect_error(
  159. rvar(1:10, nchains = 0),
  160. "chains must be >= 1"
  161. )
  162. })
  163. # rep ---------------------------------------------------------------------
  164. test_that("rep works", {
  165. x_array = array(1:10, dim = c(5,2))
  166. x = rvar(x_array)
  167. expect_equal(rep(x, times = 3), new_rvar(cbind(x_array, x_array, x_array)))
  168. expect_equal(rep.int(x, 3), new_rvar(cbind(x_array, x_array, x_array)))
  169. each_twice = cbind(x_array[,1], x_array[,1], x_array[,2], x_array[,2])
  170. expect_equal(rep(x, each = 2), new_rvar(each_twice))
  171. expect_equal(rep(x, each = 2, times = 3), new_rvar(cbind(each_twice, each_twice, each_twice)))
  172. expect_equal(rep(x, length.out = 3), new_rvar(cbind(x_array, x_array[,1])))
  173. expect_equal(rep_len(x, 3), new_rvar(cbind(x_array, x_array[,1])))
  174. })
  175. # all.equal ---------------------------------------------------------------------
  176. test_that("all.equal works", {
  177. x_array = array(1:10, dim = c(5,2))
  178. x = rvar(x_array)
  179. expect_true(all.equal(x, x))
  180. expect_true(!isTRUE(all.equal(x, x + 1)))
  181. expect_true(!isTRUE(all.equal(x, "a")))
  182. })
  183. # apply functions ---------------------------------------------------------
  184. test_that("apply family functions work", {
  185. x_array = array(1:24, dim = c(2,3,4))
  186. x = rvar(x_array)
  187. expect_equal(lapply(x, function(x) sum(draws_of(x))), as.list(apply(draws_of(x), 2, sum)))
  188. expect_equal(sapply(x, function(x) sum(draws_of(x))), apply(draws_of(x), 2, sum))
  189. expect_equal(vapply(x, function(x) sum(draws_of(x)), numeric(1)), apply(draws_of(x), 2, sum))
  190. expect_equal(apply(x, 1, function(x) sum(draws_of(x))), apply(draws_of(x), 2, sum))
  191. expect_equal(apply(x, 1:2, function(x) sum(draws_of(x))), apply(draws_of(x), 2:3, sum))
  192. })