diff --git a/NEWS.md b/NEWS.md index ae6353e2..bf99a9fa 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # bayesplot (development version) +* Validate empty list and zero-row matrix inputs in `nuts_params.list()`. * Validate user-provided `pit` values in `ppc_loo_pit_data()` and `ppc_loo_pit_qq()`, rejecting non-numeric inputs, missing values, and values outside `[0, 1]`. * New `show_marginal` argument to `ppd_*()` functions to show the PPD - the marginal predictive distribution by @mattansb (#425) * `ppc_ecdf_overlay()`, `ppc_ecdf_overlay_grouped()`, and `ppd_ecdf_overlay()` now always use `geom_step()`. The `discrete` argument is deprecated. diff --git a/R/bayesplot-extractors.R b/R/bayesplot-extractors.R index 79f7966e..fcaead66 100644 --- a/R/bayesplot-extractors.R +++ b/R/bayesplot-extractors.R @@ -145,10 +145,18 @@ nuts_params.stanreg <- #' @export #' @method nuts_params list nuts_params.list <- function(object, pars = NULL, ...) { + if (length(object) == 0) { + abort("'object' must be a non-empty list.") + } + if (!all(sapply(object, is.matrix))) { abort("All list elements should be matrices.") } + if (any(vapply(object, nrow, integer(1)) == 0)) { + abort("All matrices in the list must have at least one row.") + } + dd <- lapply(object, dim) if (length(unique(dd)) != 1) { abort("All matrices in the list must have the same dimensions.") diff --git a/tests/testthat/test-extractors.R b/tests/testthat/test-extractors.R index 355e1418..1116be67 100644 --- a/tests/testthat/test-extractors.R +++ b/tests/testthat/test-extractors.R @@ -9,6 +9,8 @@ x <- list(cbind(a = 1:3, b = rnorm(3)), cbind(a = 1:3, b = rnorm(3))) # nuts_params and log_posterior methods ----------------------------------- test_that("nuts_params.list throws errors", { + expect_error(nuts_params.list(list()), "non-empty list") + x[[3]] <- c(a = 1:3, b = rnorm(3)) expect_error(nuts_params.list(x), "list elements should be matrices") @@ -17,6 +19,20 @@ test_that("nuts_params.list throws errors", { x[[3]] <- cbind(a = 1:4, b = rnorm(4)) expect_error(nuts_params.list(x), "same dimensions") + + zero_row <- list(cbind(a = numeric(0), b = numeric(0))) + expect_error(nuts_params.list(zero_row), "at least one row") + + zero_row_nonfirst <- list(cbind(a = 1:3, b = rnorm(3)), cbind(a = numeric(0), b = numeric(0))) + expect_error(nuts_params.list(zero_row_nonfirst), "at least one row") +}) + +test_that("nuts_params.list works with single-chain list", { + single <- list(cbind(a = 1:3, b = rnorm(3))) + np <- nuts_params.list(single) + expect_identical(colnames(np), c("Chain", "Iteration", "Parameter", "Value")) + expect_true(all(np$Chain == 1L)) + expect_equal(nrow(np), 6L) }) test_that("nuts_params.list parameter selection ok", {