/tests/testthat/test-embeddings.R

https://github.com/tidymodels/textrecipes · R · 352 lines · 300 code · 42 blank · 10 comment · 2 complexity · bbb0b39707258cd7be05df8db6306aa3 MD5 · raw file

  1. library(recipes)
  2. eps <- if (capabilities("long.double")) {
  3. sqrt(.Machine$double.eps)
  4. } else {
  5. 0.1
  6. }
  7. # Set up the data that will be used in these tests. -----------------------
  8. test_data <- tibble(text = c(
  9. "I would not eat them here or there.",
  10. "I would not eat them anywhere.",
  11. "I would not eat green eggs and ham.",
  12. "I do not like them, Sam-I-am."
  13. ))
  14. rec_base <- recipe(~., data = test_data)
  15. # Create some manual data for expected results.
  16. tokens <- rec_base %>%
  17. step_tokenize(text) %>%
  18. recipes::prep() %>%
  19. recipes::bake(new_data = NULL) %>%
  20. vctrs::vec_cbind(rename(test_data, text1 = text)) %>%
  21. dplyr::select(text = text1, tokens = text)
  22. # Give each token an arbitrary value for comparison. Real embeddings will be
  23. # doubles, so make these double.
  24. embeddings <- tokens %>%
  25. dplyr::mutate(tokens = vctrs::field(tokens, "tokens")) %>%
  26. tidyr::unnest(tokens) %>%
  27. dplyr::distinct(tokens) %>%
  28. dplyr::arrange(tokens) %>%
  29. # There are 17 unique tokens. We'll represent them with a 5-d set of vectors
  30. # so each one can be unique.
  31. dplyr::mutate(
  32. token_num_binary = purrr::map(
  33. seq_along(tokens),
  34. function(this_token) {
  35. tibble(
  36. dimension = paste0("d", 1:5),
  37. score = as.double(intToBits(this_token)[1:5])
  38. )
  39. }
  40. )
  41. ) %>%
  42. tidyr::unnest(token_num_binary) %>%
  43. tidyr::pivot_wider(
  44. names_from = dimension,
  45. values_from = score
  46. )
  47. sentence_embeddings_long <- tokens %>%
  48. dplyr::mutate(tokens = vctrs::field(tokens, "tokens")) %>%
  49. tidyr::unnest(tokens) %>%
  50. dplyr::left_join(embeddings, by = "tokens")
  51. # Summarize by each statistic, and reorder to original order.
  52. sentence_embeddings_sum <- sentence_embeddings_long %>%
  53. dplyr::select(-tokens) %>%
  54. dplyr::group_by(text) %>%
  55. dplyr::summarize_all(sum) %>%
  56. dplyr::rename_if(
  57. is.numeric,
  58. ~ paste("w_embed", "sum", ., sep = "_")
  59. )
  60. sentence_embeddings_sum <- test_data %>%
  61. dplyr::left_join(sentence_embeddings_sum, by = "text")
  62. sentence_embeddings_mean <- sentence_embeddings_long %>%
  63. dplyr::select(-tokens) %>%
  64. dplyr::group_by(text) %>%
  65. dplyr::summarize_all(mean) %>%
  66. dplyr::rename_if(
  67. is.numeric,
  68. ~ paste("w_embed", "mean", ., sep = "_")
  69. )
  70. sentence_embeddings_mean <- test_data %>%
  71. dplyr::left_join(sentence_embeddings_mean, by = "text")
  72. sentence_embeddings_min <- sentence_embeddings_long %>%
  73. dplyr::select(-tokens) %>%
  74. dplyr::group_by(text) %>%
  75. dplyr::summarize_all(min) %>%
  76. dplyr::rename_if(
  77. is.numeric,
  78. ~ paste("w_embed", "min", ., sep = "_")
  79. )
  80. sentence_embeddings_min <- test_data %>%
  81. dplyr::left_join(sentence_embeddings_min, by = "text")
  82. sentence_embeddings_max <- sentence_embeddings_long %>%
  83. dplyr::select(-tokens) %>%
  84. dplyr::group_by(text) %>%
  85. dplyr::summarize_all(max) %>%
  86. dplyr::rename_if(
  87. is.numeric,
  88. ~ paste("w_embed", "max", ., sep = "_")
  89. )
  90. sentence_embeddings_max <- test_data %>%
  91. dplyr::left_join(sentence_embeddings_max, by = "text")
  92. rec <- rec_base %>%
  93. step_tokenize(text) %>%
  94. step_word_embeddings(text, embeddings = embeddings)
  95. obj <- rec %>%
  96. prep()
  97. juiced <- bake(obj, new_data = NULL)
  98. test_that("step_word_embeddings adds the appropriate number of columns.", {
  99. ncol_given <- ncol(embeddings) - 1L
  100. ncol_juiced <- juiced %>%
  101. select(contains("w_embed_")) %>%
  102. ncol()
  103. expect_identical(ncol_juiced, ncol_given)
  104. })
  105. test_that("step_word_embeddings gives numeric output.", {
  106. expect_true(
  107. juiced %>%
  108. select(contains("embedding")) %>%
  109. lapply(is.numeric) %>%
  110. unlist() %>%
  111. all()
  112. )
  113. })
  114. # Run the tests. ----------------------------------------------------------
  115. test_that("step_word_embeddings tidy method works.", {
  116. rec_tidied <- tidy(rec, 2)
  117. obj_tidied <- tidy(obj, 2)
  118. expected_cols <- c("terms", "embeddings_rows", "aggregation", "id")
  119. expect_equal(dim(rec_tidied), c(1, 4))
  120. expect_equal(dim(obj_tidied), c(1, 4))
  121. expect_identical(colnames(rec_tidied), expected_cols)
  122. expect_identical(colnames(obj_tidied), expected_cols)
  123. expect_identical(rec_tidied$embeddings_rows, 17L)
  124. expect_identical(rec_tidied$aggregation, "sum")
  125. })
  126. test_that("step_word_embeddings aggregates vectors as expected.", {
  127. # By default, step_word_embeddings sums the vectors of the tokens it is given.
  128. expect_equal(
  129. as.data.frame(juiced),
  130. as.data.frame(select(sentence_embeddings_sum, -text)),
  131. tolerance = eps
  132. )
  133. # Also allow the user to choose an aggregation function.
  134. juiced_max <- rec_base %>%
  135. step_tokenize(text) %>%
  136. step_word_embeddings(
  137. text,
  138. embeddings = embeddings, aggregation = "max"
  139. ) %>%
  140. prep() %>%
  141. bake(new_data = NULL)
  142. expect_equal(
  143. as.data.frame(juiced_max),
  144. as.data.frame(select(sentence_embeddings_max, -text)),
  145. tolerance = eps
  146. )
  147. juiced_min <- rec_base %>%
  148. step_tokenize(text) %>%
  149. step_word_embeddings(
  150. text,
  151. embeddings = embeddings, aggregation = "min"
  152. ) %>%
  153. prep() %>%
  154. bake(new_data = NULL)
  155. expect_equal(
  156. as.data.frame(juiced_min),
  157. as.data.frame(select(sentence_embeddings_min, -text)),
  158. tolerance = eps
  159. )
  160. juiced_mean <- rec_base %>%
  161. step_tokenize(text) %>%
  162. step_word_embeddings(
  163. text,
  164. embeddings = embeddings, aggregation = "mean"
  165. ) %>%
  166. prep() %>%
  167. bake(new_data = NULL)
  168. expect_equal(
  169. as.data.frame(juiced_mean),
  170. as.data.frame(select(sentence_embeddings_mean, -text)),
  171. tolerance = eps
  172. )
  173. })
  174. test_that("step_word_embeddings deals with missing words appropriately.", {
  175. new_text <- tibble(
  176. text = c(
  177. "I would not eat red beans and rice.",
  178. "I do not like them, they're not nice."
  179. )
  180. )
  181. expect_warning(
  182. bake(obj, new_data = new_text),
  183. NA
  184. )
  185. expect_warning(
  186. bake(obj, new_data = new_text),
  187. NA
  188. )
  189. expect_warning(
  190. bake(obj, new_data = test_data),
  191. NA
  192. )
  193. new_text <- tibble(
  194. text = "aksjdf nagjli aslkfa"
  195. )
  196. expect_error(
  197. bake(obj, new_data = new_text),
  198. NA
  199. )
  200. })
  201. test_that("printing", {
  202. expect_output(
  203. print(rec),
  204. "Word embeddings aggregated from text"
  205. )
  206. expect_output(
  207. prep(rec, verbose = TRUE)
  208. )
  209. })
  210. test_that("NA tokens work.", {
  211. new_text <- tibble(
  212. text = c("am", "and", NA)
  213. )
  214. test_result <- bake(obj, new_data = new_text)
  215. expected_result <- rbind(
  216. bake(obj, new_data = new_text[1:2, ]),
  217. c(0, 0, 0, 0, 0)
  218. )
  219. expect_identical(test_result, expected_result)
  220. })
  221. test_that("Embeddings work with empty documents", {
  222. empty_data <- data.frame(text = "")
  223. expect_equal(
  224. recipe(~text, data = empty_data) %>%
  225. step_tokenize(text) %>%
  226. step_word_embeddings(text, embeddings = embeddings, aggregation = "sum") %>%
  227. prep() %>%
  228. bake(new_data = NULL) %>%
  229. as.numeric(),
  230. rep(0, 5)
  231. )
  232. expect_equal(
  233. recipe(~text, data = empty_data) %>%
  234. step_tokenize(text) %>%
  235. step_word_embeddings(text, embeddings = embeddings, aggregation = "mean") %>%
  236. prep() %>%
  237. bake(new_data = NULL) %>%
  238. as.numeric(),
  239. rep(0, 5)
  240. )
  241. expect_equal(
  242. recipe(~text, data = empty_data) %>%
  243. step_tokenize(text) %>%
  244. step_word_embeddings(text, embeddings = embeddings, aggregation = "min") %>%
  245. prep() %>%
  246. bake(new_data = NULL) %>%
  247. as.numeric(),
  248. rep(0, 5)
  249. )
  250. expect_equal(
  251. recipe(~text, data = empty_data) %>%
  252. step_tokenize(text) %>%
  253. step_word_embeddings(text, embeddings = embeddings, aggregation = "max") %>%
  254. prep() %>%
  255. bake(new_data = NULL) %>%
  256. as.numeric(),
  257. rep(0, 5)
  258. )
  259. })
  260. test_that("aggregation_default argument works", {
  261. empty_data <- data.frame(text = "")
  262. expect_equal(
  263. recipe(~text, data = empty_data) %>%
  264. step_tokenize(text) %>%
  265. step_word_embeddings(text,
  266. embeddings = embeddings, aggregation = "sum",
  267. aggregation_default = 3
  268. ) %>%
  269. prep() %>%
  270. bake(new_data = NULL) %>%
  271. as.numeric(),
  272. rep(3, 5)
  273. )
  274. expect_equal(
  275. recipe(~text, data = empty_data) %>%
  276. step_tokenize(text) %>%
  277. step_word_embeddings(text,
  278. embeddings = embeddings, aggregation = "mean",
  279. aggregation_default = 3
  280. ) %>%
  281. prep() %>%
  282. bake(new_data = NULL) %>%
  283. as.numeric(),
  284. rep(3, 5)
  285. )
  286. expect_equal(
  287. recipe(~text, data = empty_data) %>%
  288. step_tokenize(text) %>%
  289. step_word_embeddings(text,
  290. embeddings = embeddings, aggregation = "min",
  291. aggregation_default = 3
  292. ) %>%
  293. prep() %>%
  294. bake(new_data = NULL) %>%
  295. as.numeric(),
  296. rep(3, 5)
  297. )
  298. expect_equal(
  299. recipe(~text, data = empty_data) %>%
  300. step_tokenize(text) %>%
  301. step_word_embeddings(text,
  302. embeddings = embeddings, aggregation = "max",
  303. aggregation_default = 3
  304. ) %>%
  305. prep() %>%
  306. bake(new_data = NULL) %>%
  307. as.numeric(),
  308. rep(3, 5)
  309. )
  310. })