pymc-extras 0.4.1__tar.gz → 0.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/PKG-INFO +1 -1
  2. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/_version.py +2 -2
  3. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/__init__.py +8 -1
  4. pymc_extras-0.5.0/pymc_extras/inference/dadvi/dadvi.py +261 -0
  5. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/fit.py +5 -0
  6. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/laplace_approx/find_map.py +16 -8
  7. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/laplace_approx/idata.py +5 -2
  8. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/laplace_approx/laplace.py +1 -0
  9. pymc_extras-0.5.0/pymc_extras/statespace/models/DFM.py +849 -0
  10. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/SARIMAX.py +4 -4
  11. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/VARMAX.py +7 -7
  12. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/utils/constants.py +3 -1
  13. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/inference/laplace_approx/test_find_map.py +6 -2
  14. pymc_extras-0.5.0/tests/statespace/models/test_DFM.py +727 -0
  15. pymc_extras-0.5.0/tests/utils.py +0 -0
  16. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/.gitignore +0 -0
  17. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/.gitpod.yml +0 -0
  18. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/.pre-commit-config.yaml +0 -0
  19. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/.readthedocs.yaml +0 -0
  20. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/CODE_OF_CONDUCT.md +0 -0
  21. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/CONTRIBUTING.md +0 -0
  22. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/LICENSE +0 -0
  23. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/README.md +0 -0
  24. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/codecov.yml +0 -0
  25. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/conda-envs/environment-test.yml +0 -0
  26. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/.nojekyll +0 -0
  27. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/Makefile +0 -0
  28. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/_templates/autosummary/base.rst +0 -0
  29. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/_templates/autosummary/class.rst +0 -0
  30. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/api_reference.rst +0 -0
  31. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/conf.py +0 -0
  32. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/index.rst +0 -0
  33. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/make.bat +0 -0
  34. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/statespace/core.rst +0 -0
  35. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/statespace/filters.rst +0 -0
  36. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/statespace/models/structural.rst +0 -0
  37. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/docs/statespace/models.rst +0 -0
  38. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/__init__.py +0 -0
  39. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/deserialize.py +0 -0
  40. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/distributions/__init__.py +0 -0
  41. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/distributions/continuous.py +0 -0
  42. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/distributions/discrete.py +0 -0
  43. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/distributions/histogram_utils.py +0 -0
  44. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/distributions/multivariate/__init__.py +0 -0
  45. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/distributions/multivariate/r2d2m2cp.py +0 -0
  46. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/distributions/timeseries.py +0 -0
  47. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/distributions/transforms/__init__.py +0 -0
  48. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/distributions/transforms/partial_order.py +0 -0
  49. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/gp/__init__.py +0 -0
  50. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/gp/latent_approx.py +0 -0
  51. {pymc_extras-0.4.1/pymc_extras/inference/laplace_approx → pymc_extras-0.5.0/pymc_extras/inference/dadvi}/__init__.py +0 -0
  52. {pymc_extras-0.4.1/pymc_extras/model → pymc_extras-0.5.0/pymc_extras/inference/laplace_approx}/__init__.py +0 -0
  53. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/laplace_approx/scipy_interface.py +0 -0
  54. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/pathfinder/__init__.py +0 -0
  55. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/pathfinder/importance_sampling.py +0 -0
  56. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/pathfinder/lbfgs.py +0 -0
  57. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/pathfinder/pathfinder.py +0 -0
  58. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/smc/__init__.py +0 -0
  59. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/inference/smc/sampling.py +0 -0
  60. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/linearmodel.py +0 -0
  61. {pymc_extras-0.4.1/pymc_extras/model/marginal → pymc_extras-0.5.0/pymc_extras/model}/__init__.py +0 -0
  62. {pymc_extras-0.4.1/pymc_extras/model/transforms → pymc_extras-0.5.0/pymc_extras/model/marginal}/__init__.py +0 -0
  63. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/model/marginal/distributions.py +0 -0
  64. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/model/marginal/graph_analysis.py +0 -0
  65. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/model/marginal/marginal_model.py +0 -0
  66. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/model/model_api.py +0 -0
  67. {pymc_extras-0.4.1/pymc_extras/preprocessing → pymc_extras-0.5.0/pymc_extras/model/transforms}/__init__.py +0 -0
  68. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/model/transforms/autoreparam.py +0 -0
  69. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/model_builder.py +0 -0
  70. {pymc_extras-0.4.1/pymc_extras/statespace/models/structural/components → pymc_extras-0.5.0/pymc_extras/preprocessing}/__init__.py +0 -0
  71. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/preprocessing/standard_scaler.py +0 -0
  72. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/printing.py +0 -0
  73. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/prior.py +0 -0
  74. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/__init__.py +0 -0
  75. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/core/__init__.py +0 -0
  76. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/core/compile.py +0 -0
  77. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/core/representation.py +0 -0
  78. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/core/statespace.py +0 -0
  79. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/filters/__init__.py +0 -0
  80. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/filters/distributions.py +0 -0
  81. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/filters/kalman_filter.py +0 -0
  82. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/filters/kalman_smoother.py +0 -0
  83. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/filters/utilities.py +0 -0
  84. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/ETS.py +0 -0
  85. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/__init__.py +0 -0
  86. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/structural/__init__.py +0 -0
  87. {pymc_extras-0.4.1/pymc_extras/statespace/utils → pymc_extras-0.5.0/pymc_extras/statespace/models/structural/components}/__init__.py +0 -0
  88. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/structural/components/autoregressive.py +0 -0
  89. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/structural/components/cycle.py +0 -0
  90. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/structural/components/level_trend.py +0 -0
  91. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/structural/components/measurement_error.py +0 -0
  92. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/structural/components/regression.py +0 -0
  93. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/structural/components/seasonality.py +0 -0
  94. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/structural/core.py +0 -0
  95. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/structural/utils.py +0 -0
  96. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/models/utilities.py +0 -0
  97. {pymc_extras-0.4.1/tests/inference → pymc_extras-0.5.0/pymc_extras/statespace/utils}/__init__.py +0 -0
  98. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/utils/coord_tools.py +0 -0
  99. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/statespace/utils/data_tools.py +0 -0
  100. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/utils/__init__.py +0 -0
  101. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/utils/linear_cg.py +0 -0
  102. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/utils/model_equivalence.py +0 -0
  103. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/utils/prior.py +0 -0
  104. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pymc_extras/utils/spline.py +0 -0
  105. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/pyproject.toml +0 -0
  106. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/__init__.py +0 -0
  107. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/conftest.py +0 -0
  108. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/distributions/__init__.py +0 -0
  109. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/distributions/test_continuous.py +0 -0
  110. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/distributions/test_discrete.py +0 -0
  111. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/distributions/test_discrete_markov_chain.py +0 -0
  112. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/distributions/test_multivariate.py +0 -0
  113. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/distributions/test_transform.py +0 -0
  114. {pymc_extras-0.4.1/tests/inference/laplace_approx → pymc_extras-0.5.0/tests/inference}/__init__.py +0 -0
  115. {pymc_extras-0.4.1/tests/model → pymc_extras-0.5.0/tests/inference/laplace_approx}/__init__.py +0 -0
  116. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/inference/laplace_approx/test_idata.py +0 -0
  117. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/inference/laplace_approx/test_laplace.py +0 -0
  118. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/inference/laplace_approx/test_scipy_interface.py +0 -0
  119. {pymc_extras-0.4.1/tests/model/marginal → pymc_extras-0.5.0/tests/model}/__init__.py +0 -0
  120. {pymc_extras-0.4.1/tests/statespace → pymc_extras-0.5.0/tests/model/marginal}/__init__.py +0 -0
  121. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/model/marginal/test_distributions.py +0 -0
  122. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/model/marginal/test_graph_analysis.py +0 -0
  123. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/model/marginal/test_marginal_model.py +0 -0
  124. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/model/test_model_api.py +0 -0
  125. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/model/transforms/test_autoreparam.py +0 -0
  126. {pymc_extras-0.4.1/tests/statespace/core → pymc_extras-0.5.0/tests/statespace}/__init__.py +0 -0
  127. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/_data/airpass.csv +0 -0
  128. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/_data/airpassangers.csv +0 -0
  129. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/_data/nile.csv +0 -0
  130. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/_data/statsmodels_macrodata_processed.csv +0 -0
  131. {pymc_extras-0.4.1/tests/statespace/filters → pymc_extras-0.5.0/tests/statespace/core}/__init__.py +0 -0
  132. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/core/test_representation.py +0 -0
  133. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/core/test_statespace.py +0 -0
  134. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/core/test_statespace_JAX.py +0 -0
  135. {pymc_extras-0.4.1/tests/statespace/models → pymc_extras-0.5.0/tests/statespace/filters}/__init__.py +0 -0
  136. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/filters/test_distributions.py +0 -0
  137. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/filters/test_kalman_filter.py +0 -0
  138. {pymc_extras-0.4.1/tests/statespace/models/structural → pymc_extras-0.5.0/tests/statespace/models}/__init__.py +0 -0
  139. {pymc_extras-0.4.1/tests/statespace/models/structural/components → pymc_extras-0.5.0/tests/statespace/models/structural}/__init__.py +0 -0
  140. {pymc_extras-0.4.1/tests/statespace/utils → pymc_extras-0.5.0/tests/statespace/models/structural/components}/__init__.py +0 -0
  141. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/structural/components/test_autoregressive.py +0 -0
  142. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/structural/components/test_cycle.py +0 -0
  143. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/structural/components/test_level_trend.py +0 -0
  144. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/structural/components/test_measurement_error.py +0 -0
  145. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/structural/components/test_regression.py +0 -0
  146. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/structural/components/test_seasonality.py +0 -0
  147. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/structural/conftest.py +0 -0
  148. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/structural/test_against_statsmodels.py +0 -0
  149. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/structural/test_core.py +0 -0
  150. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/test_ETS.py +0 -0
  151. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/test_SARIMAX.py +0 -0
  152. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/test_VARMAX.py +0 -0
  153. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/models/test_utilities.py +0 -0
  154. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/shared_fixtures.py +0 -0
  155. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/statsmodel_local_level.py +0 -0
  156. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/test_utilities.py +0 -0
  157. /pymc_extras-0.4.1/tests/utils.py → /pymc_extras-0.5.0/tests/statespace/utils/__init__.py +0 -0
  158. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/statespace/utils/test_coord_assignment.py +0 -0
  159. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/test_blackjax_smc.py +0 -0
  160. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/test_deserialize.py +0 -0
  161. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/test_histogram_approximation.py +0 -0
  162. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/test_linearmodel.py +0 -0
  163. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/test_model_builder.py +0 -0
  164. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/test_pathfinder.py +0 -0
  165. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/test_printing.py +0 -0
  166. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/test_prior.py +0 -0
  167. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/test_prior_from_trace.py +0 -0
  168. {pymc_extras-0.4.1 → pymc_extras-0.5.0}/tests/test_splines.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pymc-extras
