pymc-extras 0.6.0__tar.gz → 0.8.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/PKG-INFO +4 -4
  2. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/_version.py +2 -2
  3. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/conda-envs/environment-test.yml +2 -2
  4. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/statespace/models/structural.rst +3 -3
  5. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/distributions/timeseries.py +10 -10
  6. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/dadvi/dadvi.py +14 -83
  7. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/laplace_approx/laplace.py +187 -159
  8. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/pathfinder/pathfinder.py +12 -7
  9. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/smc/sampling.py +2 -2
  10. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/model/marginal/distributions.py +4 -2
  11. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/model/marginal/marginal_model.py +12 -2
  12. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/prior.py +3 -3
  13. pymc_extras-0.8.0/pymc_extras/statespace/core/properties.py +276 -0
  14. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/core/statespace.py +182 -45
  15. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/filters/distributions.py +19 -34
  16. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/filters/kalman_filter.py +13 -12
  17. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/filters/kalman_smoother.py +2 -2
  18. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/DFM.py +179 -168
  19. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/ETS.py +177 -151
  20. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/SARIMAX.py +149 -152
  21. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/VARMAX.py +134 -145
  22. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/__init__.py +8 -1
  23. pymc_extras-0.8.0/pymc_extras/statespace/models/structural/__init__.py +43 -0
  24. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/structural/components/autoregressive.py +87 -45
  25. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/structural/components/cycle.py +119 -80
  26. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/structural/components/level_trend.py +95 -42
  27. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/structural/components/measurement_error.py +27 -17
  28. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/structural/components/regression.py +105 -68
  29. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/structural/components/seasonality.py +138 -100
  30. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/structural/core.py +397 -286
  31. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/utilities.py +5 -20
  32. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pyproject.toml +12 -5
  33. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/distributions/test_continuous.py +4 -0
  34. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/distributions/test_discrete.py +8 -5
  35. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/inference/dadvi/test_dadvi.py +1 -5
  36. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/inference/laplace_approx/test_laplace.py +25 -14
  37. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/model/marginal/test_marginal_model.py +18 -5
  38. pymc_extras-0.8.0/tests/statespace/core/test_properties.py +203 -0
  39. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/core/test_statespace.py +13 -30
  40. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/filters/test_distributions.py +4 -6
  41. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/components/test_autoregressive.py +34 -31
  42. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/components/test_cycle.py +57 -44
  43. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/components/test_level_trend.py +36 -29
  44. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/components/test_measurement_error.py +25 -12
  45. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/components/test_regression.py +24 -25
  46. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/components/test_seasonality.py +79 -58
  47. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/conftest.py +1 -1
  48. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/test_against_statsmodels.py +5 -5
  49. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/test_core.py +10 -10
  50. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/test_DFM.py +110 -60
  51. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/test_ETS.py +45 -0
  52. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/test_SARIMAX.py +43 -3
  53. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/test_VARMAX.py +56 -11
  54. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/test_utilities.py +3 -2
  55. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/utils/test_coord_assignment.py +4 -4
  56. pymc_extras-0.6.0/pymc_extras/statespace/models/structural/__init__.py +0 -21
  57. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/.gitignore +0 -0
  58. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/.gitpod.yml +0 -0
  59. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/.pre-commit-config.yaml +0 -0
  60. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/.readthedocs.yaml +0 -0
  61. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/CODE_OF_CONDUCT.md +0 -0
  62. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/CONTRIBUTING.md +0 -0
  63. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/LICENSE +0 -0
  64. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/README.md +0 -0
  65. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/codecov.yml +0 -0
  66. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/.nojekyll +0 -0
  67. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/Makefile +0 -0
  68. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/_templates/autosummary/base.rst +0 -0
  69. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/_templates/autosummary/class.rst +0 -0
  70. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/api_reference.rst +0 -0
  71. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/conf.py +0 -0
  72. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/index.rst +0 -0
  73. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/make.bat +0 -0
  74. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/statespace/core.rst +0 -0
  75. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/statespace/filters.rst +0 -0
  76. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/docs/statespace/models.rst +0 -0
  77. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/__init__.py +0 -0
  78. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/deserialize.py +0 -0
  79. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/distributions/__init__.py +0 -0
  80. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/distributions/continuous.py +0 -0
  81. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/distributions/discrete.py +0 -0
  82. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/distributions/histogram_utils.py +0 -0
  83. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/distributions/multivariate/__init__.py +0 -0
  84. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
  85. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/distributions/transforms/__init__.py +0 -0
  86. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/distributions/transforms/partial_order.py +0 -0
  87. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/gp/__init__.py +0 -0
  88. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/gp/latent_approx.py +0 -0
  89. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/__init__.py +0 -0
  90. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/dadvi/__init__.py +0 -0
  91. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/fit.py +0 -0
  92. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/laplace_approx/__init__.py +0 -0
  93. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/laplace_approx/find_map.py +0 -0
  94. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/laplace_approx/idata.py +0 -0
  95. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/laplace_approx/scipy_interface.py +0 -0
  96. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/pathfinder/__init__.py +0 -0
  97. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/pathfinder/idata.py +0 -0
  98. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/pathfinder/importance_sampling.py +0 -0
  99. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/pathfinder/lbfgs.py +0 -0
  100. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/inference/smc/__init__.py +0 -0
  101. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/linearmodel.py +0 -0
  102. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/model/__init__.py +0 -0
  103. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/model/marginal/__init__.py +0 -0
  104. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/model/marginal/graph_analysis.py +0 -0
  105. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/model/model_api.py +0 -0
  106. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/model/transforms/__init__.py +0 -0
  107. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/model/transforms/autoreparam.py +0 -0
  108. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/model_builder.py +0 -0
  109. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/preprocessing/__init__.py +0 -0
  110. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/preprocessing/standard_scaler.py +0 -0
  111. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/printing.py +0 -0
  112. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/__init__.py +0 -0
  113. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/core/__init__.py +0 -0
  114. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/core/compile.py +0 -0
  115. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/core/representation.py +0 -0
  116. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/filters/__init__.py +0 -0
  117. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/filters/utilities.py +0 -0
  118. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/structural/components/__init__.py +0 -0
  119. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/models/structural/utils.py +0 -0
  120. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/utils/__init__.py +0 -0
  121. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/utils/constants.py +0 -0
  122. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/utils/coord_tools.py +0 -0
  123. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/statespace/utils/data_tools.py +0 -0
  124. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/utils/__init__.py +0 -0
  125. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/utils/linear_cg.py +0 -0
  126. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/utils/model_equivalence.py +0 -0
  127. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/utils/prior.py +0 -0
  128. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/pymc_extras/utils/spline.py +0 -0
  129. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/__init__.py +0 -0
  130. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/conftest.py +0 -0
  131. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/distributions/__init__.py +0 -0
  132. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/distributions/test_discrete_markov_chain.py +0 -0
  133. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/distributions/test_multivariate.py +0 -0
  134. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/distributions/test_transform.py +0 -0
  135. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/inference/__init__.py +0 -0
  136. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/inference/dadvi/__init__.py +0 -0
  137. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/inference/laplace_approx/__init__.py +0 -0
  138. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/inference/laplace_approx/test_find_map.py +0 -0
  139. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/inference/laplace_approx/test_idata.py +0 -0
  140. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/inference/laplace_approx/test_scipy_interface.py +0 -0
  141. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/model/__init__.py +0 -0
  142. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/model/marginal/__init__.py +0 -0
  143. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/model/marginal/test_distributions.py +0 -0
  144. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/model/marginal/test_graph_analysis.py +0 -0
  145. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/model/test_model_api.py +0 -0
  146. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/model/transforms/test_autoreparam.py +0 -0
  147. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/pathfinder/__init__.py +0 -0
  148. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/pathfinder/test_idata.py +0 -0
  149. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/pathfinder/test_pathfinder.py +0 -0
  150. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/__init__.py +0 -0
  151. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/_data/airpass.csv +0 -0
  152. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/_data/airpassangers.csv +0 -0
  153. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/_data/nile.csv +0 -0
  154. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/_data/statsmodels_macrodata_processed.csv +0 -0
  155. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/core/__init__.py +0 -0
  156. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/core/test_representation.py +0 -0
  157. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/core/test_statespace_JAX.py +0 -0
  158. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/filters/__init__.py +0 -0
  159. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/filters/test_kalman_filter.py +0 -0
  160. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/__init__.py +0 -0
  161. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/__init__.py +0 -0
  162. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/structural/components/__init__.py +0 -0
  163. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/models/test_utilities.py +0 -0
  164. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/shared_fixtures.py +0 -0
  165. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/statsmodel_local_level.py +0 -0
  166. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/statespace/utils/__init__.py +0 -0
  167. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/test_blackjax_smc.py +0 -0
  168. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/test_deserialize.py +0 -0
  169. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/test_histogram_approximation.py +0 -0
  170. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/test_linearmodel.py +0 -0
  171. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/test_model_builder.py +0 -0
  172. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/test_printing.py +0 -0
  173. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/test_prior.py +0 -0
  174. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/test_prior_from_trace.py +0 -0
  175. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/test_splines.py +0 -0
  176. {pymc_extras-0.6.0 → pymc_extras-0.8.0}/tests/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pymc-extras
