Skip to content

Conversation

@njtierney
Copy link
Collaborator

Resolves #765

Merge branch 'add-snaper-hmc' into adaptive-hmc-v2-i765

# Conflicts:
#	R/inference_class.R
#	tests/testthat/test_posteriors_geweke.R
…rror in trace_list_batches[[1]] : subscript out of bounds` - need to investigate further
@njtierney
Copy link
Collaborator Author

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
library(tictoc)
x <- normal(0, c(0.1, 1, 10, 100))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
m <- model(x)
tic()
draws_hmc <- mcmc(
  model = m,
  sampler = hmc()
  )
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#>     warmup                                           0/1000 | eta:  ?s              warmup ==                                       50/1000 | eta: 10s              warmup ====                                    100/1000 | eta:  5s              warmup ======                                  150/1000 | eta:  4s              warmup ========                                200/1000 | eta:  3s              warmup ==========                              250/1000 | eta:  2s              warmup ===========                             300/1000 | eta:  2s              warmup =============                           350/1000 | eta:  2s              warmup ===============                         400/1000 | eta:  2s              warmup =================                       450/1000 | eta:  1s              warmup ===================                     500/1000 | eta:  1s              warmup =====================                   550/1000 | eta:  1s              warmup =======================                 600/1000 | eta:  1s              warmup =========================               650/1000 | eta:  1s              warmup ===========================             700/1000 | eta:  1s              warmup ============================            750/1000 | eta:  1s              warmup ==============================          800/1000 | eta:  0s              warmup ================================        850/1000 | eta:  0s              warmup ==================================      900/1000 | eta:  0s              warmup ====================================    950/1000 | eta:  0s              warmup ====================================== 1000/1000 | eta:  0s          
#>   sampling                                           0/1000 | eta:  ?s            sampling ==                                       50/1000 | eta:  1s            sampling ====                                    100/1000 | eta:  1s            sampling ======                                  150/1000 | eta:  1s            sampling ========                                200/1000 | eta:  1s            sampling ==========                              250/1000 | eta:  0s            sampling ===========                             300/1000 | eta:  0s            sampling =============                           350/1000 | eta:  0s            sampling ===============                         400/1000 | eta:  0s            sampling =================                       450/1000 | eta:  0s            sampling ===================                     500/1000 | eta:  0s            sampling =====================                   550/1000 | eta:  0s            sampling =======================                 600/1000 | eta:  0s            sampling =========================               650/1000 | eta:  0s            sampling ===========================             700/1000 | eta:  0s            sampling ============================            750/1000 | eta:  0s            sampling ==============================          800/1000 | eta:  0s            sampling ================================        850/1000 | eta:  0s            sampling ==================================      900/1000 | eta:  0s            sampling ====================================    950/1000 | eta:  0s            sampling ====================================== 1000/1000 | eta:  0s
toc()
#> 2.877 sec elapsed

par(mfrow = c(2, 4))
plot(draws_hmc, auto.layout = FALSE)

tic()
draws_hmc_adapt <- mcmc(
  model = m,
  sampler = adaptive_hmc()
  )
#> running 4 chains simultaneously on up to 8 CPU cores
#>   sampling                                           0/1000 | eta:  ?s            sampling ==                                       50/1000 | eta: 41s            sampling ====                                    100/1000 | eta: 20s            sampling ======                                  150/1000 | eta: 13s            sampling ========                                200/1000 | eta:  9s            sampling ==========                              250/1000 | eta:  7s            sampling ===========                             300/1000 | eta:  5s            sampling =============                           350/1000 | eta:  4s            sampling ===============                         400/1000 | eta:  4s            sampling =================                       450/1000 | eta:  3s            sampling ===================                     500/1000 | eta:  2s            sampling =====================                   550/1000 | eta:  2s            sampling =======================                 600/1000 | eta:  2s            sampling =========================               650/1000 | eta:  1s            sampling ===========================             700/1000 | eta:  1s            sampling ============================            750/1000 | eta:  1s            sampling ==============================          800/1000 | eta:  1s            sampling ================================        850/1000 | eta:  0s            sampling ==================================      900/1000 | eta:  0s            sampling ====================================    950/1000 | eta:  0s            sampling ====================================== 1000/1000 | eta:  0s
toc()
#> 5.167 sec elapsed

plot(draws_hmc_adapt, auto.layout = FALSE)

Created on 2025-03-12 with reprex v2.1.1

Session info

sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.2 (2024-10-31)
#>  os       macOS Sequoia 15.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Hobart
#>  date     2025-03-12
#>  pandoc   3.2.1 @ /opt/homebrew/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  abind         1.4-8      2024-09-12 [1] CRAN (R 4.4.1)
#>  backports     1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.4.0)
#>  callr         3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>  cli           3.6.4      2025-02-13 [1] CRAN (R 4.4.1)
#>  coda          0.19-4.1   2024-01-31 [1] CRAN (R 4.4.0)
#>  codetools     0.2-20     2024-03-31 [2] CRAN (R 4.4.2)
#>  crayon        1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
#>  curl          6.2.0      2025-01-23 [1] CRAN (R 4.4.1)
#>  digest        0.6.37     2024-08-19 [1] CRAN (R 4.4.1)
#>  evaluate      1.0.1      2024-10-10 [1] CRAN (R 4.4.1)
#>  fastmap       1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>  fs            1.6.5      2024-10-30 [1] CRAN (R 4.4.1)
#>  future        1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>  globals       0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>  glue          1.8.0      2024-09-30 [1] CRAN (R 4.4.1)
#>  greta       * 0.5.0.9000 2025-03-12 [1] local
#>  hms           1.1.3      2023-03-21 [1] CRAN (R 4.4.0)
#>  htmltools     0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>  jsonlite      1.8.9      2024-09-20 [1] CRAN (R 4.4.1)
#>  knitr         1.49       2024-11-08 [1] CRAN (R 4.4.1)
#>  lattice       0.22-6     2024-03-20 [2] CRAN (R 4.4.2)
#>  lifecycle     1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>  listenv       0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>  magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>  Matrix        1.7-1      2024-10-18 [2] CRAN (R 4.4.2)
#>  parallelly    1.41.0     2024-12-18 [1] CRAN (R 4.4.1)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>  png           0.1-8      2022-11-29 [1] CRAN (R 4.4.0)
#>  prettyunits   1.2.0      2023-09-24 [1] CRAN (R 4.4.0)
#>  processx      3.8.5      2025-01-08 [1] CRAN (R 4.4.1)
#>  progress      1.2.3      2023-12-06 [1] CRAN (R 4.4.0)
#>  ps            1.8.1      2024-10-28 [1] CRAN (R 4.4.1)
#>  R6            2.6.1      2025-02-15 [1] CRAN (R 4.4.1)
#>  Rcpp          1.0.14     2025-01-12 [1] CRAN (R 4.4.1)
#>  reprex        2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>  reticulate    1.40.0     2024-11-15 [1] CRAN (R 4.4.1)
#>  rlang         1.1.5      2025-01-17 [1] CRAN (R 4.4.1)
#>  rmarkdown     2.29       2024-11-04 [1] CRAN (R 4.4.1)
#>  rstudioapi    0.17.1     2024-10-22 [1] CRAN (R 4.4.1)
#>  sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>  tensorflow    2.16.0     2024-04-15 [1] CRAN (R 4.4.0)
#>  tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.4.0)
#>  tfruns        1.5.3      2024-04-19 [1] CRAN (R 4.4.0)
#>  tictoc      * 1.2.1      2024-03-18 [1] CRAN (R 4.4.0)
#>  vctrs         0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>  whisker       0.4.1      2022-12-05 [1] CRAN (R 4.4.0)
#>  withr         3.0.2      2024-10-28 [1] CRAN (R 4.4.1)
#>  xfun          0.50.5     2025-01-15 [1] Github (yihui/xfun@116d689)
#>  xml2          1.3.6      2023-12-04 [1] CRAN (R 4.4.0)
#>  yaml          2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/nick/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#>  numpy_version:  1.26.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python() function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

njtierney and others added 16 commits March 18, 2025 16:37
- remove unused n_adapt parameter in warm_up_sampler, instead refer to parameter set by user in adaptive_hmc()
- removed unrequired sampler_param_vec in define_tf_kernel()
- Pass chains argument through to `check_geweke` as adaptive_hmc() requires at least 2 chains
- use expect_no_error to test that adaptive_hmc() works with 0 warmup
- update snapshots
…h as the free state, otherwise use the last warmed up parameters to pass through to sampling.
…is is implicitly checked in other tests, e.g., extra_samples etc
…c), and then for rejection (for counting bad samples)
- now returns as `self$warm_results`
…le time.

- Remove extra `make_sampler_function()`
- Add tests for `extra_samples` to ensure it works with `adaptive_hmc()`
… not the current state from the warmed up sampler
- replace `current_state` with `free_state`
- warm_results doesn't need to return the kernel and current state - we actually just only need the kernel_results
- make_sampler_function() doesn't need `sampler` arg
- make_sampler_function() also doesn't need to have the logic inside of it to dispatch different results if it is already warmed up, we can move this inside of `sample_raw`
- create sampling_results object that gets the result of `batch_results$kernel_results`, for use in geweke test
…it from "tune_tf". This doesn't do anything at the moment, but helps set things up for when we want to explore using tuning from the TF side of things and have slower equivalents in R that poll/check more frequently.

Merge commit '6e1a347922086f993f6c75f7a672524c15555195'

#Conflicts:
#	R/samplers.R
@njtierney
Copy link
Collaborator Author

njtierney commented May 6, 2025

Current progress - 10K warmup, 500 iterations:

image

Next steps:

  1. Run separate runners for geweke for each sampler (hmc, rwmh, slice) to get a sense of how long these take to run, as well as for completeness

Then for adaptive HMC geweke test

  1. Run for more iterations
  2. Run for more chains
  3. Add capacity to return all chains in geweke checks
  4. Add capacity to do thinning after geweke checks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add snaper HMC sampler

2 participants