3
- Version: 0.4.1
3
+ Version: 0.5.0
4
4
  Summary: A home for new additions to PyMC, which may include unusual probability distribitions, advanced model fitting algorithms, or any code that may be inappropriate to include in the pymc repository, but may want to be made available to users.
5
5
  Project-URL: Documentation, https://pymc-extras.readthedocs.io/
6
6
  Project-URL: Repository, https://github.com/pymc-devs/pymc-extras.git
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.4.1'
32
- __version_tuple__ = version_tuple = (0, 4, 1)
31
+ __version__ = version = '0.5.0'
32
+ __version_tuple__ = version_tuple = (0, 5, 0)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -12,9 +12,16 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from pymc_extras.inference.dadvi.dadvi import fit_dadvi
15
16
  from pymc_extras.inference.fit import fit
16
17
  from pymc_extras.inference.laplace_approx.find_map import find_MAP
17
18
  from pymc_extras.inference.laplace_approx.laplace import fit_laplace
18
19
  from pymc_extras.inference.pathfinder.pathfinder import fit_pathfinder
19
20
 
20
- __all__ = ["find_MAP", "fit", "fit_laplace", "fit_pathfinder"]
21
+ __all__ = [
22
+ "find_MAP",
23
+ "fit",
24
+ "fit_laplace",
25
+ "fit_pathfinder",
26
+ "fit_dadvi",
27
+ ]
@@ -0,0 +1,261 @@
1
+ import arviz as az
2
+ import numpy as np
3
+ import pymc
4
+ import pytensor
5
+ import pytensor.tensor as pt
6
+ import xarray
7
+
8
+ from better_optimize import minimize
9
+ from better_optimize.constants import minimize_method
10
+ from pymc import DictToArrayBijection, Model, join_nonshared_inputs
11
+ from pymc.backends.arviz import (
12
+ PointFunc,
13
+ apply_function_over_dataset,
14
+ coords_and_dims_for_inferencedata,
15
+ )
16
+ from pymc.util import RandomSeed, get_default_varnames
17
+ from pytensor.tensor.variable import TensorVariable
18
+
19
+ from pymc_extras.inference.laplace_approx.laplace import unstack_laplace_draws
20
+ from pymc_extras.inference.laplace_approx.scipy_interface import (
21
+ _compile_functions_for_scipy_optimize,
22
+ )
23
+
24
+
25
+ def fit_dadvi(
26
+ model: Model | None = None,
27
+ n_fixed_draws: int = 30,
28
+ random_seed: RandomSeed = None,
29
+ n_draws: int = 1000,
30
+ keep_untransformed: bool = False,
31
+ optimizer_method: minimize_method = "trust-ncg",
32
+ use_grad: bool = True,
33
+ use_hessp: bool = True,
34
+ use_hess: bool = False,
35
+ **minimize_kwargs,
36
+ ) -> az.InferenceData:
37
+ """
38
+ Does inference using deterministic ADVI (automatic differentiation
39
+ variational inference), DADVI for short.
40
+
41
+ For full details see the paper cited in the references:
42
+ https://www.jmlr.org/papers/v25/23-1015.html
43
+
44
+ Parameters
45
+ ----------
46
+ model : pm.Model
47
+ The PyMC model to be fit. If None, the current model context is used.
48
+
49
+ n_fixed_draws : int
50
+ The number of fixed draws to use for the optimisation. More
51
+ draws will result in more accurate estimates, but also
52
+ increase inference time. Usually, the default of 30 is a good
53
+ tradeoff.between speed and accuracy.
54
+
55
+ random_seed: int
56
+ The random seed to use for the fixed draws. Running the optimisation
57
+ twice with the same seed should arrive at the same result.
58
+
59
+ n_draws: int
60
+ The number of draws to return from the variational approximation.
61
+
62
+ keep_untransformed: bool
63
+ Whether or not to keep the unconstrained variables (such as
64
+ logs of positive-constrained parameters) in the output.
65
+
66
+ optimizer_method: str
67
+ Which optimization method to use. The function calls
68
+ ``scipy.optimize.minimize``, so any of the methods there can
69
+ be used. The default is trust-ncg, which uses second-order
70
+ information and is generally very reliable. Other methods such
71
+ as L-BFGS-B might be faster but potentially more brittle and
72
+ may not converge exactly to the optimum.
73
+
74
+ minimize_kwargs:
75
+ Additional keyword arguments to pass to the
76
+ ``scipy.optimize.minimize`` function. See the documentation of
77
+ that function for details.
78
+
79
+ use_grad:
80
+ If True, pass the gradient function to
81
+ `scipy.optimize.minimize` (where it is referred to as `jac`).
82
+
83
+ use_hessp:
84
+ If True, pass the hessian vector product to `scipy.optimize.minimize`.
85
+
86
+ use_hess:
87
+ If True, pass the hessian to `scipy.optimize.minimize`. Note that
88
+ this is generally not recommended since its computation can be slow
89
+ and memory-intensive if there are many parameters.
90
+
91
+ Returns
92
+ -------
93
+ :class:`~arviz.InferenceData`
94
+ The inference data containing the results of the DADVI algorithm.
95
+
96
+ References
97
+ ----------
98
+ Giordano, R., Ingram, M., & Broderick, T. (2024). Black Box
99
+ Variational Inference with a Deterministic Objective: Faster, More
100
+ Accurate, and Even More Black Box. Journal of Machine Learning
101
+ Research, 25(18), 1–39.
102
+ """
103
+
104
+ model = pymc.modelcontext(model) if model is None else model
105
+
106
+ initial_point_dict = model.initial_point()
107
+ n_params = DictToArrayBijection.map(initial_point_dict).data.shape[0]
108
+
109
+ var_params, objective = create_dadvi_graph(
110
+ model,
111
+ n_fixed_draws=n_fixed_draws,
112
+ random_seed=random_seed,
113
+ n_params=n_params,
114
+ )
115
+
116
+ f_fused, f_hessp = _compile_functions_for_scipy_optimize(
117
+ objective,
118
+ [var_params],
119
+ compute_grad=use_grad,
120
+ compute_hessp=use_hessp,
121
+ compute_hess=use_hess,
122
+ )
123
+
124
+ derivative_kwargs = {}
125
+
126
+ if use_grad:
127
+ derivative_kwargs["jac"] = True
128
+ if use_hessp:
129
+ derivative_kwargs["hessp"] = f_hessp
130
+ if use_hess:
131
+ derivative_kwargs["hess"] = True
132
+
133
+ result = minimize(
134
+ f_fused,
135
+ np.zeros(2 * n_params),
136
+ method=optimizer_method,
137
+ **derivative_kwargs,
138
+ **minimize_kwargs,
139
+ )
140
+
141
+ opt_var_params = result.x
142
+ opt_means, opt_log_sds = np.split(opt_var_params, 2)
143
+
144
+ # Make the draws:
145
+ generator = np.random.default_rng(seed=random_seed)
146
+ draws_raw = generator.standard_normal(size=(n_draws, n_params))
147
+
148
+ draws = opt_means + draws_raw * np.exp(opt_log_sds)
149
+ draws_arviz = unstack_laplace_draws(draws, model, chains=1, draws=n_draws)
150
+
151
+ transformed_draws = transform_draws(draws_arviz, model, keep_untransformed=keep_untransformed)
152
+
153
+ return transformed_draws
154
+
155
+
156
+ def create_dadvi_graph(
157
+ model: Model,
158
+ n_params: int,
159
+ n_fixed_draws: int = 30,
160
+ random_seed: RandomSeed = None,
161
+ ) -> tuple[TensorVariable, TensorVariable]:
162
+ """
163
+ Sets up the DADVI graph in pytensor and returns it.
164
+
165
+ Parameters
166
+ ----------
167
+ model : pm.Model
168
+ The PyMC model to be fit.
169
+
170
+ n_params: int
171
+ The total number of parameters in the model.
172
+
173
+ n_fixed_draws : int
174
+ The number of fixed draws to use.
175
+
176
+ random_seed: int
177
+ The random seed to use for the fixed draws.
178
+
179
+ Returns
180
+ -------
181
+ Tuple[TensorVariable, TensorVariable]
182
+ A tuple whose first element contains the variational parameters,
183
+ and whose second contains the DADVI objective.
184
+ """
185
+
186
+ # Make the fixed draws
187
+ generator = np.random.default_rng(seed=random_seed)
188
+ draws = generator.standard_normal(size=(n_fixed_draws, n_params))
189
+
190
+ inputs = model.continuous_value_vars + model.discrete_value_vars
191
+ initial_point_dict = model.initial_point()
192
+ logp = model.logp()
193
+
194
+ # Graph in terms of a flat input
195
+ [logp], flat_input = join_nonshared_inputs(
196
+ point=initial_point_dict, outputs=[logp], inputs=inputs
197
+ )
198
+
199
+ var_params = pt.vector(name="eta", shape=(2 * n_params,))
200
+
201
+ means, log_sds = pt.split(var_params, axis=0, splits_size=[n_params, n_params], n_splits=2)
202
+
203
+ draw_matrix = pt.constant(draws)
204
+ samples = means + pt.exp(log_sds) * draw_matrix
205
+
206
+ logp_vectorized_draws = pytensor.graph.vectorize_graph(logp, replace={flat_input: samples})
207
+
208
+ mean_log_density = pt.mean(logp_vectorized_draws)
209
+ entropy = pt.sum(log_sds)
210
+
211
+ objective = -mean_log_density - entropy
212
+
213
+ return var_params, objective
214
+
215
+
216
+ def transform_draws(
217
+ unstacked_draws: xarray.Dataset,
218
+ model: Model,
219
+ keep_untransformed: bool = False,
220
+ ):
221
+ """
222
+ Transforms the unconstrained draws back into the constrained space.
223
+
224
+ Parameters
225
+ ----------
226
+ unstacked_draws : xarray.Dataset
227
+ The draws to constrain back into the original space.
228
+
229
+ model : Model
230
+ The PyMC model the variables were derived from.
231
+
232
+ n_draws: int
233
+ The number of draws to return from the variational approximation.
234
+
235
+ keep_untransformed: bool
236
+ Whether or not to keep the unconstrained variables in the output.
237
+
238
+ Returns
239
+ -------
240
+ :class:`~arviz.InferenceData`
241
+ Draws from the original constrained parameters.
242
+ """
243
+
244
+ filtered_var_names = model.unobserved_value_vars
245
+ vars_to_sample = list(
246
+ get_default_varnames(filtered_var_names, include_transformed=keep_untransformed)
247
+ )
248
+ fn = pytensor.function(model.value_vars, vars_to_sample)
249
+ point_func = PointFunc(fn)
250
+
251
+ coords, dims = coords_and_dims_for_inferencedata(model)
252
+
253
+ transformed_result = apply_function_over_dataset(
254
+ point_func,
255
+ unstacked_draws,
256
+ output_var_names=[x.name for x in vars_to_sample],
257
+ coords=coords,
258
+ dims=dims,
259
+ )
260
+
261
+ return transformed_result
@@ -40,3 +40,8 @@ def fit(method: str, **kwargs) -> az.InferenceData:
40
40
  from pymc_extras.inference import fit_laplace