3
- Version: 0.6.0
3
+ Version: 0.8.0
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Project-URL: Documentation, https://pymc-extras.readthedocs.io/
6
6
  Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
@@ -235,8 +235,8 @@ Requires-Python: >=3.11
235
235
  Requires-Dist: better-optimize>=0.1.5
236
236
  Requires-Dist: preliz>=0.20.0
237
237
  Requires-Dist: pydantic>=2.0.0
238
- Requires-Dist: pymc>=5.26.1
239
- Requires-Dist: pytensor>=2.35.1
238
+ Requires-Dist: pymc>=5.27.1
239
+ Requires-Dist: pytensor>=2.37.0
240
240
  Requires-Dist: scikit-learn
241
241
  Provides-Extra: complete
242
242
  Requires-Dist: dask[complete]<2025.1.1; extra == 'complete'
@@ -245,7 +245,7 @@ Provides-Extra: dask-histogram
245
245
  Requires-Dist: dask[complete]<2025.1.1; extra == 'dask-histogram'
246
246
  Requires-Dist: xhistogram; extra == 'dask-histogram'
247
247
  Provides-Extra: dev
248
- Requires-Dist: blackjax; extra == 'dev'
248
+ Requires-Dist: blackjax>=0.12; extra == 'dev'
249
249
  Requires-Dist: dask[all]<2025.1.1; extra == 'dev'
