/inst/examples/Chicago_corr_lm.R

https://github.com/tidymodels/tune · R · 92 lines · 64 code · 25 blank · 3 comment · 2 complexity · c2f4227f209dd999783d0c3c18acfe1a MD5 · raw file

  1. library(tidymodels)
  2. library(tune)
  3. library(ggforce)
  4. # ------------------------------------------------------------------------------
  5. set.seed(7898)
  6. data_folds <- rolling_origin(Chicago, initial = 364 * 15, assess = 7 * 4, skip = 13, cumulative = FALSE)
  7. # ------------------------------------------------------------------------------
  8. stations <- names(Chicago)[2:21]
  9. chi_rec <-
  10. recipe(ridership ~ ., data = Chicago) %>%
  11. step_holiday(date) %>%
  12. step_date(date) %>%
  13. step_rm(date) %>%
  14. step_dummy(all_nominal()) %>%
  15. step_zv(all_predictors()) %>%
  16. step_corr(one_of(!!stations), threshold = tune())
  17. lm_model <-
  18. linear_reg(mode = "regression") %>%
  19. set_engine("lm")
  20. chi_wflow <-
  21. workflow() %>%
  22. add_recipe(chi_rec) %>%
  23. add_model(lm_model)
  24. chi_grid <-
  25. parameters(chi_wflow) %>%
  26. update(threshold = threshold(c(.8, .99))) %>%
  27. grid_regular(levels = 10)
  28. ext <- function(x) {
  29. broom::glance(x$model)
  30. }
  31. res <- tune_grid(chi_wflow, resamples = data_folds, grid = chi_grid,
  32. control = control_grid(verbose = TRUE, extract = ext))
  33. res_2 <- tune_grid(chi_rec, lm_model, resamples = data_folds, grid = chi_grid,
  34. control = control_grid(verbose = TRUE, extract = ext))
  35. # unnest(unnest(res %>% select(id, .extracts), cols = .extracts), cols = .extract)
  36. lm_stats <-
  37. res %>%
  38. select(id, .extracts) %>%
  39. unnest(cols = .extracts) %>%
  40. unnest(cols = .extracts) %>%
  41. select(id, threshold, adj.r.squared, sigma, AIC, BIC) %>%
  42. group_by(threshold) %>%
  43. summarize(
  44. adj.r.squared = mean(adj.r.squared, na.rm = TRUE),
  45. sigma = mean(sigma, na.rm = TRUE),
  46. AIC = mean(AIC, na.rm = TRUE),
  47. BIC = mean(BIC, na.rm = TRUE)
  48. )
  49. rs_stats <-
  50. summarize(res) %>%
  51. select(threshold, .metric, mean) %>%
  52. pivot_wider(names_from = .metric, values_from = mean, id_cols = threshold)
  53. all_stats <- full_join(lm_stats, rs_stats)
  54. ggplot(all_stats, aes(x = .panel_x, y = .panel_y, colour = threshold)) +
  55. geom_point() +
  56. facet_matrix(vars(-threshold)) +
  57. theme_bw()
  58. summarize(res) %>%
  59. dplyr::filter(.metric == "rmse") %>%
  60. select(-n, -std_err, -.estimator, -.metric) %>%
  61. ggplot(aes(x = threshold, y = mean)) +
  62. geom_point() +
  63. geom_line()
  64. summarize(res) %>%
  65. dplyr::filter(.metric == "rmse") %>%
  66. arrange(mean) %>%
  67. slice(1)