/R/yardstick-metric-sets.R

https://github.com/business-science/modeltime · R · 287 lines · 93 code · 45 blank · 149 comment · 5 complexity · c294776f4d2c8f1dd602a417f35d9bd7 MD5 · raw file

  1. # DEFAULT FORECAST ACCURACY METRIC SET ----
  2. #' Forecast Accuracy Metrics Sets
  3. #'
  4. #'
  5. #' This is a wrapper for [metric_set()] with several common forecast / regression
  6. #' accuracy metrics included. These are the default time series accuracy
  7. #' metrics used with [modeltime_accuracy()].
  8. #'
  9. #' @param ... Add additional `yardstick` metrics
  10. #'
  11. #' @details
  12. #'
  13. #' # Default Forecast Accuracy Metric Set
  14. #'
  15. #' The primary purpose is to use the default accuracy metrics to calculate the following
  16. #' forecast accuracy metrics using [modeltime_accuracy()]:
  17. #'
  18. #' - MAE - Mean absolute error, [mae()]
  19. #' - MAPE - Mean absolute percentage error, [mape()]
  20. #' - MASE - Mean absolute scaled error, [mase()]
  21. #' - SMAPE - Symmetric mean absolute percentage error, [smape()]
  22. #' - RMSE - Root mean squared error, [rmse()]
  23. #' - RSQ - R-squared, [rsq()]
  24. #'
  25. #' Adding additional metrics is possible via `...`.
  26. #'
  27. #' # Extended Forecast Accuracy Metric Set
  28. #'
  29. #' Extends the default metric set by adding:
  30. #'
  31. #' - MAAPE - Mean Arctangent Absolute Percentage Error, [maape()].
  32. #' MAAPE is designed for intermittent data where MAPE returns `Inf`.
  33. #'
  34. #'
  35. #'
  36. #' @seealso
  37. #' - [yardstick::metric_tweak()] - For modifying `yardstick` metrics
  38. #'
  39. #' @examples
  40. #' library(tibble)
  41. #' library(dplyr)
  42. #' library(timetk)
  43. #' library(yardstick)
  44. #'
  45. #' fake_data <- tibble(
  46. #' y = c(1:12, 2*1:12),
  47. #' yhat = c(1 + 1:12, 2*1:12 - 1)
  48. #' )
  49. #'
  50. #' # ---- HOW IT WORKS ----
  51. #'
  52. #' # Default Forecast Accuracy Metric Specification
  53. #' default_forecast_accuracy_metric_set()
  54. #'
  55. #' # Create a metric summarizer function from the metric set
  56. #' calc_default_metrics <- default_forecast_accuracy_metric_set()
  57. #'
  58. #' # Apply the metric summarizer to new data
  59. #' calc_default_metrics(fake_data, y, yhat)
  60. #'
  61. #' # ---- ADD MORE PARAMETERS ----
  62. #'
  63. #' # Can create a version of mase() with seasonality = 12 (monthly)
  64. #' mase12 <- metric_tweak(.name = "mase12", .fn = mase, m = 12)
  65. #'
  66. #' # Add it to the default metric set
  67. #' my_metric_set <- default_forecast_accuracy_metric_set(mase12)
  68. #' my_metric_set
  69. #'
  70. #' # Apply the newly created metric set
  71. #' my_metric_set(fake_data, y, yhat)
  72. #'
  73. #'
  74. #' @name metric_sets
  75. #' @importFrom yardstick mae mape mase smape rmse rsq metric_tweak
  76. #' @export
  77. #' @rdname metric_sets
  78. default_forecast_accuracy_metric_set <- function(...) {
  79. yardstick::metric_set(
  80. yardstick::mae,
  81. yardstick::mape,
  82. yardstick::mase,
  83. yardstick::smape,
  84. yardstick::rmse,
  85. yardstick::rsq,
  86. ...
  87. )
  88. }
  89. # EXTENDED FORECAST ACCURACY METRIC SET ----
  90. #' @importFrom yardstick mae mape mase smape rmse rsq metric_tweak
  91. #' @export
  92. #' @rdname metric_sets
  93. extended_forecast_accuracy_metric_set <- function(...) {
  94. yardstick::metric_set(
  95. yardstick::mae,
  96. yardstick::mape,
  97. maape,
  98. yardstick::mase,
  99. yardstick::smape,
  100. yardstick::rmse,
  101. yardstick::rsq,
  102. ...
  103. )
  104. }
  105. # SUMMARIZE ACCURACY ----
  106. #' Summarize Accuracy Metrics
  107. #'
  108. #' This is an internal function used by `modeltime_accuracy()`.
  109. #'
  110. #' @inheritParams modeltime_accuracy
  111. #' @param data A `data.frame` containing the truth and estimate columns.
  112. #' @param truth The column identifier for the true results (that is numeric).
  113. #' @param estimate The column identifier for the predicted results (that is also numeric).
  114. #'
  115. #' @examples
  116. #' library(tibble)
  117. #' library(dplyr)
  118. #'
  119. #' predictions_tbl <- tibble(
  120. #' group = c("model 1", "model 1", "model 1",
  121. #' "model 2", "model 2", "model 2"),
  122. #' truth = c(1, 2, 3,
  123. #' 1, 2, 3),
  124. #' estimate = c(1.2, 2.0, 2.5,
  125. #' 0.9, 1.9, 3.3)
  126. #' )
  127. #'
  128. #' predictions_tbl %>%
  129. #' group_by(group) %>%
  130. #' summarize_accuracy_metrics(
  131. #' truth, estimate,
  132. #' metric_set = default_forecast_accuracy_metric_set()
  133. #' )
  134. #'
  135. #' @export
  136. summarize_accuracy_metrics <- function(data, truth, estimate, metric_set) {
  137. data_tbl <- data
  138. truth_expr <- rlang::enquo(truth)
  139. estimate_expr <- rlang::enquo(estimate)
  140. metric_summarizer_fun <- metric_set
  141. group_nms <- dplyr::group_vars(data_tbl)
  142. data_tbl %>%
  143. metric_summarizer_fun(!! truth_expr, !! estimate_expr) %>%
  144. dplyr::select(-.estimator) %>%
  145. dplyr::group_by(!!! rlang::syms(group_nms)) %>%
  146. dplyr::mutate(.metric = make.unique(.metric, sep = "_")) %>%
  147. dplyr::ungroup() %>%
  148. tidyr::pivot_wider(
  149. names_from = .metric,
  150. values_from = .estimate
  151. )
  152. }
  153. # UTILITIES ----
  154. calc_accuracy_2 <- function(train_data = NULL, test_data = NULL, metric_set, by_id = FALSE, ...) {
  155. metrics <- metric_set
  156. # Training Metrics
  157. train_metrics_tbl <- tibble::tibble()
  158. # Testing Metrics
  159. test_metrics_tbl <- tibble::tibble()
  160. # Check by_id
  161. if (by_id) {
  162. if (length(names(test_data)) == 5) {
  163. id_col_text <- names(test_data)[5]
  164. test_data <- test_data %>%
  165. dplyr::group_by(!! rlang::ensym(id_col_text))
  166. } else {
  167. rlang::warn("The 'id' column in calibration data was not detected. Global accuracy is being returned.")
  168. }
  169. }
  170. if (!is.null(test_data)) {
  171. test_metrics_tbl <- test_data %>%
  172. summarize_accuracy_metrics(
  173. truth = .actual,
  174. estimate = .prediction,
  175. metric_set = metrics
  176. ) %>%
  177. dplyr::ungroup()
  178. }
  179. metrics_tbl <- dplyr::bind_rows(train_metrics_tbl, test_metrics_tbl)
  180. return(metrics_tbl)
  181. }
  182. # MAAPE ----
  183. #' Mean Arctangent Absolute Percentage Error
  184. #'
  185. #' This is basically a wrapper to the function of `TSrepr::maape()`.
  186. #'
  187. #' @param truth The column identifier for the true results (that is numeric).
  188. #' @param estimate The column identifier for the predicted results (that is also numeric).
  189. #' @param na_rm Not in use...NA values managed by TSrepr::maape
  190. #' @param ... Not currently in use
  191. #'
  192. #' @export
  193. maape_vec <- function(truth, estimate, na_rm = TRUE, ...) {
  194. maape_impl <- function(truth, estimate) {
  195. TSrepr::maape(truth, estimate)
  196. }
  197. yardstick::metric_vec_template(
  198. metric_impl = maape_impl,
  199. truth = truth,
  200. estimate = estimate,
  201. na_rm = na_rm,
  202. cls = "numeric",
  203. ...
  204. )
  205. }
  206. # MAAPE ----
  207. #' Mean Arctangent Absolute Percentage Error
  208. #'
  209. #' Useful when MAPE returns Inf typically due to intermittent data containing zeros.
  210. #' This is a wrapper to the function of `TSrepr::maape()`.
  211. #'
  212. #' @param data A `data.frame` containing the truth and estimate columns.
  213. #' @param ... Not currently in use.
  214. #'
  215. #' @export
  216. maape <- function(data, ...) {
  217. UseMethod("maape")
  218. }
  219. maape <- yardstick::new_numeric_metric(maape, direction = "minimize")
  220. # MAAPE ----
  221. #' Mean Arctangent Absolute Percentage Error
  222. #'
  223. #' This is basically a wrapper to the function of `TSrepr::maape()`.
  224. #'
  225. #' @param data A `data.frame` containing the truth and estimate columns.
  226. #' @param truth The column identifier for the true results (that is numeric).
  227. #' @param estimate The column identifier for the predicted results (that is also numeric).
  228. #' @param na_rm Not in use...NA values managed by TSrepr::maape
  229. #' @param ... Not currently in use
  230. #'
  231. #' @export
  232. maape.data.frame <- function(data, truth, estimate, na_rm = TRUE, ...) {
  233. yardstick::metric_summarizer(
  234. metric_nm = "maape",
  235. metric_fn = maape_vec,
  236. data = data,
  237. truth = !! enquo(truth),
  238. estimate = !! enquo(estimate),
  239. na_rm = na_rm,
  240. ...
  241. )
  242. }