250
250
  Requires-Dist: pytest-mock; extra == 'dev'
251
251
  Requires-Dist: pytest>=6.0; extra == 'dev'
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.6.0'
32
- __version_tuple__ = version_tuple = (0, 6, 0)
31
+ __version__ = version = '0.8.0'
32
+ __version_tuple__ = version_tuple = (0, 8, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -3,8 +3,8 @@ channels:
3
3
  - conda-forge
4
4
  - nodefaults
5
5
  dependencies:
6
- - pymc>=5.26.1
7
- - pytensor>=2.35.1
6
+ - pymc>=5.27.1
7
+ - pytensor>=2.37.0
8
8
  - scikit-learn
9
9
  - better-optimize>=0.1.5
10
10
  - dask<2025.1.1
@@ -6,9 +6,9 @@ Structural Components
6
6
  .. autosummary::
7
7
  :toctree: generated
8
8
 
9
- LevelTrendComponent
10
- AutoregressiveComponent
9
+ LevelTrend
10
+ Autoregressive
11
11
  TimeSeasonality
12
12
  FrequencySeasonality
13
13
  MeasurementError
14
- CycleComponent
14
+ Cycle
@@ -196,21 +196,20 @@ class DiscreteMarkovChain(Distribution):
196
196
  state_rng = pytensor.shared(np.random.default_rng())
197
197
 
198
198
  def transition(*args):
199
- *states, transition_probs, old_rng = args
199
+ old_rng, *states, transition_probs = args
200
200
  p = transition_probs[tuple(states)]
201
201
  next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
202
- return next_state, {old_rng: next_rng}
202
+ return next_rng, next_state
203
203
 
204
- markov_chain, state_updates = pytensor.scan(
204
+ state_next_rng, markov_chain = pytensor.scan(
205
205
  transition,
206
- non_sequences=[P_, state_rng],
207
- outputs_info=_make_outputs_info(n_lags, init_dist_),
206
+ outputs_info=[state_rng, *_make_outputs_info(n_lags, init_dist_)],
207
+ non_sequences=[P_],
208
208
  n_steps=steps_,
209
209
  strict=True,
210
+ return_updates=False,
210
211
  )
211
212
 
212
- (state_next_rng,) = tuple(state_updates.values())
213
-
214
213
  discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
215
214
 
216
215
  discrete_mc_op = DiscreteMarkovChainRV(
@@ -239,16 +238,17 @@ def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
239
238
  n_lags = op.n_lags
240
239
 
241
240
  def greedy_transition(*args):
242
- *states, transition_probs, old_rng = args
241
+ *states, transition_probs = args
243
242
  p = transition_probs[tuple(states)]
244
243
  return pt.argmax(p)
245
244
 
246
- chain_moment, moment_updates = pytensor.scan(
245
+ chain_moment = pytensor.scan(
247
246
  greedy_transition,
248
- non_sequences=[P, state_rng],
247
+ non_sequences=[P],
249
248
  outputs_info=_make_outputs_info(n_lags, init_dist),
250
249
  n_steps=steps,
251
250
  strict=True,
251
+ return_updates=False,
252
252
  )
253
253
  chain_moment = pt.concatenate([init_dist_moment, chain_moment])
254
254
  return chain_moment
@@ -3,25 +3,20 @@ import numpy as np
3
3
  import pymc
4
4
  import pytensor
5
5
  import pytensor.tensor as pt
6
- import xarray
7
6
 
7
+ from arviz import InferenceData
8
8
  from better_optimize import basinhopping, minimize
9
9
  from better_optimize.constants import minimize_method
10
10
  from pymc import DictToArrayBijection, Model, join_nonshared_inputs
11
- from pymc.backends.arviz import (
12
- PointFunc,
13
- apply_function_over_dataset,
14
- coords_and_dims_for_inferencedata,
15
- )
16
11
  from pymc.blocking import RaveledVars
17
- from pymc.util import RandomSeed, get_default_varnames
12
+ from pymc.util import RandomSeed
18
13
  from pytensor.tensor.variable import TensorVariable
19
14
 
20
15
  from pymc_extras.inference.laplace_approx.idata import (
21
16
  add_data_to_inference_data,
22
17
  add_optimizer_result_to_inference_data,
23
18
  )
24
- from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
19
+ from pymc_extras.inference.laplace_approx.laplace import draws_from_laplace_approx
25
20
  from pymc_extras.inference.laplace_approx.scipy_interface import (
26
21
  scipy_optimize_funcs_from_loss,
27
22
  set_optimizer_function_defaults,
@@ -193,16 +188,18 @@ def fit_dadvi(
193
188
  opt_var_params = result.x
194
189
  opt_means, opt_log_sds = np.split(opt_var_params, 2)
195
190
 
196
- # Make the draws:
197
- generator = np.random.default_rng(seed=random_seed)
198
- draws_raw = generator.standard_normal(size=(n_draws, n_params))
199
-
200
- draws = opt_means + draws_raw * np.exp(opt_log_sds)
201
- draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
202
-
203
- idata = dadvi_result_to_idata(
204
- draws_arviz, model, include_transformed=include_transformed, progressbar=progressbar
191
+ posterior, unconstrained_posterior = draws_from_laplace_approx(
192
+ mean=opt_means,
193
+ standard_deviation=np.exp(opt_log_sds),
194
+ draws=n_draws,
195
+ model=model,
196
+ vectorize_draws=False,
197
+ return_unconstrained=include_transformed,
198
+ random_seed=random_seed,
205
199
  )
200
+ idata = InferenceData(posterior=posterior)
201
+ if include_transformed:
202
+ idata.add_groups(unconstrained_posterior=unconstrained_posterior)
206
203
 
207
204
  var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
208
205
  var_name_to_model_var.update(
@@ -283,69 +280,3 @@ def create_dadvi_graph(
283
280
  objective = -mean_log_density - entropy
284
281
 
285
282
  return var_params, objective
286
-
287
-
288
- def dadvi_result_to_idata(
289
- unstacked_draws: xarray.Dataset,
290
- model: Model,
291
- include_transformed: bool = False,
292
- progressbar: bool = True,
293
- ):
294
- """
295
- Transforms the unconstrained draws back into the constrained space.
296
-
297
- Parameters
298
- ----------
299
- unstacked_draws : xarray.Dataset
300
- The draws to constrain back into the original space.
301
-
302
- model : Model
303
- The PyMC model the variables were derived from.
304
-
305
- n_draws: int
306
- The number of draws to return from the variational approximation.
307
-
308
- include_transformed: bool
309
- Whether or not to keep the unconstrained variables in the output.
310
-
311
- progressbar: bool
312
- Whether or not to show a progress bar during the transformation. Default is True.
313
-
314
- Returns
315
- -------
316
- :class:`~arviz.InferenceData`
317
- Draws from the original constrained parameters.
318
- """
319
-
320
- filtered_var_names = model.unobserved_value_vars
321
- vars_to_sample = list(
322
- get_default_varnames(filtered_var_names, include_transformed=include_transformed)
323
- )
324
- fn = pytensor.function(model.value_vars, vars_to_sample)
325
- point_func = PointFunc(fn)
326
-
327
- coords, dims = coords_and_dims_for_inferencedata(model)
328
-
329
- transformed_result = apply_function_over_dataset(
330
- point_func,
331
- unstacked_draws,
332
- output_var_names=[x.name for x in vars_to_sample],
333
- coords=coords,
334
- dims=dims,
335
- progressbar=progressbar,
336
- )
337
-
338
- constrained_names = [
339
- x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
340
- ]
341
- all_varnames = [
342
- x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=True)
343
- ]
344
- unconstrained_names = sorted(set(all_varnames) - set(constrained_names))
345
-
346
- idata = az.InferenceData(posterior=transformed_result[constrained_names])
347
-
348
- if unconstrained_names and include_transformed:
349
- idata["unconstrained_posterior"] = transformed_result[unconstrained_names]
350
-
351
- return idata
@@ -16,9 +16,7 @@
16
16
  import logging
17
17
 
18
18
  from collections.abc import Callable
19
- from functools import partial
20
19
  from typing import Literal
21
- from typing import cast as type_cast
22
20
 
23
21
  import arviz as az
24
22
  import numpy as np
@@ -27,16 +25,18 @@ import pytensor
27
25
  import pytensor.tensor as pt
28
26
  import xarray as xr
29
27
 
28
+ from arviz import dict_to_dataset
30
29
  from better_optimize.constants import minimize_method
31
30
  from numpy.typing import ArrayLike
31
+ from pymc import Model
32
+ from pymc.backends.arviz import coords_and_dims_for_inferencedata
32
33
  from pymc.blocking import DictToArrayBijection
33
34
  from pymc.model.transform.optimization import freeze_dims_and_data
34
- from pymc.pytensorf import join_nonshared_inputs
35
- from pymc.util import get_default_varnames
35
+ from pymc.util import get_untransformed_name, is_transformed_name
36
36
  from pytensor.graph import vectorize_graph
37
37
  from pytensor.tensor import TensorVariable
38
38
  from pytensor.tensor.optimize import minimize
39
- from pytensor.tensor.type import Variable
39
+ from xarray import Dataset
40
40
 
41
41
  from pymc_extras.inference.laplace_approx.find_map import (
42
42
  _compute_inverse_hessian,
@@ -137,7 +137,7 @@ def get_conditional_gaussian_approximation(
137
137
  hess = pytensor.graph.replace.graph_replace(hess, {x: x0})
138
138
 
139
139
  # Full log(p(x | y, params)) using the Laplace approximation (up to a constant)
140
- _, logdetQ = pt.nlinalg.slogdet(Q)
140
+ _, logdetQ = pt.linalg.slogdet(Q)
141
141
  conditional_gaussian_approx = (
142
142
  -0.5 * x.T @ (-hess + Q) @ x + x.T @ (Q @ mu + jac - hess @ x0) + 0.5 * logdetQ
143
143
  )
@@ -147,138 +147,175 @@ def get_conditional_gaussian_approximation(
147
147
  return pytensor.function(args, [x0, conditional_gaussian_approx])
148
148
 
149
149
 
150
- def _unconstrained_vector_to_constrained_rvs(model):
151
- outputs = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
152
- constrained_names = [
153
- x.name for x in get_default_varnames(model.unobserved_value_vars, include_transformed=False)
154
- ]
155
- names = [x.name for x in outputs]
150
+ def unpack_last_axis(packed_input, packed_shapes):
151
+ if len(packed_shapes) == 1:
152
+ # Single case currently fails in unpack
153
+ return [pt.split_dims(packed_input, packed_shapes[0], axis=-1)]
156
154
 
157
- unconstrained_names = [name for name in names if name not in constrained_names]
155
+ keep_axes = tuple(range(packed_input.ndim))[:-1]
156
+ return pt.unpack(packed_input, keep_axes=keep_axes, packed_shapes=packed_shapes)
158
157
 
159
- new_outputs, unconstrained_vector = join_nonshared_inputs(
160
- model.initial_point(),
161
- inputs=model.value_vars,
162
- outputs=outputs,
163
- )
164
-
165
- constrained_rvs = [x for x, name in zip(new_outputs, names) if name in constrained_names]
166
- value_rvs = [x for x in new_outputs if x not in constrained_rvs]
167
-
168
- unconstrained_vector.name = "unconstrained_vector"
169
158
 
170
- # Redo the names list to ensure it is sorted to match the return order
171
- constrained_rvs_and_names = [(rv, name) for rv, name in zip(constrained_rvs, constrained_names)]
172
- value_rvs_and_names = [
173
- (rv, name) for rv, name in zip(value_rvs, names) for name in unconstrained_names
174
- ]
175
- # names = [*constrained_names, *unconstrained_names]
159
+ def draws_from_laplace_approx(
160
+ *,
161
+ mean,
162
+ covariance=None,
163
+ standard_deviation=None,
164
+ draws: int,
165
+ model: Model,
166
+ vectorize_draws: bool = True,
167
+ return_unconstrained: bool = True,
168
+ random_seed=None,
169
+ compile_kwargs: dict | None = None,
170
+ ) -> tuple[Dataset, Dataset | None]:
171
+ """
172
+ Generate draws from the Laplace approximation of the posterior.
176
173
 
177
- return constrained_rvs_and_names, value_rvs_and_names, unconstrained_vector
174
+ Parameters
175
+ ----------
176
+ mean : np.ndarray
177
+ The mean of the Laplace approximation (MAP estimate).
178
+ covariance : np.ndarray, optional
179
+ The covariance matrix of the Laplace approximation.
180
+ Mutually exclusive with `standard_deviation`.
181
+ standard_deviation : np.ndarray, optional
182
+ The standard deviation of the Laplace approximation (diagonal approximation).
183
+ Mutually exclusive with `covariance`.
184
+ draws : int
185
+ The number of draws.
186
+ model : pm.Model
187
+ The PyMC model.
188
+ vectorize_draws : bool, default True
189
+ Whether to vectorize the draws.
190
+ return_unconstrained : bool, default True
191
+ Whether to return the unconstrained draws in addition to the constrained ones.
192
+ random_seed : int, optional
193
+ Random seed for reproducibility.
194
+ compile_kwargs: dict, optional
195
+ Optional compile kwargs
178
196
 
197
+ Returns
198
+ -------
199
+ tuple[Dataset, Dataset | None]
200
+ A tuple containing the constrained draws (trace) and optionally the unconstrained draws.
201
+
202
+ Raises
203
+ ------
204
+ ValueError
205
+ If neither `covariance` nor `standard_deviation` is provided,
206
+ or if both are provided.
207
+ """
208
+ # This function assumes that mean/covariance/standard_deviation are aligned with model.initial_point()
209
+ if covariance is None and standard_deviation is None:
210
+ raise ValueError("Must specify either covariance or standard_deviation")
211
+ if covariance is not None and standard_deviation is not None:
212
+ raise ValueError("Cannot specify both covariance and standard_deviation")
213
+ if compile_kwargs is None:
214
+ compile_kwargs = {}
179
215
 
180
- def model_to_laplace_approx(
181
- model: pm.Model, unpacked_variable_names: list[str], chains: int = 1, draws: int = 500
182
- ):
183
216
  initial_point = model.initial_point()
184
- raveled_vars = DictToArrayBijection.map(initial_point)
185
- raveled_shape = raveled_vars.data.shape[0]
186
-
187
- # temp_chain and temp_draw are a hack to allow sampling from the Laplace approximation. We only have one mu and cov,
188
- # so we add batch dims (which correspond to chains and draws). But the names "chain" and "draw" are reserved.
189
-
190
- # The model was frozen during the find_MAP procedure. To ensure we're operating on the same model, freeze it again.
191
- frozen_model = freeze_dims_and_data(model)
192
- constrained_rvs_and_names, _, unconstrained_vector = _unconstrained_vector_to_constrained_rvs(
193
- frozen_model
217
+ n = int(np.sum([np.prod(v.shape) for v in initial_point.values()]))
218
+ assert mean.shape == (n,)
219
+ if covariance is not None:
220
+ assert covariance.shape == (n, n)
221
+ elif standard_deviation is not None:
222
+ assert standard_deviation.shape == (n,)
223
+
224
+ vars_to_sample = [v for v in model.free_RVs + model.deterministics]
225
+ var_names = [v.name for v in vars_to_sample]
226
+
227
+ orig_constrained_vars = model.value_vars
228
+ orig_outputs = model.replace_rvs_by_values(vars_to_sample)
229
+ if return_unconstrained:
230
+ orig_outputs.extend(model.value_vars)
231
+
232
+ mu_pt = pt.vector("mu", shape=(n,), dtype=mean.dtype)
233
+ size = (draws,) if vectorize_draws else ()
234
+ if covariance is not None:
235
+ sigma_pt = pt.matrix("cov", shape=(n, n), dtype=covariance.dtype)
236
+ laplace_approximation = pm.MvNormal.dist(mu=mu_pt, cov=sigma_pt, size=size, method="svd")
237
+ else:
238
+ sigma_pt = pt.vector("sigma", shape=(n,), dtype=standard_deviation.dtype)
239
+ laplace_approximation = pm.Normal.dist(mu=mu_pt, sigma=sigma_pt, size=(*size, n))
240
+
241
+ constrained_vars = unpack_last_axis(
242
+ laplace_approximation,
243
+ [initial_point[v.name].shape for v in orig_constrained_vars],
244
+ )
245
+ outputs = vectorize_graph(
246
+ orig_outputs, replace=dict(zip(orig_constrained_vars, constrained_vars))
194
247
  )
195
248
 
196
- coords = model.coords | {
197
- "temp_chain": np.arange(chains),
198
- "temp_draw": np.arange(draws),
199
- "unpacked_variable_names": unpacked_variable_names,
200
- }
201
-
202
- with pm.Model(coords=coords, model=None) as laplace_model:
203
- mu = pm.Flat("mean_vector", shape=(raveled_shape,))
204
- cov = pm.Flat("covariance_matrix", shape=(raveled_shape, raveled_shape))
205
- laplace_approximation = pm.MvNormal(
206
- "laplace_approximation",
207
- mu=mu,
208
- cov=cov,
209
- dims=["temp_chain", "temp_draw", "unpacked_variable_names"],
210
- method="svd",
211
- )
212
-
213
- cast_to_var = partial(type_cast, Variable)
214
- constrained_rvs, constrained_names = zip(*constrained_rvs_and_names)
215
- batched_rvs = vectorize_graph(
216
- type_cast(list[Variable], constrained_rvs),
217
- replace={cast_to_var(unconstrained_vector): cast_to_var(laplace_approximation)},
218
- )
219
-
220
- for name, batched_rv in zip(constrained_names, batched_rvs):
221
- batch_dims = ("temp_chain", "temp_draw")
222
- if batched_rv.ndim == 2:
223
- dims = batch_dims
224
- elif name in model.named_vars_to_dims:
225
- dims = (*batch_dims, *model.named_vars_to_dims[name])
226
- else:
227
- dims = (*batch_dims, *[f"{name}_dim_{i}" for i in range(batched_rv.ndim - 2)])
228
- initval = initial_point.get(name, None)
229
- dim_shapes = initval.shape if initval is not None else batched_rv.type.shape[2:]
230
- laplace_model.add_coords(
231
- {name: np.arange(shape) for name, shape in zip(dims[2:], dim_shapes)}
232
- )
233
-
234
- pm.Deterministic(name, batched_rv, dims=dims)
235
-
236
- return laplace_model
237
-
238
-
239
- def unstack_laplace_draws(laplace_data, model, chains=2, draws=500):
240
- """
241
- The `model_to_laplace_approx` function returns a model with a single MvNormal distribution, draws from which are
242
- in the unconstrained variable space. These might be interesting to the user, but since they come back stacked in a
243
- single vector, it's not easy to work with.
244
-
245
- This function unpacks each component of the vector into its own DataArray, with the appropriate dimensions and
246
- coordinates, where possible.
247
- """
248
- initial_point = DictToArrayBijection.map(model.initial_point())
249
-
250
- cursor = 0
251
- unstacked_laplace_draws = {}
252
- coords = model.coords | {"chain": range(chains), "draw": range(draws)}
253
-
254
- # There are corner cases where the value_vars will not have the same dimensions as the random variable (e.g.
255
- # simplex transform of a Dirichlet). In these cases, we don't try to guess what the labels should be, and just
256
- # add an arviz-style default dim and label.
257
- for rv, (name, shape, size, dtype) in zip(model.free_RVs, initial_point.point_map_info):
258
- rv_dims = []
259
- for i, dim in enumerate(
260
- model.named_vars_to_dims.get(rv.name, [f"{name}_dim_{i}" for i in range(len(shape))])
261
- ):
262
- if coords.get(dim) and shape[i] == len(coords[dim]):
263
- rv_dims.append(dim)
264
- else:
265
- rv_dims.append(f"{name}_dim_{i}")
266
- coords[f"{name}_dim_{i}"] = np.arange(shape[i])
267
-
268
- dims = ("chain", "draw", *rv_dims)
269
-
270
- values = (
271
- laplace_data[..., cursor : cursor + size].reshape((chains, draws, *shape)).astype(dtype)
249
+ fn = pm.pytensorf.compile(
250
+ [mu_pt, sigma_pt],
251
+ outputs,
252
+ random_seed=random_seed,
253
+ trust_input=True,
254
+ **compile_kwargs,
255
+ )
256
+ sigma = covariance if covariance is not None else standard_deviation
257
+ if vectorize_draws:
258
+ output_buffers = fn(mean, sigma)
259
+ else:
260
+ # Take one draw to find the shape of the outputs
261
+ output_buffers = []
262
+ for out_draw in fn(mean, sigma):
263
+ output_buffer = np.empty((draws, *out_draw.shape), dtype=out_draw.dtype)
264
+ output_buffer[0] = out_draw
265
+ output_buffers.append(output_buffer)
266
+ # Fill one draws at a time
267
+ for i in range(1, draws):
268
+ for out_buffer, out_draw in zip(output_buffers, fn(mean, sigma)):
269
+ out_buffer[i] = out_draw
270
+
271
+ model_coords, model_dims = coords_and_dims_for_inferencedata(model)
272
+ posterior = {
273
+ var_name: out_buffer[None]
274
+ for var_name, out_buffer in (
275
+ zip(var_names, output_buffers, strict=not return_unconstrained)
272
276
  )
273
- unstacked_laplace_draws[name] = xr.DataArray(
274
- values, dims=dims, coords={dim: list(coords[dim]) for dim in dims}
277
+ }
278
+ posterior_dataset = dict_to_dataset(posterior, coords=model_coords, dims=model_dims, library=pm)
279
+ unconstrained_posterior_dataset = None
280
+
281
+ if return_unconstrained:
282
+ unconstrained_posterior = {
283
+ var.name: out_buffer[None]
284
+ for var, out_buffer in zip(
285
+ model.value_vars, output_buffers[len(posterior) :], strict=True
286
+ )
287
+ }
288
+ # Attempt to map constrained dims to unconstrained dims
289
+ for var_name, var_draws in unconstrained_posterior.items():
290
+ if not is_transformed_name(var_name):
291
+ # constrained == unconstrained, dims already shared
292
+ continue
293
+ constrained_dims = model_dims.get(get_untransformed_name(var_name))
294
+ if constrained_dims is None or (len(constrained_dims) != (var_draws.ndim - 2)):
295
+ continue
296
+ # Reuse dims from constrained variable if they match in length with unconstrained draws
297
+ inferred_dims = []
298
+ for i, (constrained_dim, unconstrained_dim_length) in enumerate(
299
+ zip(constrained_dims, var_draws.shape[2:], strict=True)
300
+ ):
301
+ if model_coords.get(constrained_dim) is not None and (
302
+ len(model_coords[constrained_dim]) == unconstrained_dim_length
303
+ ):
304
+ # Assume coordinates map. This could be fooled, by e.g., having a transform that reverses values
305
+ inferred_dims.append(constrained_dim)
306
+ else:
307
+ # Size mismatch (e.g., Simplex), make no assumption about mapping
308
+ inferred_dims.append(f"{var_name}_dim_{i}")
309
+ model_dims[var_name] = inferred_dims
310
+
311
+ unconstrained_posterior_dataset = dict_to_dataset(
312
+ unconstrained_posterior,
313
+ coords=model_coords,
314
+ dims=model_dims,
315
+ library=pm,
275
316
  )
276
317
 
277
- cursor += size
278
-
279
- unstacked_laplace_draws = xr.Dataset(unstacked_laplace_draws)
280
-
281
- return unstacked_laplace_draws
318
+ return posterior_dataset, unconstrained_posterior_dataset
282
319
 
283
320
 
284
321
  def fit_laplace(
@@ -295,8 +332,9 @@ def fit_laplace(
295
332
  include_transformed: bool = True,
296
333
  freeze_model: bool = True,
297
334
  gradient_backend: GradientBackend = "pytensor",
298
- chains: int = 2,
335
+ chains: None | int = None,
299
336
  draws: int = 500,
337
+ vectorize_draws: bool = True,
300
338
  optimizer_kwargs: dict | None = None,
301
339
  compile_kwargs: dict | None = None,
302
340
  ) -> az.InferenceData:
@@ -343,16 +381,14 @@ def fit_laplace(
343
381
  True.
344
382
  gradient_backend: str, default "pytensor"
345
383
  The backend to use for gradient computations. Must be one of "pytensor" or "jax".
346
- chains: int, default: 2
347
- The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
348
- because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
349
- compatible with the ArviZ library.
350
384
  draws: int, default: 500
351
- The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
385
+ The number of samples to draw from the approximated posterior.
352
386
  optimizer_kwargs
353
387
  Additional keyword arguments to pass to the ``scipy.optimize`` function being used. Unless
354
388
  ``method = "basinhopping"``, ``scipy.optimize.minimize`` will be used. For ``basinhopping``,
355
389
  ``scipy.optimize.basinhopping`` will be used. See the documentation of these functions for details.
390
+ vectorize_draws: bool, default True
391
+ Whether to natively vectorize the random function or take one at a time in a python loop.
356
392
  compile_kwargs: dict, optional
357
393
  Additional keyword arguments to pass to pytensor.function.
358
394
 
@@ -385,6 +421,12 @@ def fit_laplace(
385
421
  will forward the call to 'fit_laplace'.
386
422
 
387
423
  """
424
+ if chains is not None:
425
+ raise ValueError(
426
+ "chains argument has been deprecated. "
427
+ "The behavior can be recreated by unstacking draws into multiple chains after fitting"
428
+ )
429
+
388
430
  compile_kwargs = {} if compile_kwargs is None else compile_kwargs
389
431
  optimizer_kwargs = {} if optimizer_kwargs is None else optimizer_kwargs
390
432
  model = pm.modelcontext(model) if model is None else model
@@ -410,11 +452,10 @@ def fit_laplace(
410
452
  **optimizer_kwargs,
411
453
  )
412
454
 
413
- unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
414
-
415
455
  if "covariance_matrix" not in idata.fit:
416
456
  # The user didn't use `use_hess` or `use_hessp` (or an optimization method that returns an inverse Hessian), so
417
457
  # we have to go back and compute the Hessian at the MAP point now.
458
+ unpacked_variable_names = idata.fit["mean_vector"].coords["rows"].values.tolist()
418
459
  frozen_model = freeze_dims_and_data(model)
419
460
  initial_params = _make_initial_point(frozen_model, initvals, random_seed, jitter_rvs)
420
461
 
@@ -443,29 +484,16 @@ def fit_laplace(
443
484
  coords={"rows": unpacked_variable_names, "columns": unpacked_variable_names},
444
485
  )
445
486
 
446
- with model_to_laplace_approx(model, unpacked_variable_names, chains, draws) as laplace_model:
447
- new_posterior = (
448
- pm.sample_posterior_predictive(
449
- idata.fit.expand_dims(chain=[0], draw=[0]),
450
- extend_inferencedata=False,
451
- random_seed=random_seed,
452
- var_names=[
453
- "laplace_approximation",
454
- *[x.name for x in laplace_model.deterministics],
455
- ],
456
- )
457
- .posterior_predictive.squeeze(["chain", "draw"])
458
- .drop_vars(["chain", "draw"])
459
- .rename({"temp_chain": "chain", "temp_draw": "draw"})
460
- )
461
-
462
- if include_transformed:
463
- idata.unconstrained_posterior = unstack_laplace_draws(
464
- new_posterior.laplace_approximation.values, model, chains=chains, draws=draws
465
- )
466
-
467
- idata.posterior = new_posterior.drop_vars(
468
- ["laplace_approximation", "unpacked_variable_names"]
469
- )
470
-
487
+ # We override the posterior/unconstrained_posterior from find_MAP
488
+ idata.posterior, unconstrained_posterior = draws_from_laplace_approx(
489
+ mean=idata.fit["mean_vector"].values,
490
+ covariance=idata.fit["covariance_matrix"].values,
491
+ draws=draws,
492
+ return_unconstrained=include_transformed,
493
+ model=model,
494
+ vectorize_draws=vectorize_draws,
495
+ random_seed=random_seed,
496
+ )
497
+ if include_transformed:
498
+ idata.unconstrained_posterior = unconstrained_posterior
471
499
  return idata