41
41
 
42
42
  return fit_laplace(**kwargs)
43
+
44
+ if method == "dadvi":
45
+ from pymc_extras.inference import fit_dadvi
46
+
47
+ return fit_dadvi(**kwargs)
@@ -198,6 +198,7 @@ def find_MAP(
198
198
  include_transformed: bool = True,
199
199
  gradient_backend: GradientBackend = "pytensor",
200
200
  compile_kwargs: dict | None = None,
201
+ compute_hessian: bool = False,
201
202
  **optimizer_kwargs,
202
203
  ) -> (
203
204
  dict[str, np.ndarray]
@@ -239,6 +240,10 @@ def find_MAP(
239
240
  Whether to include transformed variable values in the returned dictionary. Defaults to True.
240
241
  gradient_backend: str, default "pytensor"
241
242
  Which backend to use to compute gradients. Must be one of "pytensor" or "jax".
243
+ compute_hessian: bool
244
+ If True, the inverse Hessian matrix at the optimum will be computed and included in the returned
245
+ InferenceData object. This is needed for the Laplace approximation, but can be computationally expensive for
246
+ high-dimensional problems. Defaults to False.
242
247
  compile_kwargs: dict, optional
243
248
  Additional options to pass to the ``pytensor.function`` function when compiling loss functions.
244
249
  **optimizer_kwargs
@@ -316,14 +321,17 @@ def find_MAP(
316
321
  **optimizer_kwargs,
317
322
  )
318
323
 
319
- H_inv = _compute_inverse_hessian(
320
- optimizer_result=optimizer_result,
321
- optimal_point=None,
322
- f_fused=f_fused,
323
- f_hessp=f_hessp,
324
- use_hess=use_hess,
325
- method=method,
326
- )
324
+ if compute_hessian:
325
+ H_inv = _compute_inverse_hessian(
326
+ optimizer_result=optimizer_result,
327
+ optimal_point=None,
328
+ f_fused=f_fused,
329
+ f_hessp=f_hessp,
330
+ use_hess=use_hess,
331
+ method=method,
332
+ )
333
+ else:
334
+ H_inv = None
327
335
 
328
336
  raveled_optimized = RaveledVars(optimizer_result.x, initial_params.point_map_info)
329
337
  unobserved_vars = get_default_varnames(model.unobserved_value_vars, include_transformed=True)
@@ -136,7 +136,10 @@ def map_results_to_inference_data(
136
136
 
137
137
 
138
138
  def add_fit_to_inference_data(
139
- idata: az.InferenceData, mu: RaveledVars, H_inv: np.ndarray, model: pm.Model | None = None
139
+ idata: az.InferenceData,
140
+ mu: RaveledVars,
141
+ H_inv: np.ndarray | None,
142
+ model: pm.Model | None = None,
140
143
  ) -> az.InferenceData:
141
144
  """
142
145
  Add the mean vector and covariance matrix of the Laplace approximation to an InferenceData object.
@@ -147,7 +150,7 @@ def add_fit_to_inference_data(
147
150
  An InferenceData object containing the approximated posterior samples.
148
151
  mu: RaveledVars
149
152
  The MAP estimate of the model parameters.
150
- H_inv: np.ndarray
153
+ H_inv: np.ndarray, optional
151
154
  The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
152
155
  model: Model, optional
153
156
  A PyMC model. If None, the model is taken from the current model context.
@@ -389,6 +389,7 @@ def fit_laplace(
389
389
  include_transformed=include_transformed,
390
390
  gradient_backend=gradient_backend,
391
391
  compile_kwargs=compile_kwargs,
392
+ compute_hessian=True,
392
393
  **optimizer_kwargs,
393
394
  )
394
395