arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/data/io_pystan.py DELETED
@@ -1,1095 +0,0 @@
1
- # pylint: disable=too-many-instance-attributes,too-many-lines
2
- """PyStan-specific conversion code."""
3
- import re
4
- from collections import OrderedDict
5
- from copy import deepcopy
6
- from math import ceil
7
-
8
- import numpy as np
9
- import xarray as xr
10
-
11
- from .. import _log
12
- from ..rcparams import rcParams
13
- from .base import dict_to_dataset, generate_dims_coords, infer_stan_dtypes, make_attrs, requires
14
- from .inference_data import InferenceData
15
-
16
- try:
17
- import ujson as json
18
- except ImportError:
19
- # Can't find ujson using json
20
- # mypy struggles with conditional imports expressed as catching ImportError:
21
- # https://github.com/python/mypy/issues/1153
22
- import json # type: ignore
23
-
24
-
25
- class PyStanConverter:
26
- """Encapsulate PyStan specific logic."""
27
-
28
- def __init__(
29
- self,
30
- *,
31
- posterior=None,
32
- posterior_predictive=None,
33
- predictions=None,
34
- prior=None,
35
- prior_predictive=None,
36
- observed_data=None,
37
- constant_data=None,
38
- predictions_constant_data=None,
39
- log_likelihood=None,
40
- coords=None,
41
- dims=None,
42
- save_warmup=None,
43
- dtypes=None,
44
- ):
45
- self.posterior = posterior
46
- self.posterior_predictive = posterior_predictive
47
- self.predictions = predictions
48
- self.prior = prior
49
- self.prior_predictive = prior_predictive
50
- self.observed_data = observed_data
51
- self.constant_data = constant_data
52
- self.predictions_constant_data = predictions_constant_data
53
- self.log_likelihood = (
54
- rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood
55
- )
56
- self.coords = coords
57
- self.dims = dims
58
- self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
59
- self.dtypes = dtypes
60
-
61
- if (
62
- self.log_likelihood is True
63
- and self.posterior is not None
64
- and "log_lik" in self.posterior.sim["pars_oi"]
65
- ):
66
- self.log_likelihood = ["log_lik"]
67
- elif isinstance(self.log_likelihood, bool):
68
- self.log_likelihood = None
69
-
70
- import pystan # pylint: disable=import-error
71
-
72
- self.pystan = pystan
73
-
74
- @requires("posterior")
75
- def posterior_to_xarray(self):
76
- """Extract posterior samples from fit."""
77
- posterior = self.posterior
78
- # filter posterior_predictive and log_likelihood
79
- posterior_predictive = self.posterior_predictive
80
- if posterior_predictive is None:
81
- posterior_predictive = []
82
- elif isinstance(posterior_predictive, str):
83
- posterior_predictive = [posterior_predictive]
84
- predictions = self.predictions
85
- if predictions is None:
86
- predictions = []
87
- elif isinstance(predictions, str):
88
- predictions = [predictions]
89
- log_likelihood = self.log_likelihood
90
- if log_likelihood is None:
91
- log_likelihood = []
92
- elif isinstance(log_likelihood, str):
93
- log_likelihood = [log_likelihood]
94
- elif isinstance(log_likelihood, dict):
95
- log_likelihood = list(log_likelihood.values())
96
-
97
- ignore = posterior_predictive + predictions + log_likelihood + ["lp__"]
98
-
99
- data, data_warmup = get_draws(
100
- posterior, ignore=ignore, warmup=self.save_warmup, dtypes=self.dtypes
101
- )
102
- attrs = get_attrs(posterior)
103
- return (
104
- dict_to_dataset(
105
- data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
106
- ),
107
- dict_to_dataset(
108
- data_warmup, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
109
- ),
110
- )
111
-
112
- @requires("posterior")
113
- def sample_stats_to_xarray(self):
114
- """Extract sample_stats from posterior."""
115
- posterior = self.posterior
116
-
117
- data, data_warmup = get_sample_stats(posterior, warmup=self.save_warmup)
118
-
119
- # lp__
120
- stat_lp, stat_lp_warmup = get_draws(
121
- posterior, variables="lp__", warmup=self.save_warmup, dtypes=self.dtypes
122
- )
123
- data["lp"] = stat_lp["lp__"]
124
- if stat_lp_warmup:
125
- data_warmup["lp"] = stat_lp_warmup["lp__"]
126
-
127
- attrs = get_attrs(posterior)
128
- return (
129
- dict_to_dataset(
130
- data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
131
- ),
132
- dict_to_dataset(
133
- data_warmup, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
134
- ),
135
- )
136
-
137
- @requires("posterior")
138
- @requires("log_likelihood")
139
- def log_likelihood_to_xarray(self):
140
- """Store log_likelihood data in log_likelihood group."""
141
- fit = self.posterior
142
-
143
- # log_likelihood values
144
- log_likelihood = self.log_likelihood
145
- if isinstance(log_likelihood, str):
146
- log_likelihood = [log_likelihood]
147
- if isinstance(log_likelihood, (list, tuple)):
148
- log_likelihood = {name: name for name in log_likelihood}
149
- log_likelihood_draws, log_likelihood_draws_warmup = get_draws(
150
- fit,
151
- variables=list(log_likelihood.values()),
152
- warmup=self.save_warmup,
153
- dtypes=self.dtypes,
154
- )
155
- data = {
156
- obs_var_name: log_likelihood_draws[log_like_name]
157
- for obs_var_name, log_like_name in log_likelihood.items()
158
- if log_like_name in log_likelihood_draws
159
- }
160
-
161
- data_warmup = {
162
- obs_var_name: log_likelihood_draws_warmup[log_like_name]
163
- for obs_var_name, log_like_name in log_likelihood.items()
164
- if log_like_name in log_likelihood_draws_warmup
165
- }
166
-
167
- return (
168
- dict_to_dataset(
169
- data, library=self.pystan, coords=self.coords, dims=self.dims, skip_event_dims=True
170
- ),
171
- dict_to_dataset(
172
- data_warmup,
173
- library=self.pystan,
174
- coords=self.coords,
175
- dims=self.dims,
176
- skip_event_dims=True,
177
- ),
178
- )
179
-
180
- @requires("posterior")
181
- @requires("posterior_predictive")
182
- def posterior_predictive_to_xarray(self):
183
- """Convert posterior_predictive samples to xarray."""
184
- posterior = self.posterior
185
- posterior_predictive = self.posterior_predictive
186
- data, data_warmup = get_draws(
187
- posterior, variables=posterior_predictive, warmup=self.save_warmup, dtypes=self.dtypes
188
- )
189
- return (
190
- dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims),
191
- dict_to_dataset(data_warmup, library=self.pystan, coords=self.coords, dims=self.dims),
192
- )
193
-
194
- @requires("posterior")
195
- @requires("predictions")
196
- def predictions_to_xarray(self):
197
- """Convert predictions samples to xarray."""
198
- posterior = self.posterior
199
- predictions = self.predictions
200
- data, data_warmup = get_draws(
201
- posterior, variables=predictions, warmup=self.save_warmup, dtypes=self.dtypes
202
- )
203
- return (
204
- dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims),
205
- dict_to_dataset(data_warmup, library=self.pystan, coords=self.coords, dims=self.dims),
206
- )
207
-
208
- @requires("prior")
209
- def prior_to_xarray(self):
210
- """Convert prior samples to xarray."""
211
- prior = self.prior
212
- # filter posterior_predictive and log_likelihood
213
- prior_predictive = self.prior_predictive
214
- if prior_predictive is None:
215
- prior_predictive = []
216
- elif isinstance(prior_predictive, str):
217
- prior_predictive = [prior_predictive]
218
-
219
- ignore = prior_predictive + ["lp__"]
220
-
221
- data, _ = get_draws(prior, ignore=ignore, warmup=False, dtypes=self.dtypes)
222
- attrs = get_attrs(prior)
223
- return dict_to_dataset(
224
- data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
225
- )
226
-
227
- @requires("prior")
228
- def sample_stats_prior_to_xarray(self):
229
- """Extract sample_stats_prior from prior."""
230
- prior = self.prior
231
- data, _ = get_sample_stats(prior, warmup=False)
232
-
233
- # lp__
234
- stat_lp, _ = get_draws(prior, variables="lp__", warmup=False, dtypes=self.dtypes)
235
- data["lp"] = stat_lp["lp__"]
236
-
237
- attrs = get_attrs(prior)
238
- return dict_to_dataset(
239
- data, library=self.pystan, attrs=attrs, coords=self.coords, dims=self.dims
240
- )
241
-
242
- @requires("prior")
243
- @requires("prior_predictive")
244
- def prior_predictive_to_xarray(self):
245
- """Convert prior_predictive samples to xarray."""
246
- prior = self.prior
247
- prior_predictive = self.prior_predictive
248
- data, _ = get_draws(prior, variables=prior_predictive, warmup=False, dtypes=self.dtypes)
249
- return dict_to_dataset(data, library=self.pystan, coords=self.coords, dims=self.dims)
250
-
251
- @requires("posterior")
252
- @requires(["observed_data", "constant_data", "predictions_constant_data"])
253
- def data_to_xarray(self):
254
- """Convert observed, constant data and predictions constant data to xarray."""
255
- posterior = self.posterior
256
- dims = {} if self.dims is None else self.dims
257
- obs_const_dict = {}
258
- for group_name in ("observed_data", "constant_data", "predictions_constant_data"):
259
- names = getattr(self, group_name)
260
- if names is None:
261
- continue
262
- names = [names] if isinstance(names, str) else names
263
- data = OrderedDict()
264
- for key in names:
265
- vals = np.atleast_1d(posterior.data[key])
266
- val_dims = dims.get(key)
267
- val_dims, coords = generate_dims_coords(
268
- vals.shape, key, dims=val_dims, coords=self.coords
269
- )
270
- data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
271
- obs_const_dict[group_name] = xr.Dataset(
272
- data_vars=data, attrs=make_attrs(library=self.pystan)
273
- )
274
- return obs_const_dict
275
-
276
- def to_inference_data(self):
277
- """Convert all available data to an InferenceData object.
278
-
279
- Note that if groups can not be created (i.e., there is no `fit`, so
280
- the `posterior` and `sample_stats` can not be extracted), then the InferenceData
281
- will not have those groups.
282
- """
283
- data_dict = self.data_to_xarray()
284
- return InferenceData(
285
- save_warmup=self.save_warmup,
286
- **{
287
- "posterior": self.posterior_to_xarray(),
288
- "sample_stats": self.sample_stats_to_xarray(),
289
- "log_likelihood": self.log_likelihood_to_xarray(),
290
- "posterior_predictive": self.posterior_predictive_to_xarray(),
291
- "predictions": self.predictions_to_xarray(),
292
- "prior": self.prior_to_xarray(),
293
- "sample_stats_prior": self.sample_stats_prior_to_xarray(),
294
- "prior_predictive": self.prior_predictive_to_xarray(),
295
- **({} if data_dict is None else data_dict),
296
- },
297
- )
298
-
299
-
300
- class PyStan3Converter:
301
- """Encapsulate PyStan3 specific logic."""
302
-
303
- # pylint: disable=too-many-instance-attributes
304
- def __init__(
305
- self,
306
- *,
307
- posterior=None,
308
- posterior_model=None,
309
- posterior_predictive=None,
310
- predictions=None,
311
- prior=None,
312
- prior_model=None,
313
- prior_predictive=None,
314
- observed_data=None,
315
- constant_data=None,
316
- predictions_constant_data=None,
317
- log_likelihood=None,
318
- coords=None,
319
- dims=None,
320
- save_warmup=None,
321
- dtypes=None,
322
- ):
323
- self.posterior = posterior
324
- self.posterior_model = posterior_model
325
- self.posterior_predictive = posterior_predictive
326
- self.predictions = predictions
327
- self.prior = prior
328
- self.prior_model = prior_model
329
- self.prior_predictive = prior_predictive
330
- self.observed_data = observed_data
331
- self.constant_data = constant_data
332
- self.predictions_constant_data = predictions_constant_data
333
- self.log_likelihood = (
334
- rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood
335
- )
336
- self.coords = coords
337
- self.dims = dims
338
- self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
339
- self.dtypes = dtypes
340
-
341
- if (
342
- self.log_likelihood is True
343
- and self.posterior is not None
344
- and "log_lik" in self.posterior.param_names
345
- ):
346
- self.log_likelihood = ["log_lik"]
347
- elif isinstance(self.log_likelihood, bool):
348
- self.log_likelihood = None
349
-
350
- import stan # pylint: disable=import-error
351
-
352
- self.stan = stan
353
-
354
- @requires("posterior")
355
- def posterior_to_xarray(self):
356
- """Extract posterior samples from fit."""
357
- posterior = self.posterior
358
- posterior_model = self.posterior_model
359
- # filter posterior_predictive and log_likelihood
360
- posterior_predictive = self.posterior_predictive
361
- if posterior_predictive is None:
362
- posterior_predictive = []
363
- elif isinstance(posterior_predictive, str):
364
- posterior_predictive = [posterior_predictive]
365
- predictions = self.predictions
366
- if predictions is None:
367
- predictions = []
368
- elif isinstance(predictions, str):
369
- predictions = [predictions]
370
- log_likelihood = self.log_likelihood
371
- if log_likelihood is None:
372
- log_likelihood = []
373
- elif isinstance(log_likelihood, str):
374
- log_likelihood = [log_likelihood]
375
- elif isinstance(log_likelihood, dict):
376
- log_likelihood = list(log_likelihood.values())
377
-
378
- ignore = posterior_predictive + predictions + log_likelihood
379
-
380
- data, data_warmup = get_draws_stan3(
381
- posterior,
382
- model=posterior_model,
383
- ignore=ignore,
384
- warmup=self.save_warmup,
385
- dtypes=self.dtypes,
386
- )
387
- attrs = get_attrs_stan3(posterior, model=posterior_model)
388
- return (
389
- dict_to_dataset(
390
- data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
391
- ),
392
- dict_to_dataset(
393
- data_warmup, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
394
- ),
395
- )
396
-
397
- @requires("posterior")
398
- def sample_stats_to_xarray(self):
399
- """Extract sample_stats from posterior."""
400
- posterior = self.posterior
401
- posterior_model = self.posterior_model
402
- data, data_warmup = get_sample_stats_stan3(
403
- posterior, ignore="lp__", warmup=self.save_warmup, dtypes=self.dtypes
404
- )
405
- data_lp, data_warmup_lp = get_sample_stats_stan3(
406
- posterior, variables="lp__", warmup=self.save_warmup
407
- )
408
- data["lp"] = data_lp["lp"]
409
- if data_warmup_lp:
410
- data_warmup["lp"] = data_warmup_lp["lp"]
411
-
412
- attrs = get_attrs_stan3(posterior, model=posterior_model)
413
- return (
414
- dict_to_dataset(
415
- data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
416
- ),
417
- dict_to_dataset(
418
- data_warmup, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
419
- ),
420
- )
421
-
422
- @requires("posterior")
423
- @requires("log_likelihood")
424
- def log_likelihood_to_xarray(self):
425
- """Store log_likelihood data in log_likelihood group."""
426
- fit = self.posterior
427
-
428
- log_likelihood = self.log_likelihood
429
- model = self.posterior_model
430
- if isinstance(log_likelihood, str):
431
- log_likelihood = [log_likelihood]
432
- if isinstance(log_likelihood, (list, tuple)):
433
- log_likelihood = {name: name for name in log_likelihood}
434
- log_likelihood_draws, log_likelihood_draws_warmup = get_draws_stan3(
435
- fit,
436
- model=model,
437
- variables=list(log_likelihood.values()),
438
- warmup=self.save_warmup,
439
- dtypes=self.dtypes,
440
- )
441
- data = {
442
- obs_var_name: log_likelihood_draws[log_like_name]
443
- for obs_var_name, log_like_name in log_likelihood.items()
444
- if log_like_name in log_likelihood_draws
445
- }
446
- data_warmup = {
447
- obs_var_name: log_likelihood_draws_warmup[log_like_name]
448
- for obs_var_name, log_like_name in log_likelihood.items()
449
- if log_like_name in log_likelihood_draws_warmup
450
- }
451
-
452
- return (
453
- dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims),
454
- dict_to_dataset(data_warmup, library=self.stan, coords=self.coords, dims=self.dims),
455
- )
456
-
457
- @requires("posterior")
458
- @requires("posterior_predictive")
459
- def posterior_predictive_to_xarray(self):
460
- """Convert posterior_predictive samples to xarray."""
461
- posterior = self.posterior
462
- posterior_model = self.posterior_model
463
- posterior_predictive = self.posterior_predictive
464
- data, data_warmup = get_draws_stan3(
465
- posterior,
466
- model=posterior_model,
467
- variables=posterior_predictive,
468
- warmup=self.save_warmup,
469
- dtypes=self.dtypes,
470
- )
471
- return (
472
- dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims),
473
- dict_to_dataset(data_warmup, library=self.stan, coords=self.coords, dims=self.dims),
474
- )
475
-
476
- @requires("posterior")
477
- @requires("predictions")
478
- def predictions_to_xarray(self):
479
- """Convert predictions samples to xarray."""
480
- posterior = self.posterior
481
- posterior_model = self.posterior_model
482
- predictions = self.predictions
483
- data, data_warmup = get_draws_stan3(
484
- posterior,
485
- model=posterior_model,
486
- variables=predictions,
487
- warmup=self.save_warmup,
488
- dtypes=self.dtypes,
489
- )
490
- return (
491
- dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims),
492
- dict_to_dataset(data_warmup, library=self.stan, coords=self.coords, dims=self.dims),
493
- )
494
-
495
- @requires("prior")
496
- def prior_to_xarray(self):
497
- """Convert prior samples to xarray."""
498
- prior = self.prior
499
- prior_model = self.prior_model
500
- # filter posterior_predictive and log_likelihood
501
- prior_predictive = self.prior_predictive
502
- if prior_predictive is None:
503
- prior_predictive = []
504
- elif isinstance(prior_predictive, str):
505
- prior_predictive = [prior_predictive]
506
-
507
- ignore = prior_predictive
508
-
509
- data, data_warmup = get_draws_stan3(
510
- prior, model=prior_model, ignore=ignore, warmup=self.save_warmup, dtypes=self.dtypes
511
- )
512
- attrs = get_attrs_stan3(prior, model=prior_model)
513
- return (
514
- dict_to_dataset(
515
- data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
516
- ),
517
- dict_to_dataset(
518
- data_warmup, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
519
- ),
520
- )
521
-
522
- @requires("prior")
523
- def sample_stats_prior_to_xarray(self):
524
- """Extract sample_stats_prior from prior."""
525
- prior = self.prior
526
- prior_model = self.prior_model
527
- data, data_warmup = get_sample_stats_stan3(
528
- prior, warmup=self.save_warmup, dtypes=self.dtypes
529
- )
530
- attrs = get_attrs_stan3(prior, model=prior_model)
531
- return (
532
- dict_to_dataset(
533
- data, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
534
- ),
535
- dict_to_dataset(
536
- data_warmup, library=self.stan, attrs=attrs, coords=self.coords, dims=self.dims
537
- ),
538
- )
539
-
540
- @requires("prior")
541
- @requires("prior_predictive")
542
- def prior_predictive_to_xarray(self):
543
- """Convert prior_predictive samples to xarray."""
544
- prior = self.prior
545
- prior_model = self.prior_model
546
- prior_predictive = self.prior_predictive
547
- data, data_warmup = get_draws_stan3(
548
- prior,
549
- model=prior_model,
550
- variables=prior_predictive,
551
- warmup=self.save_warmup,
552
- dtypes=self.dtypes,
553
- )
554
- return (
555
- dict_to_dataset(data, library=self.stan, coords=self.coords, dims=self.dims),
556
- dict_to_dataset(data_warmup, library=self.stan, coords=self.coords, dims=self.dims),
557
- )
558
-
559
- @requires("posterior_model")
560
- @requires(["observed_data", "constant_data"])
561
- def observed_and_constant_data_to_xarray(self):
562
- """Convert observed data to xarray."""
563
- posterior_model = self.posterior_model
564
- dims = {} if self.dims is None else self.dims
565
- obs_const_dict = {}
566
- for group_name in ("observed_data", "constant_data"):
567
- names = getattr(self, group_name)
568
- if names is None:
569
- continue
570
- names = [names] if isinstance(names, str) else names
571
- data = OrderedDict()
572
- for key in names:
573
- vals = np.atleast_1d(posterior_model.data[key])
574
- val_dims = dims.get(key)
575
- val_dims, coords = generate_dims_coords(
576
- vals.shape, key, dims=val_dims, coords=self.coords
577
- )
578
- data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
579
- obs_const_dict[group_name] = xr.Dataset(
580
- data_vars=data, attrs=make_attrs(library=self.stan)
581
- )
582
- return obs_const_dict
583
-
584
- @requires("posterior_model")
585
- @requires("predictions_constant_data")
586
- def predictions_constant_data_to_xarray(self):
587
- """Convert observed data to xarray."""
588
- posterior_model = self.posterior_model
589
- dims = {} if self.dims is None else self.dims
590
- names = self.predictions_constant_data
591
- names = [names] if isinstance(names, str) else names
592
- data = OrderedDict()
593
- for key in names:
594
- vals = np.atleast_1d(posterior_model.data[key])
595
- val_dims = dims.get(key)
596
- val_dims, coords = generate_dims_coords(
597
- vals.shape, key, dims=val_dims, coords=self.coords
598
- )
599
- data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
600
- return xr.Dataset(data_vars=data, attrs=make_attrs(library=self.stan))
601
-
602
- def to_inference_data(self):
603
- """Convert all available data to an InferenceData object.
604
-
605
- Note that if groups can not be created (i.e., there is no `fit`, so
606
- the `posterior` and `sample_stats` can not be extracted), then the InferenceData
607
- will not have those groups.
608
- """
609
- obs_const_dict = self.observed_and_constant_data_to_xarray()
610
- predictions_const_data = self.predictions_constant_data_to_xarray()
611
- return InferenceData(
612
- save_warmup=self.save_warmup,
613
- **{
614
- "posterior": self.posterior_to_xarray(),
615
- "sample_stats": self.sample_stats_to_xarray(),
616
- "log_likelihood": self.log_likelihood_to_xarray(),
617
- "posterior_predictive": self.posterior_predictive_to_xarray(),
618
- "predictions": self.predictions_to_xarray(),
619
- "prior": self.prior_to_xarray(),
620
- "sample_stats_prior": self.sample_stats_prior_to_xarray(),
621
- "prior_predictive": self.prior_predictive_to_xarray(),
622
- **({} if obs_const_dict is None else obs_const_dict),
623
- **(
624
- {}
625
- if predictions_const_data is None
626
- else {"predictions_constant_data": predictions_const_data}
627
- ),
628
- },
629
- )
630
-
631
-
632
- def get_draws(fit, variables=None, ignore=None, warmup=False, dtypes=None):
633
- """Extract draws from PyStan fit."""
634
- if ignore is None:
635
- ignore = []
636
- if fit.mode == 1:
637
- msg = "Model in mode 'test_grad'. Sampling is not conducted."
638
- raise AttributeError(msg)
639
-
640
- if fit.mode == 2 or fit.sim.get("samples") is None:
641
- msg = "Fit doesn't contain samples."
642
- raise AttributeError(msg)
643
-
644
- if dtypes is None:
645
- dtypes = {}
646
-
647
- dtypes = {**infer_dtypes(fit), **dtypes}
648
-
649
- if variables is None:
650
- variables = fit.sim["pars_oi"]
651
- elif isinstance(variables, str):
652
- variables = [variables]
653
- variables = list(variables)
654
-
655
- for var, dim in zip(fit.sim["pars_oi"], fit.sim["dims_oi"]):
656
- if var in variables and np.prod(dim) == 0:
657
- del variables[variables.index(var)]
658
-
659
- ndraws_warmup = fit.sim["warmup2"]
660
- if max(ndraws_warmup) == 0:
661
- warmup = False
662
- ndraws = [s - w for s, w in zip(fit.sim["n_save"], ndraws_warmup)]
663
- nchain = len(fit.sim["samples"])
664
-
665
- # check if the values are in 0-based (<=2.17) or 1-based indexing (>=2.18)
666
- shift = 1
667
- if any(dim and np.prod(dim) != 0 for dim in fit.sim["dims_oi"]):
668
- # choose variable with lowest number of dims > 1
669
- par_idx = min(
670
- (dim, i) for i, dim in enumerate(fit.sim["dims_oi"]) if (dim and np.prod(dim) != 0)
671
- )[1]
672
- offset = int(sum(map(np.prod, fit.sim["dims_oi"][:par_idx])))
673
- par_offset = int(np.prod(fit.sim["dims_oi"][par_idx]))
674
- par_keys = fit.sim["fnames_oi"][offset : offset + par_offset]
675
- shift = len(par_keys)
676
- for item in par_keys:
677
- _, shape = item.replace("]", "").split("[")
678
- shape_idx_min = min(int(shape_value) for shape_value in shape.split(","))
679
- shift = min(shift, shape_idx_min)
680
- # If shift is higher than 1, this will probably mean that Stan
681
- # has implemented sparse structure (saves only non-zero parts),
682
- # but let's hope that dims are still corresponding to the full shape
683
- shift = int(min(shift, 1))
684
-
685
- var_keys = OrderedDict((var, []) for var in fit.sim["pars_oi"])
686
- for key in fit.sim["fnames_oi"]:
687
- var, *tails = key.split("[")
688
- loc = [Ellipsis]
689
- for tail in tails:
690
- loc = []
691
- for i in tail[:-1].split(","):
692
- loc.append(int(i) - shift)
693
- var_keys[var].append((key, loc))
694
-
695
- shapes = dict(zip(fit.sim["pars_oi"], fit.sim["dims_oi"]))
696
-
697
- variables = [var for var in variables if var not in ignore]
698
-
699
- data = OrderedDict()
700
- data_warmup = OrderedDict()
701
-
702
- for var in variables:
703
- if var in data:
704
- continue
705
- keys_locs = var_keys.get(var, [(var, [Ellipsis])])
706
- shape = shapes.get(var, [])
707
- dtype = dtypes.get(var)
708
-
709
- ndraw = max(ndraws)
710
- ary_shape = [nchain, ndraw] + shape
711
- ary = np.empty(ary_shape, dtype=dtype, order="F")
712
-
713
- if warmup:
714
- nwarmup = max(ndraws_warmup)
715
- ary_warmup_shape = [nchain, nwarmup] + shape
716
- ary_warmup = np.empty(ary_warmup_shape, dtype=dtype, order="F")
717
-
718
- for chain, (pyholder, ndraw, ndraw_warmup) in enumerate(
719
- zip(fit.sim["samples"], ndraws, ndraws_warmup)
720
- ):
721
- axes = [chain, slice(None)]
722
- for key, loc in keys_locs:
723
- ary_slice = tuple(axes + loc)
724
- ary[ary_slice] = pyholder.chains[key][-ndraw:]
725
- if warmup:
726
- ary_warmup[ary_slice] = pyholder.chains[key][:ndraw_warmup]
727
- data[var] = ary
728
- if warmup:
729
- data_warmup[var] = ary_warmup
730
- return data, data_warmup
731
-
732
-
733
- def get_sample_stats(fit, warmup=False, dtypes=None):
734
- """Extract sample stats from PyStan fit."""
735
- if dtypes is None:
736
- dtypes = {}
737
- dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64, **dtypes}
738
-
739
- rename_dict = {
740
- "divergent": "diverging",
741
- "n_leapfrog": "n_steps",
742
- "treedepth": "tree_depth",
743
- "stepsize": "step_size",
744
- "accept_stat": "acceptance_rate",
745
- }
746
-
747
- ndraws_warmup = fit.sim["warmup2"]
748
- if max(ndraws_warmup) == 0:
749
- warmup = False
750
- ndraws = [s - w for s, w in zip(fit.sim["n_save"], ndraws_warmup)]
751
-
752
- extraction = OrderedDict()
753
- extraction_warmup = OrderedDict()
754
- for chain, (pyholder, ndraw, ndraw_warmup) in enumerate(
755
- zip(fit.sim["samples"], ndraws, ndraws_warmup)
756
- ):
757
- if chain == 0:
758
- for key in pyholder["sampler_param_names"]:
759
- extraction[key] = []
760
- if warmup:
761
- extraction_warmup[key] = []
762
- for key, values in zip(pyholder["sampler_param_names"], pyholder["sampler_params"]):
763
- extraction[key].append(values[-ndraw:])
764
- if warmup:
765
- extraction_warmup[key].append(values[:ndraw_warmup])
766
-
767
- data = OrderedDict()
768
- for key, values in extraction.items():
769
- values = np.stack(values, axis=0)
770
- dtype = dtypes.get(key)
771
- values = values.astype(dtype)
772
- name = re.sub("__$", "", key)
773
- name = rename_dict.get(name, name)
774
- data[name] = values
775
-
776
- data_warmup = OrderedDict()
777
- if warmup:
778
- for key, values in extraction_warmup.items():
779
- values = np.stack(values, axis=0)
780
- values = values.astype(dtypes.get(key))
781
- name = re.sub("__$", "", key)
782
- name = rename_dict.get(name, name)
783
- data_warmup[name] = values
784
-
785
- return data, data_warmup
786
-
787
-
788
- def get_attrs(fit):
789
- """Get attributes from PyStan fit object."""
790
- attrs = {}
791
-
792
- try:
793
- attrs["args"] = [deepcopy(holder.args) for holder in fit.sim["samples"]]
794
- except Exception as exp: # pylint: disable=broad-except
795
- _log.warning("Failed to fetch args from fit: %s", exp)
796
- if "args" in attrs:
797
- for arg in attrs["args"]:
798
- if isinstance(arg["init"], bytes):
799
- arg["init"] = arg["init"].decode("utf-8")
800
- attrs["args"] = json.dumps(attrs["args"])
801
- try:
802
- attrs["inits"] = [holder.inits for holder in fit.sim["samples"]]
803
- except Exception as exp: # pylint: disable=broad-except
804
- _log.warning("Failed to fetch `args` from fit: %s", exp)
805
- else:
806
- attrs["inits"] = json.dumps(attrs["inits"])
807
-
808
- attrs["step_size"] = []
809
- attrs["metric"] = []
810
- attrs["inv_metric"] = []
811
- for holder in fit.sim["samples"]:
812
- try:
813
- step_size = float(
814
- re.search(
815
- r"step\s*size\s*=\s*([0-9]+.?[0-9]+)\s*",
816
- holder.adaptation_info,
817
- flags=re.IGNORECASE,
818
- ).group(1)
819
- )
820
- except AttributeError:
821
- step_size = np.nan
822
- attrs["step_size"].append(step_size)
823
-
824
- inv_metric_match = re.search(
825
- r"mass matrix:\s*(.*)\s*$", holder.adaptation_info, flags=re.DOTALL
826
- )
827
- if inv_metric_match:
828
- inv_metric_str = inv_metric_match.group(1)
829
- if "Diagonal elements of inverse mass matrix" in holder.adaptation_info:
830
- metric = "diag_e"
831
- inv_metric = [float(item) for item in inv_metric_str.strip(" #\n").split(",")]
832
- else:
833
- metric = "dense_e"
834
- inv_metric = [
835
- list(map(float, item.split(",")))
836
- for item in re.sub(r"#\s", "", inv_metric_str).splitlines()
837
- ]
838
- else:
839
- metric = "unit_e"
840
- inv_metric = None
841
-
842
- attrs["metric"].append(metric)
843
- attrs["inv_metric"].append(inv_metric)
844
- attrs["inv_metric"] = json.dumps(attrs["inv_metric"])
845
-
846
- if not attrs["step_size"]:
847
- del attrs["step_size"]
848
-
849
- attrs["adaptation_info"] = fit.get_adaptation_info()
850
- attrs["stan_code"] = fit.get_stancode()
851
-
852
- return attrs
853
-
854
-
855
- def get_draws_stan3(fit, model=None, variables=None, ignore=None, warmup=False, dtypes=None):
856
- """Extract draws from PyStan3 fit."""
857
- if ignore is None:
858
- ignore = []
859
-
860
- if dtypes is None:
861
- dtypes = {}
862
-
863
- if model is not None:
864
- dtypes = {**infer_dtypes(fit, model), **dtypes}
865
-
866
- if not fit.save_warmup:
867
- warmup = False
868
-
869
- num_warmup = ceil((fit.num_warmup * fit.save_warmup) / fit.num_thin)
870
-
871
- if variables is None:
872
- variables = fit.param_names
873
- elif isinstance(variables, str):
874
- variables = [variables]
875
- variables = list(variables)
876
-
877
- data = OrderedDict()
878
- data_warmup = OrderedDict()
879
-
880
- for var in variables:
881
- if var in ignore:
882
- continue
883
- if var in data:
884
- continue
885
- dtype = dtypes.get(var)
886
-
887
- new_shape = (*fit.dims[fit.param_names.index(var)], -1, fit.num_chains)
888
- if 0 in new_shape:
889
- continue
890
- values = fit._draws[fit._parameter_indexes(var), :] # pylint: disable=protected-access
891
- values = values.reshape(new_shape, order="F")
892
- values = np.moveaxis(values, [-2, -1], [1, 0])
893
- values = values.astype(dtype)
894
- if warmup:
895
- data_warmup[var] = values[:, num_warmup:]
896
- data[var] = values[:, num_warmup:]
897
-
898
- return data, data_warmup
899
-
900
-
901
- def get_sample_stats_stan3(fit, variables=None, ignore=None, warmup=False, dtypes=None):
902
- """Extract sample stats from PyStan3 fit."""
903
- if dtypes is None:
904
- dtypes = {}
905
- dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64, **dtypes}
906
-
907
- rename_dict = {
908
- "divergent": "diverging",
909
- "n_leapfrog": "n_steps",
910
- "treedepth": "tree_depth",
911
- "stepsize": "step_size",
912
- "accept_stat": "acceptance_rate",
913
- }
914
-
915
- if isinstance(variables, str):
916
- variables = [variables]
917
- if isinstance(ignore, str):
918
- ignore = [ignore]
919
-
920
- if not fit.save_warmup:
921
- warmup = False
922
-
923
- num_warmup = ceil((fit.num_warmup * fit.save_warmup) / fit.num_thin)
924
-
925
- data = OrderedDict()
926
- data_warmup = OrderedDict()
927
- for key in fit.sample_and_sampler_param_names:
928
- if (variables and key not in variables) or (ignore and key in ignore):
929
- continue
930
- new_shape = -1, fit.num_chains
931
- values = fit._draws[fit._parameter_indexes(key)] # pylint: disable=protected-access
932
- values = values.reshape(new_shape, order="F")
933
- values = np.moveaxis(values, [-2, -1], [1, 0])
934
- dtype = dtypes.get(key)
935
- values = values.astype(dtype)
936
- name = re.sub("__$", "", key)
937
- name = rename_dict.get(name, name)
938
- if warmup:
939
- data_warmup[name] = values[:, :num_warmup]
940
- data[name] = values[:, num_warmup:]
941
-
942
- return data, data_warmup
943
-
944
-
945
- def get_attrs_stan3(fit, model=None):
946
- """Get attributes from PyStan3 fit and model object."""
947
- attrs = {}
948
- for key in ["num_chains", "num_samples", "num_thin", "num_warmup", "save_warmup"]:
949
- try:
950
- attrs[key] = getattr(fit, key)
951
- except AttributeError as exp:
952
- _log.warning("Failed to access attribute %s in fit object %s", key, exp)
953
-
954
- if model is not None:
955
- for key in ["model_name", "program_code", "random_seed"]:
956
- try:
957
- attrs[key] = getattr(model, key)
958
- except AttributeError as exp:
959
- _log.warning("Failed to access attribute %s in model object %s", key, exp)
960
-
961
- return attrs
962
-
963
-
964
- def infer_dtypes(fit, model=None):
965
- """Infer dtypes from Stan model code.
966
-
967
- Function strips out generated quantities block and searches for `int`
968
- dtypes after stripping out comments inside the block.
969
- """
970
- if model is None:
971
- stan_code = fit.get_stancode()
972
- model_pars = fit.model_pars
973
- else:
974
- stan_code = model.program_code
975
- model_pars = fit.param_names
976
-
977
- dtypes = {key: item for key, item in infer_stan_dtypes(stan_code).items() if key in model_pars}
978
- return dtypes
979
-
980
-
981
- # pylint disable=too-many-instance-attributes
982
- def from_pystan(
983
- posterior=None,
984
- *,
985
- posterior_predictive=None,
986
- predictions=None,
987
- prior=None,
988
- prior_predictive=None,
989
- observed_data=None,
990
- constant_data=None,
991
- predictions_constant_data=None,
992
- log_likelihood=None,
993
- coords=None,
994
- dims=None,
995
- posterior_model=None,
996
- prior_model=None,
997
- save_warmup=None,
998
- dtypes=None,
999
- ):
1000
- """Convert PyStan data into an InferenceData object.
1001
-
1002
- For a usage example read the
1003
- :ref:`Creating InferenceData section on from_pystan <creating_InferenceData>`
1004
-
1005
- Parameters
1006
- ----------
1007
- posterior : StanFit4Model or stan.fit.Fit
1008
- PyStan fit object for posterior.
1009
- posterior_predictive : str, a list of str
1010
- Posterior predictive samples for the posterior.
1011
- predictions : str, a list of str
1012
- Out-of-sample predictions for the posterior.
1013
- prior : StanFit4Model or stan.fit.Fit
1014
- PyStan fit object for prior.
1015
- prior_predictive : str, a list of str
1016
- Posterior predictive samples for the prior.
1017
- observed_data : str or a list of str
1018
- observed data used in the sampling.
1019
- Observed data is extracted from the `posterior.data`.
1020
- PyStan3 needs model object for the extraction.
1021
- See `posterior_model`.
1022
- constant_data : str or list of str
1023
- Constants relevant to the model (i.e. x values in a linear
1024
- regression).
1025
- predictions_constant_data : str or list of str
1026
- Constants relevant to the model predictions (i.e. new x values in a linear
1027
- regression).
1028
- log_likelihood : dict of {str: str}, list of str or str, optional
1029
- Pointwise log_likelihood for the data. log_likelihood is extracted from the
1030
- posterior. It is recommended to use this argument as a dictionary whose keys
1031
- are observed variable names and its values are the variables storing log
1032
- likelihood arrays in the Stan code. In other cases, a dictionary with keys
1033
- equal to its values is used. By default, if a variable ``log_lik`` is
1034
- present in the Stan model, it will be retrieved as pointwise log
1035
- likelihood values. Use ``False`` or set ``data.log_likelihood`` to
1036
- false to avoid this behaviour.
1037
- coords : dict[str, iterable]
1038
- A dictionary containing the values that are used as index. The key
1039
- is the name of the dimension, the values are the index values.
1040
- dims : dict[str, List(str)]
1041
- A mapping from variables to a list of coordinate names for the variable.
1042
- posterior_model : stan.model.Model
1043
- PyStan3 specific model object. Needed for automatic dtype parsing
1044
- and for the extraction of observed data.
1045
- prior_model : stan.model.Model
1046
- PyStan3 specific model object. Needed for automatic dtype parsing.
1047
- save_warmup : bool
1048
- Save warmup iterations into InferenceData object. If not defined, use default
1049
- defined by the rcParams.
1050
- dtypes: dict
1051
- A dictionary containing dtype information (int, float) for parameters.
1052
- By default dtype information is extracted from the model code.
1053
- Model code is extracted from fit object in PyStan 2 and from model object
1054
- in PyStan 3.
1055
-
1056
- Returns
1057
- -------
1058
- InferenceData object
1059
- """
1060
- check_posterior = (posterior is not None) and (type(posterior).__module__ == "stan.fit")
1061
- check_prior = (prior is not None) and (type(prior).__module__ == "stan.fit")
1062
- if check_posterior or check_prior:
1063
- return PyStan3Converter(
1064
- posterior=posterior,
1065
- posterior_model=posterior_model,
1066
- posterior_predictive=posterior_predictive,
1067
- predictions=predictions,
1068
- prior=prior,
1069
- prior_model=prior_model,
1070
- prior_predictive=prior_predictive,
1071
- observed_data=observed_data,
1072
- constant_data=constant_data,
1073
- predictions_constant_data=predictions_constant_data,
1074
- log_likelihood=log_likelihood,
1075
- coords=coords,
1076
- dims=dims,
1077
- save_warmup=save_warmup,
1078
- dtypes=dtypes,
1079
- ).to_inference_data()
1080
- else:
1081
- return PyStanConverter(
1082
- posterior=posterior,
1083
- posterior_predictive=posterior_predictive,
1084
- predictions=predictions,
1085
- prior=prior,
1086
- prior_predictive=prior_predictive,
1087
- observed_data=observed_data,
1088
- constant_data=constant_data,
1089
- predictions_constant_data=predictions_constant_data,
1090
- log_likelihood=log_likelihood,
1091
- coords=coords,
1092
- dims=dims,
1093
- save_warmup=save_warmup,
1094
- dtypes=dtypes,
1095
- ).to_inference_data()