pymc-extras 0.5.0__tar.gz → 0.7.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (175) hide show
  1. pymc_extras-0.7.0/CONTRIBUTING.md +24 -0
  2. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/PKG-INFO +4 -4
  3. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/_version.py +2 -2
  4. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/conda-envs/environment-test.yml +3 -3
  5. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/api_reference.rst +1 -0
  6. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/statespace/models.rst +2 -1
  7. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/deserialize.py +10 -4
  8. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/continuous.py +1 -1
  9. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/histogram_utils.py +6 -4
  10. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/multivariate/r2d2m2cp.py +4 -3
  11. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/timeseries.py +14 -12
  12. pymc_extras-0.7.0/pymc_extras/inference/dadvi/dadvi.py +282 -0
  13. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/laplace_approx/find_map.py +16 -39
  14. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/laplace_approx/idata.py +22 -4
  15. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/laplace_approx/laplace.py +196 -151
  16. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/laplace_approx/scipy_interface.py +47 -7
  17. pymc_extras-0.7.0/pymc_extras/inference/pathfinder/idata.py +517 -0
  18. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/pathfinder/pathfinder.py +71 -12
  19. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/smc/sampling.py +2 -2
  20. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/marginal/distributions.py +4 -2
  21. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/marginal/graph_analysis.py +2 -2
  22. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/marginal/marginal_model.py +12 -2
  23. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model_builder.py +9 -4
  24. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/prior.py +203 -8
  25. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/core/compile.py +1 -1
  26. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/core/statespace.py +2 -1
  27. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/filters/distributions.py +15 -13
  28. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/filters/kalman_filter.py +24 -22
  29. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/filters/kalman_smoother.py +3 -5
  30. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/filters/utilities.py +2 -5
  31. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/DFM.py +12 -27
  32. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/ETS.py +190 -198
  33. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/SARIMAX.py +5 -17
  34. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/VARMAX.py +15 -67
  35. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/autoregressive.py +4 -4
  36. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/regression.py +4 -26
  37. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/utilities.py +7 -0
  38. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/utils/model_equivalence.py +2 -2
  39. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/utils/prior.py +10 -14
  40. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/utils/spline.py +4 -10
  41. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pyproject.toml +19 -15
  42. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/test_continuous.py +4 -0
  43. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/test_discrete.py +8 -5
  44. pymc_extras-0.7.0/tests/inference/dadvi/test_dadvi.py +177 -0
  45. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/inference/laplace_approx/test_laplace.py +33 -21
  46. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/model/marginal/test_distributions.py +1 -1
  47. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/model/marginal/test_graph_analysis.py +1 -1
  48. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/model/marginal/test_marginal_model.py +21 -8
  49. pymc_extras-0.7.0/tests/pathfinder/test_idata.py +489 -0
  50. {pymc_extras-0.5.0/tests → pymc_extras-0.7.0/tests/pathfinder}/test_pathfinder.py +14 -15
  51. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/core/test_statespace.py +3 -5
  52. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/core/test_statespace_JAX.py +9 -9
  53. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/filters/test_distributions.py +2 -2
  54. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/filters/test_kalman_filter.py +47 -42
  55. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_autoregressive.py +9 -1
  56. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_cycle.py +1 -1
  57. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_measurement_error.py +1 -1
  58. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_seasonality.py +1 -1
  59. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/test_DFM.py +6 -13
  60. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/test_ETS.py +14 -10
  61. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/test_SARIMAX.py +11 -10
  62. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/test_VARMAX.py +108 -197
  63. pymc_extras-0.7.0/tests/statespace/utils/__init__.py +0 -0
  64. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_histogram_approximation.py +1 -0
  65. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_splines.py +17 -1
  66. pymc_extras-0.7.0/tests/utils.py +0 -0
  67. pymc_extras-0.5.0/CONTRIBUTING.md +0 -3
  68. pymc_extras-0.5.0/pymc_extras/inference/dadvi/dadvi.py +0 -261
  69. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/.gitignore +0 -0
  70. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/.gitpod.yml +0 -0
  71. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/.pre-commit-config.yaml +0 -0
  72. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/.readthedocs.yaml +0 -0
  73. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/CODE_OF_CONDUCT.md +0 -0
  74. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/LICENSE +0 -0
  75. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/README.md +0 -0
  76. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/codecov.yml +0 -0
  77. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/.nojekyll +0 -0
  78. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/Makefile +0 -0
  79. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/_templates/autosummary/base.rst +0 -0
  80. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/_templates/autosummary/class.rst +0 -0
  81. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/conf.py +0 -0
  82. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/index.rst +0 -0
  83. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/make.bat +0 -0
  84. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/statespace/core.rst +0 -0
  85. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/statespace/filters.rst +0 -0
  86. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/docs/statespace/models/structural.rst +0 -0
  87. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/__init__.py +0 -0
  88. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/__init__.py +0 -0
  89. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/discrete.py +0 -0
  90. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/multivariate/__init__.py +0 -0
  91. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/transforms/__init__.py +0 -0
  92. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/distributions/transforms/partial_order.py +0 -0
  93. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/gp/__init__.py +0 -0
  94. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/gp/latent_approx.py +0 -0
  95. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/__init__.py +0 -0
  96. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/dadvi/__init__.py +0 -0
  97. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/fit.py +0 -0
  98. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/laplace_approx/__init__.py +0 -0
  99. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/pathfinder/__init__.py +0 -0
  100. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/pathfinder/importance_sampling.py +0 -0
  101. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/pathfinder/lbfgs.py +0 -0
  102. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/inference/smc/__init__.py +0 -0
  103. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/linearmodel.py +0 -0
  104. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/__init__.py +0 -0
  105. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/marginal/__init__.py +0 -0
  106. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/model_api.py +0 -0
  107. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/transforms/__init__.py +0 -0
  108. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/model/transforms/autoreparam.py +0 -0
  109. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/preprocessing/__init__.py +0 -0
  110. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/preprocessing/standard_scaler.py +0 -0
  111. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/printing.py +0 -0
  112. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/__init__.py +0 -0
  113. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/core/__init__.py +0 -0
  114. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/core/representation.py +0 -0
  115. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/filters/__init__.py +0 -0
  116. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/__init__.py +0 -0
  117. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/__init__.py +0 -0
  118. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/__init__.py +0 -0
  119. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/cycle.py +0 -0
  120. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/level_trend.py +0 -0
  121. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/measurement_error.py +0 -0
  122. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/components/seasonality.py +0 -0
  123. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/core.py +0 -0
  124. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/models/structural/utils.py +0 -0
  125. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/utils/__init__.py +0 -0
  126. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/utils/constants.py +0 -0
  127. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/utils/coord_tools.py +0 -0
  128. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/statespace/utils/data_tools.py +0 -0
  129. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/utils/__init__.py +0 -0
  130. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/pymc_extras/utils/linear_cg.py +0 -0
  131. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/__init__.py +0 -0
  132. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/conftest.py +0 -0
  133. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/__init__.py +0 -0
  134. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/test_discrete_markov_chain.py +0 -0
  135. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/test_multivariate.py +0 -0
  136. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/distributions/test_transform.py +0 -0
  137. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/inference/__init__.py +0 -0
  138. {pymc_extras-0.5.0/tests/inference/laplace_approx → pymc_extras-0.7.0/tests/inference/dadvi}/__init__.py +0 -0
  139. {pymc_extras-0.5.0/tests/model → pymc_extras-0.7.0/tests/inference/laplace_approx}/__init__.py +0 -0
  140. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/inference/laplace_approx/test_find_map.py +1 -1
  141. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/inference/laplace_approx/test_idata.py +0 -0
  142. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/inference/laplace_approx/test_scipy_interface.py +0 -0
  143. {pymc_extras-0.5.0/tests/model/marginal → pymc_extras-0.7.0/tests/model}/__init__.py +0 -0
  144. {pymc_extras-0.5.0/tests/statespace → pymc_extras-0.7.0/tests/model/marginal}/__init__.py +0 -0
  145. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/model/test_model_api.py +0 -0
  146. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/model/transforms/test_autoreparam.py +0 -0
  147. {pymc_extras-0.5.0/tests/statespace/core → pymc_extras-0.7.0/tests/pathfinder}/__init__.py +0 -0
  148. {pymc_extras-0.5.0/tests/statespace/filters → pymc_extras-0.7.0/tests/statespace}/__init__.py +0 -0
  149. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/_data/airpass.csv +0 -0
  150. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/_data/airpassangers.csv +0 -0
  151. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/_data/nile.csv +0 -0
  152. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/_data/statsmodels_macrodata_processed.csv +0 -0
  153. {pymc_extras-0.5.0/tests/statespace/models → pymc_extras-0.7.0/tests/statespace/core}/__init__.py +0 -0
  154. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/core/test_representation.py +0 -0
  155. {pymc_extras-0.5.0/tests/statespace/models/structural → pymc_extras-0.7.0/tests/statespace/filters}/__init__.py +0 -0
  156. {pymc_extras-0.5.0/tests/statespace/models/structural/components → pymc_extras-0.7.0/tests/statespace/models}/__init__.py +0 -0
  157. {pymc_extras-0.5.0/tests/statespace/utils → pymc_extras-0.7.0/tests/statespace/models/structural}/__init__.py +0 -0
  158. /pymc_extras-0.5.0/tests/utils.py → /pymc_extras-0.7.0/tests/statespace/models/structural/components/__init__.py +0 -0
  159. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_level_trend.py +0 -0
  160. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/components/test_regression.py +0 -0
  161. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/conftest.py +0 -0
  162. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/test_against_statsmodels.py +0 -0
  163. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/structural/test_core.py +0 -0
  164. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/models/test_utilities.py +0 -0
  165. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/shared_fixtures.py +0 -0
  166. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/statsmodel_local_level.py +0 -0
  167. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/test_utilities.py +0 -0
  168. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/statespace/utils/test_coord_assignment.py +0 -0
  169. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_blackjax_smc.py +0 -0
  170. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_deserialize.py +0 -0
  171. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_linearmodel.py +0 -0
  172. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_model_builder.py +0 -0
  173. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_printing.py +0 -0
  174. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_prior.py +0 -0
  175. {pymc_extras-0.5.0 → pymc_extras-0.7.0}/tests/test_prior_from_trace.py +0 -0
@@ -0,0 +1,24 @@
1
+ # Contributing guide
2
+
3
+ Page in construction, for now go to https://github.com/pymc-devs/pymc-extras#questions.
4
+
5
+ ## Building the documentation
6
+
7
+ To build the documentation locally, you need to install the necessary
8
+ dependencies and then use `make` to build the HTML files.
9
+
10
+ First, install the package with the optional documentation dependencies:
11
+
12
+ ```bash
13
+ pip install ".[docs]"
14
+ ```
15
+
16
+ Then, navigate to the `docs` directory and run `make html`:
17
+
18
+ ```bash
19
+ cd docs
20
+ make html
21
+ ```
22
+
23
+ The generated HTML files will be in the `docs/_build/html` directory. You can
24
+ open the `index.html` file in that directory to view the documentation.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pymc-extras
3
- Version: 0.5.0
3
+ Version: 0.7.0
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Project-URL: Documentation, https://pymc-extras.readthedocs.io/
6
6
  Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
@@ -235,8 +235,8 @@ Requires-Python: >=3.11
235
235
  Requires-Dist: better-optimize>=0.1.5
236
236
  Requires-Dist: preliz>=0.20.0
237
237
  Requires-Dist: pydantic>=2.0.0
238
- Requires-Dist: pymc>=5.24.1
239
- Requires-Dist: pytensor>=2.31.4
238
+ Requires-Dist: pymc>=5.27.0
239
+ Requires-Dist: pytensor>=2.36.3
240
240
  Requires-Dist: scikit-learn
241
241
  Provides-Extra: complete
242
242
  Requires-Dist: dask[complete]<2025.1.1; extra == 'complete'
@@ -245,7 +245,7 @@ Provides-Extra: dask-histogram
245
245
  Requires-Dist: dask[complete]<2025.1.1; extra == 'dask-histogram'
246
246
  Requires-Dist: xhistogram; extra == 'dask-histogram'
247
247
  Provides-Extra: dev
248
- Requires-Dist: blackjax; extra == 'dev'
248
+ Requires-Dist: blackjax>=0.12; extra == 'dev'
249
249
  Requires-Dist: dask[all]<2025.1.1; extra == 'dev'
250
250
  Requires-Dist: pytest-mock; extra == 'dev'
251
251
  Requires-Dist: pytest>=6.0; extra == 'dev'
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.5.0'
32
- __version_tuple__ = version_tuple = (0, 5, 0)
31
+ __version__ = version = '0.7.0'
32
+ __version_tuple__ = version_tuple = (0, 7, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,10 +1,10 @@
1
- name: pymc-extras-test
1
+ name: pymc-extras
2
2
  channels:
3
3
  - conda-forge
4
4
  - nodefaults
5
5
  dependencies:
6
- - pymc>=5.24.1
7
- - pytensor>=2.31.4
6
+ - pymc>=5.27.0
7
+ - pytensor>=2.36.3
8
8
  - scikit-learn
9
9
  - better-optimize>=0.1.5
10
10
  - dask<2025.1.1
@@ -56,6 +56,7 @@ Prior
56
56
  create_dim_handler
57
57
  handle_dims
58
58
  Prior
59
+ register_tensor_transform
59
60
  VariableFactory
60
61
  sample_prior
61
62
  Censored
@@ -6,7 +6,8 @@ Statespace Models
6
6
  .. autosummary::
7
7
  :toctree: generated
8
8
 
9
- BayesianSARIMA
9
+ BayesianETS
10
+ BayesianSARIMAX
10
11
  BayesianVARMAX
11
12
 
12
13
  *********************
@@ -13,10 +13,7 @@ Make use of the already registered deserializers:
13
13
 
14
14
  from pymc_extras.deserialize import deserialize
15
15
 
16
- prior_class_data = {
17
- "dist": "Normal",
18
- "kwargs": {"mu": 0, "sigma": 1}
19
- }
16
+ prior_class_data = {"dist": "Normal", "kwargs": {"mu": 0, "sigma": 1}}
20
17
  prior = deserialize(prior_class_data)
21
18
  # Prior("Normal", mu=0, sigma=1)
22
19
 
@@ -26,6 +23,7 @@ Register custom class deserialization:
26
23
 
27
24
  from pymc_extras.deserialize import register_deserialization
28
25
 
26
+
29
27
  class MyClass:
30
28
  def __init__(self, value: int):
31
29
  self.value = value
@@ -34,6 +32,7 @@ Register custom class deserialization:
34
32
  # Example of what the to_dict method might look like.
35
33
  return {"value": self.value}
36
34
 
35
+
37
36
  register_deserialization(
38
37
  is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
39
38
  deserialize=lambda data: MyClass(value=data["value"]),
@@ -80,18 +79,23 @@ class Deserializer:
80
79
 
81
80
  from typing import Any
82
81
 
82
+
83
83
  class MyClass:
84
84
  def __init__(self, value: int):
85
85
  self.value = value
86
86
 
87
+
87
88
  from pymc_extras.deserialize import Deserializer
88
89
 
90
+
89
91
  def is_type(data: Any) -> bool:
90
92
  return data.keys() == {"value"} and isinstance(data["value"], int)
91
93
 
94
+
92
95
  def deserialize(data: dict) -> MyClass:
93
96
  return MyClass(value=data["value"])
94
97
 
98
+
95
99
  deserialize_logic = Deserializer(is_type=is_type, deserialize=deserialize)
96
100
 
97
101
  """
@@ -196,6 +200,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
196
200
 
197
201
  from pymc_extras.deserialize import register_deserialization
198
202
 
203
+
199
204
  class MyClass:
200
205
  def __init__(self, value: int):
201
206
  self.value = value
@@ -204,6 +209,7 @@ def register_deserialization(is_type: IsType, deserialize: Deserialize) -> None:
204
209
  # Example of what the to_dict method might look like.
205
210
  return {"value": self.value}
206
211
 
212
+
207
213
  register_deserialization(
208
214
  is_type=lambda data: data.keys() == {"value"} and isinstance(data["value"], int),
209
215
  deserialize=lambda data: MyClass(value=data["value"]),
@@ -265,7 +265,7 @@ class Chi:
265
265
  from pymc_extras.distributions import Chi
266
266
 
267
267
  with pm.Model():
268
- x = Chi('x', nu=1)
268
+ x = Chi("x", nu=1)
269
269
  """
270
270
 
271
271
  @staticmethod
@@ -130,8 +130,7 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
130
130
  ... m = pm.Normal("m", dims="tests")
131
131
  ... s = pm.LogNormal("s", dims="tests")
132
132
  ... pot = pmx.distributions.histogram_approximation(
133
- ... "pot", pm.Normal.dist(m, s),
134
- ... observed=measurements, n_quantiles=50
133
+ ... "pot", pm.Normal.dist(m, s), observed=measurements, n_quantiles=50
135
134
  ... )
136
135
 
137
136
  For special cases like Zero Inflation in Continuous variables there is a flag.
@@ -143,8 +142,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
143
142
  ... m = pm.Normal("m", dims="tests")
144
143
  ... s = pm.LogNormal("s", dims="tests")
145
144
  ... pot = pmx.distributions.histogram_approximation(
146
- ... "pot", pm.Normal.dist(m, s),
147
- ... observed=measurements, n_quantiles=50, zero_inflation=True
145
+ ... "pot",
146
+ ... pm.Normal.dist(m, s),
147
+ ... observed=measurements,
148
+ ... n_quantiles=50,
149
+ ... zero_inflation=True,
148
150
  ... )
149
151
  """
150
152
  try:
@@ -305,6 +305,7 @@ def R2D2M2CP(
305
305
  import pymc_extras as pmx
306
306
  import pymc as pm
307
307
  import numpy as np
308
+
308
309
  X = np.random.randn(10, 3)
309
310
  b = np.random.randn(3)
310
311
  y = X @ b + np.random.randn(10) * 0.04 + 5
@@ -339,7 +340,7 @@ def R2D2M2CP(
339
340
  # "c" - a must have in the relation
340
341
  variables_importance=[10, 1, 34],
341
342
  # NOTE: try both
342
- centered=True
343
+ centered=True,
343
344
  )
344
345
  # intercept prior centering should be around prior predictive mean
345
346
  intercept = y.mean()
@@ -365,7 +366,7 @@ def R2D2M2CP(
365
366
  r2_std=0.2,
366
367
  # NOTE: if you know where a variable should go
367
368
  # if you do not know, leave as 0.5
368
- centered=False
369
+ centered=False,
369
370
  )
370
371
  # intercept prior centering should be around prior predictive mean
371
372
  intercept = y.mean()
@@ -394,7 +395,7 @@ def R2D2M2CP(
394
395
  # if you do not know, leave as 0.5
395
396
  positive_probs=[0.8, 0.5, 0.1],
396
397
  # NOTE: try both
397
- centered=True
398
+ centered=True,
398
399
  )
399
400
  intercept = y.mean()
400
401
  obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
@@ -113,8 +113,10 @@ class DiscreteMarkovChain(Distribution):
113
113
 
114
114
  with pm.Model() as markov_chain:
115
115
  P = pm.Dirichlet("P", a=[1, 1, 1], size=(3,))
116
- init_dist = pm.Categorical.dist(p = np.full(3, 1 / 3))
117
- markov_chain = pmx.DiscreteMarkovChain("markov_chain", P=P, init_dist=init_dist, shape=(100,))
116
+ init_dist = pm.Categorical.dist(p=np.full(3, 1 / 3))
117
+ markov_chain = pmx.DiscreteMarkovChain(
118
+ "markov_chain", P=P, init_dist=init_dist, shape=(100,)
119
+ )
118
120
 
119
121
  """
120
122
 
@@ -194,21 +196,20 @@ class DiscreteMarkovChain(Distribution):
194
196
  state_rng = pytensor.shared(np.random.default_rng())
195
197
 
196
198
  def transition(*args):
197
- *states, transition_probs, old_rng = args
199
+ old_rng, *states, transition_probs = args
198
200
  p = transition_probs[tuple(states)]
199
201
  next_rng, next_state = pm.Categorical.dist(p=p, rng=old_rng).owner.outputs
200
- return next_state, {old_rng: next_rng}
202
+ return next_rng, next_state
201
203
 
202
- markov_chain, state_updates = pytensor.scan(
204
+ state_next_rng, markov_chain = pytensor.scan(
203
205
  transition,
204
- non_sequences=[P_, state_rng],
205
- outputs_info=_make_outputs_info(n_lags, init_dist_),
206
+ outputs_info=[state_rng, *_make_outputs_info(n_lags, init_dist_)],
207
+ non_sequences=[P_],
206
208
  n_steps=steps_,
207
209
  strict=True,
210
+ return_updates=False,
208
211
  )
209
212
 
210
- (state_next_rng,) = tuple(state_updates.values())
211
-
212
213
  discrete_mc_ = pt.moveaxis(pt.concatenate([init_dist_, markov_chain], axis=0), 0, -1)
213
214
 
214
215
  discrete_mc_op = DiscreteMarkovChainRV(
@@ -237,16 +238,17 @@ def discrete_mc_moment(op, rv, P, steps, init_dist, state_rng):
237
238
  n_lags = op.n_lags
238
239
 
239
240
  def greedy_transition(*args):
240
- *states, transition_probs, old_rng = args
241
+ *states, transition_probs = args
241
242
  p = transition_probs[tuple(states)]
242
243
  return pt.argmax(p)
243
244
 
244
- chain_moment, moment_updates = pytensor.scan(
245
+ chain_moment = pytensor.scan(
245
246
  greedy_transition,
246
- non_sequences=[P, state_rng],
247
+ non_sequences=[P],
247
248
  outputs_info=_make_outputs_info(n_lags, init_dist),
248
249
  n_steps=steps,
249
250
  strict=True,
251
+ return_updates=False,
250
252
  )
251
253
  chain_moment = pt.concatenate([init_dist_moment, chain_moment])
252
254
  return chain_moment
@@ -0,0 +1,282 @@
1
+ import arviz as az
2
+ import numpy as np
3
+ import pymc
4
+ import pytensor
5
+ import pytensor.tensor as pt
6
+
7
+ from arviz import InferenceData
8
+ from better_optimize import basinhopping, minimize
9
+ from better_optimize.constants import minimize_method
10
+ from pymc import DictToArrayBijection, Model, join_nonshared_inputs
11
+ from pymc.blocking import RaveledVars
12
+ from pymc.util import RandomSeed
13
+ from pytensor.tensor.variable import TensorVariable
14
+
15
+ from pymc_extras.inference.laplace_approx.idata import (
16
+ add_data_to_inference_data,
17
+ add_optimizer_result_to_inference_data,
18
+ )
19
+ from pymc_extras.inference.laplace_approx.laplace import draws_from_laplace_approx
20
+ from pymc_extras.inference.laplace_approx.scipy_interface import (
21
+ scipy_optimize_funcs_from_loss,
22
+ set_optimizer_function_defaults,
23
+ )
24
+
25
+
26
+ def fit_dadvi(
27
+ model: Model | None = None,
28
+ n_fixed_draws: int = 30,
29
+ n_draws: int = 1000,
30
+ include_transformed: bool = False,
31
+ optimizer_method: minimize_method = "trust-ncg",
32
+ use_grad: bool | None = None,
33
+ use_hessp: bool | None = None,
34
+ use_hess: bool | None = None,
35
+ gradient_backend: str = "pytensor",
36
+ compile_kwargs: dict | None = None,
37
+ random_seed: RandomSeed = None,
38
+ progressbar: bool = True,
39
+ **optimizer_kwargs,
40
+ ) -> az.InferenceData:
41
+ """
42
+ Does inference using Deterministic ADVI (Automatic Differentiation Variational Inference), DADVI for short.
43
+
44
+ For full details see the paper cited in the references: https://www.jmlr.org/papers/v25/23-1015.html
45
+
46
+ Parameters
47
+ ----------
48
+ model : pm.Model
49
+ The PyMC model to be fit. If None, the current model context is used.
50
+
51
+ n_fixed_draws : int
52
+ The number of fixed draws to use for the optimisation. More draws will result in more accurate estimates, but
53
+ also increase inference time. Usually, the default of 30 is a good tradeoff between speed and accuracy.
54
+
55
+ random_seed: int
56
+ The random seed to use for the fixed draws. Running the optimisation twice with the same seed should arrive at
57
+ the same result.
58
+
59
+ n_draws: int
60
+ The number of draws to return from the variational approximation.
61
+
62
+ include_transformed: bool
63
+ Whether or not to keep the unconstrained variables (such as logs of positive-constrained parameters) in the
64
+ output.
65
+
66
+ optimizer_method: str
67
+ Which optimization method to use. The function calls ``scipy.optimize.minimize``, so any of the methods there
68
+ can be used. The default is trust-ncg, which uses second-order information and is generally very reliable.
69
+ Other methods such as L-BFGS-B might be faster but potentially more brittle and may not converge exactly to
70
+ the optimum.
71
+
72
+ gradient_backend: str
73
+ Which backend to use to compute gradients. Must be one of "jax" or "pytensor". Default is "pytensor".
74
+
75
+ compile_kwargs: dict, optional
76
+ Additional keyword arguments to pass to `pytensor.function`
77
+
78
+ use_grad: bool, optional
79
+ If True, pass the gradient function to `scipy.optimize.minimize` (where it is referred to as `jac`).
80
+
81
+ use_hessp: bool, optional
82
+ If True, pass the hessian vector product to `scipy.optimize.minimize`.
83
+
84
+ use_hess: bool, optional
85
+ If True, pass the hessian to `scipy.optimize.minimize`. Note that this is generally not recommended since its
86
+ computation can be slow and memory-intensive if there are many parameters.
87
+
88
+ progressbar: bool
89
+ Whether or not to show a progress bar during optimization. Default is True.
90
+
91
+ optimizer_kwargs:
92
+ Additional keyword arguments to pass to the ``scipy.optimize.minimize`` function. See the documentation of
93
+ that function for details.
94
+
95
+ Returns
96
+ -------
97
+ :class:`~arviz.InferenceData`
98
+ The inference data containing the results of the DADVI algorithm.
99
+
100
+ References
101
+ ----------
102
+ Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box Variational Inference with a Deterministic Objective:
103
+ Faster, More Accurate, and Even More Black Box. Journal of Machine Learning Research, 25(18), 1–39.
104
+ """
105
+
106
+ model = pymc.modelcontext(model) if model is None else model
107
+ do_basinhopping = optimizer_method == "basinhopping"
108
+ minimizer_kwargs = optimizer_kwargs.pop("minimizer_kwargs", {})
109
+
110
+ if do_basinhopping:
111
+ # For a nice API, we let the user set method="basinhopping", but if we're doing basinhopping we still need
112
+ # another method for the inner optimizer. This will be set in the minimizer_kwargs, but also needs a default
113
+ # if one isn't provided.
114
+
115
+ optimizer_method = minimizer_kwargs.pop("method", "L-BFGS-B")
116
+ minimizer_kwargs["method"] = optimizer_method
117
+
118
+ initial_point_dict = model.initial_point()
119
+ initial_point = DictToArrayBijection.map(initial_point_dict)
120
+ n_params = initial_point.data.shape[0]
121
+
122
+ var_params, objective = create_dadvi_graph(
123
+ model,
124
+ n_fixed_draws=n_fixed_draws,
125
+ random_seed=random_seed,
126
+ n_params=n_params,
127
+ )
128
+
129
+ use_grad, use_hess, use_hessp = set_optimizer_function_defaults(
130
+ optimizer_method, use_grad, use_hess, use_hessp
131
+ )
132
+
133
+ f_fused, f_hessp = scipy_optimize_funcs_from_loss(
134
+ loss=objective,
135
+ inputs=[var_params],
136
+ initial_point_dict=None,
137
+ use_grad=use_grad,
138
+ use_hessp=use_hessp,
139
+ use_hess=use_hess,
140
+ gradient_backend=gradient_backend,
141
+ compile_kwargs=compile_kwargs,
142
+ inputs_are_flat=True,
143
+ )
144
+
145
+ dadvi_initial_point = {
146
+ f"{var_name}_mu": np.zeros_like(value).ravel()
147
+ for var_name, value in initial_point_dict.items()
148
+ }
149
+ dadvi_initial_point.update(
150
+ {
151
+ f"{var_name}_sigma__log": np.zeros_like(value).ravel()
152
+ for var_name, value in initial_point_dict.items()
153
+ }
154
+ )
155
+
156
+ dadvi_initial_point = DictToArrayBijection.map(dadvi_initial_point)
157
+ args = optimizer_kwargs.pop("args", ())
158
+
159
+ if do_basinhopping:
160
+ if "args" not in minimizer_kwargs:
161
+ minimizer_kwargs["args"] = args
162
+ if "hessp" not in minimizer_kwargs:
163
+ minimizer_kwargs["hessp"] = f_hessp
164
+ if "method" not in minimizer_kwargs:
165
+ minimizer_kwargs["method"] = optimizer_method
166
+
167
+ result = basinhopping(
168
+ func=f_fused,
169
+ x0=dadvi_initial_point.data,
170
+ progressbar=progressbar,
171
+ minimizer_kwargs=minimizer_kwargs,
172
+ **optimizer_kwargs,
173
+ )
174
+
175
+ else:
176
+ result = minimize(
177
+ f=f_fused,
178
+ x0=dadvi_initial_point.data,
179
+ args=args,
180
+ method=optimizer_method,
181
+ hessp=f_hessp,
182
+ progressbar=progressbar,
183
+ **optimizer_kwargs,
184
+ )
185
+
186
+ raveled_optimized = RaveledVars(result.x, dadvi_initial_point.point_map_info)
187
+
188
+ opt_var_params = result.x
189
+ opt_means, opt_log_sds = np.split(opt_var_params, 2)
190
+
191
+ posterior, unconstrained_posterior = draws_from_laplace_approx(
192
+ mean=opt_means,
193
+ standard_deviation=np.exp(opt_log_sds),
194
+ draws=n_draws,
195
+ model=model,
196
+ vectorize_draws=False,
197
+ return_unconstrained=include_transformed,
198
+ random_seed=random_seed,
199
+ )
200
+ idata = InferenceData(posterior=posterior)
201
+ if include_transformed:
202
+ idata.add_groups(unconstrained_posterior=unconstrained_posterior)
203
+
204
+ var_name_to_model_var = {f"{var_name}_mu": var_name for var_name in initial_point_dict.keys()}
205
+ var_name_to_model_var.update(
206
+ {f"{var_name}_sigma__log": var_name for var_name in initial_point_dict.keys()}
207
+ )
208
+
209
+ idata = add_optimizer_result_to_inference_data(
210
+ idata=idata,
211
+ result=result,
212
+ method=optimizer_method,
213
+ mu=raveled_optimized,
214
+ model=model,
215
+ var_name_to_model_var=var_name_to_model_var,
216
+ )
217
+
218
+ idata = add_data_to_inference_data(
219
+ idata=idata, progressbar=False, model=model, compile_kwargs=compile_kwargs
220
+ )
221
+
222
+ return idata
223
+
224
+
225
+ def create_dadvi_graph(
226
+ model: Model,
227
+ n_params: int,
228
+ n_fixed_draws: int = 30,
229
+ random_seed: RandomSeed = None,
230
+ ) -> tuple[TensorVariable, TensorVariable]:
231
+ """
232
+ Sets up the DADVI graph in pytensor and returns it.
233
+
234
+ Parameters
235
+ ----------
236
+ model : pm.Model
237
+ The PyMC model to be fit.
238
+
239
+ n_params: int
240
+ The total number of parameters in the model.
241
+
242
+ n_fixed_draws : int
243
+ The number of fixed draws to use.
244
+
245
+ random_seed: int
246
+ The random seed to use for the fixed draws.
247
+
248
+ Returns
249
+ -------
250
+ Tuple[TensorVariable, TensorVariable]
251
+ A tuple whose first element contains the variational parameters,
252
+ and whose second contains the DADVI objective.
253
+ """
254
+
255
+ # Make the fixed draws
256
+ generator = np.random.default_rng(seed=random_seed)
257
+ draws = generator.standard_normal(size=(n_fixed_draws, n_params))
258
+
259
+ inputs = model.continuous_value_vars + model.discrete_value_vars
260
+ initial_point_dict = model.initial_point()
261
+ logp = model.logp()
262
+
263
+ # Graph in terms of a flat input
264
+ [logp], flat_input = join_nonshared_inputs(
265
+ point=initial_point_dict, outputs=[logp], inputs=inputs
266
+ )
267
+
268
+ var_params = pt.vector(name="eta", shape=(2 * n_params,))
269
+
270
+ means, log_sds = pt.split(var_params, axis=0, splits_size=[n_params, n_params], n_splits=2)
271
+
272
+ draw_matrix = pt.constant(draws)
273
+ samples = means + pt.exp(log_sds) * draw_matrix
274
+
275
+ logp_vectorized_draws = pytensor.graph.vectorize_graph(logp, replace={flat_input: samples})
276
+
277
+ mean_log_density = pt.mean(logp_vectorized_draws)
278
+ entropy = pt.sum(log_sds)
279
+
280
+ objective = -mean_log_density - entropy
281
+
282
+ return var_params, objective