arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,1233 +0,0 @@
1
- # pylint: disable=too-many-lines
2
- """CmdStanPy-specific conversion code."""
3
- import logging
4
- import re
5
- from collections import defaultdict
6
- from copy import deepcopy
7
- from pathlib import Path
8
-
9
- import numpy as np
10
-
11
- from ..rcparams import rcParams
12
- from .base import dict_to_dataset, infer_stan_dtypes, make_attrs, requires
13
- from .inference_data import InferenceData
14
-
15
- _log = logging.getLogger(__name__)
16
-
17
-
18
- class CmdStanPyConverter:
19
- """Encapsulate CmdStanPy specific logic."""
20
-
21
- # pylint: disable=too-many-instance-attributes
22
-
23
- def __init__(
24
- self,
25
- *,
26
- posterior=None,
27
- posterior_predictive=None,
28
- predictions=None,
29
- prior=None,
30
- prior_predictive=None,
31
- observed_data=None,
32
- constant_data=None,
33
- predictions_constant_data=None,
34
- log_likelihood=None,
35
- index_origin=None,
36
- coords=None,
37
- dims=None,
38
- save_warmup=None,
39
- dtypes=None,
40
- ):
41
- self.posterior = posterior # CmdStanPy CmdStanMCMC object
42
- self.posterior_predictive = posterior_predictive
43
- self.predictions = predictions
44
- self.prior = prior
45
- self.prior_predictive = prior_predictive
46
- self.observed_data = observed_data
47
- self.constant_data = constant_data
48
- self.predictions_constant_data = predictions_constant_data
49
- self.log_likelihood = (
50
- rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood
51
- )
52
- self.index_origin = index_origin
53
- self.coords = coords
54
- self.dims = dims
55
-
56
- self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
57
-
58
- import cmdstanpy # pylint: disable=import-error
59
-
60
- if dtypes is None:
61
- dtypes = {}
62
- elif isinstance(dtypes, cmdstanpy.model.CmdStanModel):
63
- model_code = dtypes.code()
64
- dtypes = infer_stan_dtypes(model_code)
65
- elif isinstance(dtypes, str):
66
- dtypes_path = Path(dtypes)
67
- if dtypes_path.exists():
68
- with dtypes_path.open("r", encoding="UTF-8") as f_obj:
69
- model_code = f_obj.read()
70
- else:
71
- model_code = dtypes
72
- dtypes = infer_stan_dtypes(model_code)
73
-
74
- self.dtypes = dtypes
75
-
76
- if hasattr(self.posterior, "metadata") and hasattr(
77
- self.posterior.metadata, "stan_vars_cols"
78
- ):
79
- if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars_cols:
80
- self.log_likelihood = ["log_lik"]
81
- elif hasattr(self.posterior, "metadata") and hasattr(
82
- self.posterior.metadata, "stan_vars_cols"
83
- ):
84
- if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars_cols:
85
- self.log_likelihood = ["log_lik"]
86
- elif hasattr(self.posterior, "stan_vars_cols"):
87
- if self.log_likelihood is True and "log_lik" in self.posterior.stan_vars_cols:
88
- self.log_likelihood = ["log_lik"]
89
- elif hasattr(self.posterior, "metadata") and hasattr(self.posterior.metadata, "stan_vars"):
90
- if self.log_likelihood is True and "log_lik" in self.posterior.metadata.stan_vars:
91
- self.log_likelihood = ["log_lik"]
92
- elif (
93
- self.log_likelihood is True
94
- and self.posterior is not None
95
- and hasattr(self.posterior, "column_names")
96
- and any(name.split("[")[0] == "log_lik" for name in self.posterior.column_names)
97
- ):
98
- self.log_likelihood = ["log_lik"]
99
-
100
- if isinstance(self.log_likelihood, bool):
101
- self.log_likelihood = None
102
-
103
- self.cmdstanpy = cmdstanpy
104
-
105
- @requires("posterior")
106
- def posterior_to_xarray(self):
107
- """Extract posterior samples from output csv."""
108
- if not (hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")):
109
- return self.posterior_to_xarray_pre_v_0_9_68()
110
- if (
111
- hasattr(self.posterior, "metadata")
112
- and hasattr(self.posterior.metadata, "stan_vars_cols")
113
- ) or hasattr(self.posterior, "stan_vars_cols"):
114
- return self.posterior_to_xarray_pre_v_1_0_0()
115
- if hasattr(self.posterior, "metadata") and hasattr(
116
- self.posterior.metadata, "stan_vars_cols"
117
- ):
118
- return self.posterior_to_xarray_pre_v_1_2_0()
119
-
120
- items = list(self.posterior.metadata.stan_vars)
121
- if self.posterior_predictive is not None:
122
- try:
123
- items = _filter(items, self.posterior_predictive)
124
- except ValueError:
125
- pass
126
- if self.predictions is not None:
127
- try:
128
- items = _filter(items, self.predictions)
129
- except ValueError:
130
- pass
131
- if self.log_likelihood is not None:
132
- try:
133
- items = _filter(items, self.log_likelihood)
134
- except ValueError:
135
- pass
136
-
137
- valid_cols = []
138
- for item in items:
139
- if hasattr(self.posterior, "metadata"):
140
- if item in self.posterior.metadata.stan_vars:
141
- valid_cols.append(item)
142
-
143
- data, data_warmup = _unpack_fit(
144
- self.posterior,
145
- items,
146
- self.save_warmup,
147
- self.dtypes,
148
- )
149
-
150
- dims = deepcopy(self.dims) if self.dims is not None else {}
151
- coords = deepcopy(self.coords) if self.coords is not None else {}
152
-
153
- return (
154
- dict_to_dataset(
155
- data,
156
- library=self.cmdstanpy,
157
- coords=coords,
158
- dims=dims,
159
- index_origin=self.index_origin,
160
- ),
161
- dict_to_dataset(
162
- data_warmup,
163
- library=self.cmdstanpy,
164
- coords=coords,
165
- dims=dims,
166
- index_origin=self.index_origin,
167
- ),
168
- )
169
-
170
- @requires("posterior")
171
- def sample_stats_to_xarray(self):
172
- """Extract sample_stats from prosterior fit."""
173
- return self.stats_to_xarray(self.posterior)
174
-
175
- @requires("prior")
176
- def sample_stats_prior_to_xarray(self):
177
- """Extract sample_stats from prior fit."""
178
- return self.stats_to_xarray(self.prior)
179
-
180
- def stats_to_xarray(self, fit):
181
- """Extract sample_stats from fit."""
182
- if not (hasattr(fit, "metadata") or hasattr(fit, "sampler_vars_cols")):
183
- return self.sample_stats_to_xarray_pre_v_0_9_68(fit)
184
- if (hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols")) or hasattr(
185
- fit, "stan_vars_cols"
186
- ):
187
- return self.sample_stats_to_xarray_pre_v_1_0_0(fit)
188
- if hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols"):
189
- return self.sample_stats_to_xarray_pre_v_1_2_0(fit)
190
-
191
- dtypes = {
192
- "divergent__": bool,
193
- "n_leapfrog__": np.int64,
194
- "treedepth__": np.int64,
195
- **self.dtypes,
196
- }
197
-
198
- items = list(fit.method_variables()) # pylint: disable=protected-access
199
-
200
- rename_dict = {
201
- "divergent": "diverging",
202
- "n_leapfrog": "n_steps",
203
- "treedepth": "tree_depth",
204
- "stepsize": "step_size",
205
- "accept_stat": "acceptance_rate",
206
- }
207
-
208
- data, data_warmup = _unpack_fit(
209
- fit,
210
- items,
211
- self.save_warmup,
212
- self.dtypes,
213
- )
214
- for item in items:
215
- name = re.sub("__$", "", item)
216
- name = rename_dict.get(name, name)
217
- data[name] = data.pop(item).astype(dtypes.get(item, float))
218
- if data_warmup:
219
- data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
220
-
221
- return (
222
- dict_to_dataset(
223
- data,
224
- library=self.cmdstanpy,
225
- coords=self.coords,
226
- dims=self.dims,
227
- index_origin=self.index_origin,
228
- ),
229
- dict_to_dataset(
230
- data_warmup,
231
- library=self.cmdstanpy,
232
- coords=self.coords,
233
- dims=self.dims,
234
- index_origin=self.index_origin,
235
- ),
236
- )
237
-
238
- @requires("posterior")
239
- @requires("posterior_predictive")
240
- def posterior_predictive_to_xarray(self):
241
- """Convert posterior_predictive samples to xarray."""
242
- return self.predictive_to_xarray(self.posterior_predictive, self.posterior)
243
-
244
- @requires("prior")
245
- @requires("prior_predictive")
246
- def prior_predictive_to_xarray(self):
247
- """Convert prior_predictive samples to xarray."""
248
- return self.predictive_to_xarray(self.prior_predictive, self.prior)
249
-
250
- def predictive_to_xarray(self, names, fit):
251
- """Convert predictive samples to xarray."""
252
- predictive = _as_set(names)
253
-
254
- if not (hasattr(fit, "metadata") or hasattr(fit, "stan_vars_cols")): # pre_v_0_9_68
255
- valid_cols = _filter_columns(fit.column_names, predictive)
256
- data, data_warmup = _unpack_frame(
257
- fit,
258
- fit.column_names,
259
- valid_cols,
260
- self.save_warmup,
261
- self.dtypes,
262
- )
263
- elif (hasattr(fit, "metadata") and hasattr(fit.metadata, "sample_vars_cols")) or hasattr(
264
- fit, "stan_vars_cols"
265
- ): # pre_v_1_0_0
266
- data, data_warmup = _unpack_fit_pre_v_1_0_0(
267
- fit,
268
- predictive,
269
- self.save_warmup,
270
- self.dtypes,
271
- )
272
- elif hasattr(fit, "metadata") and hasattr(fit.metadata, "stan_vars_cols"): # pre_v_1_2_0
273
- data, data_warmup = _unpack_fit_pre_v_1_2_0(
274
- fit,
275
- predictive,
276
- self.save_warmup,
277
- self.dtypes,
278
- )
279
- else:
280
- data, data_warmup = _unpack_fit(
281
- fit,
282
- predictive,
283
- self.save_warmup,
284
- self.dtypes,
285
- )
286
-
287
- return (
288
- dict_to_dataset(
289
- data,
290
- library=self.cmdstanpy,
291
- coords=self.coords,
292
- dims=self.dims,
293
- index_origin=self.index_origin,
294
- ),
295
- dict_to_dataset(
296
- data_warmup,
297
- library=self.cmdstanpy,
298
- coords=self.coords,
299
- dims=self.dims,
300
- index_origin=self.index_origin,
301
- ),
302
- )
303
-
304
- @requires("posterior")
305
- @requires("predictions")
306
- def predictions_to_xarray(self):
307
- """Convert out of sample predictions samples to xarray."""
308
- predictions = _as_set(self.predictions)
309
-
310
- if not (
311
- hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")
312
- ): # pre_v_0_9_68
313
- columns = self.posterior.column_names
314
- valid_cols = _filter_columns(columns, predictions)
315
- data, data_warmup = _unpack_frame(
316
- self.posterior,
317
- columns,
318
- valid_cols,
319
- self.save_warmup,
320
- self.dtypes,
321
- )
322
- elif (
323
- hasattr(self.posterior, "metadata")
324
- and hasattr(self.posterior.metadata, "sample_vars_cols")
325
- ) or hasattr(
326
- self.posterior, "stan_vars_cols"
327
- ): # pre_v_1_0_0
328
- data, data_warmup = _unpack_fit_pre_v_1_0_0(
329
- self.posterior,
330
- predictions,
331
- self.save_warmup,
332
- self.dtypes,
333
- )
334
- elif hasattr(self.posterior, "metadata") and hasattr(
335
- self.posterior.metadata, "stan_vars_cols"
336
- ): # pre_v_1_2_0
337
- data, data_warmup = _unpack_fit_pre_v_1_2_0(
338
- self.posterior,
339
- predictions,
340
- self.save_warmup,
341
- self.dtypes,
342
- )
343
- else:
344
- data, data_warmup = _unpack_fit(
345
- self.posterior,
346
- predictions,
347
- self.save_warmup,
348
- self.dtypes,
349
- )
350
-
351
- return (
352
- dict_to_dataset(
353
- data,
354
- library=self.cmdstanpy,
355
- coords=self.coords,
356
- dims=self.dims,
357
- index_origin=self.index_origin,
358
- ),
359
- dict_to_dataset(
360
- data_warmup,
361
- library=self.cmdstanpy,
362
- coords=self.coords,
363
- dims=self.dims,
364
- index_origin=self.index_origin,
365
- ),
366
- )
367
-
368
- @requires("posterior")
369
- @requires("log_likelihood")
370
- def log_likelihood_to_xarray(self):
371
- """Convert elementwise log likelihood samples to xarray."""
372
- log_likelihood = _as_set(self.log_likelihood)
373
-
374
- if not (
375
- hasattr(self.posterior, "metadata") or hasattr(self.posterior, "stan_vars_cols")
376
- ): # pre_v_0_9_68
377
- columns = self.posterior.column_names
378
- valid_cols = _filter_columns(columns, log_likelihood)
379
- data, data_warmup = _unpack_frame(
380
- self.posterior,
381
- columns,
382
- valid_cols,
383
- self.save_warmup,
384
- self.dtypes,
385
- )
386
- elif (
387
- hasattr(self.posterior, "metadata")
388
- and hasattr(self.posterior.metadata, "sample_vars_cols")
389
- ) or hasattr(
390
- self.posterior, "stan_vars_cols"
391
- ): # pre_v_1_0_0
392
- data, data_warmup = _unpack_fit_pre_v_1_0_0(
393
- self.posterior,
394
- log_likelihood,
395
- self.save_warmup,
396
- self.dtypes,
397
- )
398
- elif hasattr(self.posterior, "metadata") and hasattr(
399
- self.posterior.metadata, "stan_vars_cols"
400
- ): # pre_v_1_2_0
401
- data, data_warmup = _unpack_fit_pre_v_1_2_0(
402
- self.posterior,
403
- log_likelihood,
404
- self.save_warmup,
405
- self.dtypes,
406
- )
407
- else:
408
- data, data_warmup = _unpack_fit(
409
- self.posterior,
410
- log_likelihood,
411
- self.save_warmup,
412
- self.dtypes,
413
- )
414
-
415
- if isinstance(self.log_likelihood, dict):
416
- data = {obs_name: data[lik_name] for obs_name, lik_name in self.log_likelihood.items()}
417
- if data_warmup:
418
- data_warmup = {
419
- obs_name: data_warmup[lik_name]
420
- for obs_name, lik_name in self.log_likelihood.items()
421
- }
422
- return (
423
- dict_to_dataset(
424
- data,
425
- library=self.cmdstanpy,
426
- coords=self.coords,
427
- dims=self.dims,
428
- index_origin=self.index_origin,
429
- skip_event_dims=True,
430
- ),
431
- dict_to_dataset(
432
- data_warmup,
433
- library=self.cmdstanpy,
434
- coords=self.coords,
435
- dims=self.dims,
436
- index_origin=self.index_origin,
437
- skip_event_dims=True,
438
- ),
439
- )
440
-
441
- @requires("prior")
442
- def prior_to_xarray(self):
443
- """Convert prior samples to xarray."""
444
- if not (
445
- hasattr(self.prior, "metadata") or hasattr(self.prior, "stan_vars_cols")
446
- ): # pre_v_0_9_68
447
- columns = self.prior.column_names
448
- prior_predictive = _as_set(self.prior_predictive)
449
- prior_predictive = _filter_columns(columns, prior_predictive)
450
-
451
- invalid_cols = set(prior_predictive + [col for col in columns if col.endswith("__")])
452
- valid_cols = [col for col in columns if col not in invalid_cols]
453
-
454
- data, data_warmup = _unpack_frame(
455
- self.prior,
456
- columns,
457
- valid_cols,
458
- self.save_warmup,
459
- self.dtypes,
460
- )
461
- elif (
462
- hasattr(self.prior, "metadata") and hasattr(self.prior.metadata, "sample_vars_cols")
463
- ) or hasattr(
464
- self.prior, "stan_vars_cols"
465
- ): # pre_v_1_0_0
466
- if hasattr(self.prior, "metadata"):
467
- items = list(self.prior.metadata.stan_vars_cols.keys())
468
- else:
469
- items = list(self.prior.stan_vars_cols.keys())
470
- if self.prior_predictive is not None:
471
- try:
472
- items = _filter(items, self.prior_predictive)
473
- except ValueError:
474
- pass
475
- data, data_warmup = _unpack_fit_pre_v_1_0_0(
476
- self.prior,
477
- items,
478
- self.save_warmup,
479
- self.dtypes,
480
- )
481
- elif hasattr(self.prior, "metadata") and hasattr(
482
- self.prior.metadata, "stan_vars_cols"
483
- ): # pre_v_1_2_0
484
- items = list(self.prior.metadata.stan_vars_cols.keys())
485
- if self.prior_predictive is not None:
486
- try:
487
- items = _filter(items, self.prior_predictive)
488
- except ValueError:
489
- pass
490
- data, data_warmup = _unpack_fit_pre_v_1_2_0(
491
- self.prior,
492
- items,
493
- self.save_warmup,
494
- self.dtypes,
495
- )
496
- else:
497
- items = list(self.prior.metadata.stan_vars.keys())
498
- if self.prior_predictive is not None:
499
- try:
500
- items = _filter(items, self.prior_predictive)
501
- except ValueError:
502
- pass
503
- data, data_warmup = _unpack_fit(
504
- self.prior,
505
- items,
506
- self.save_warmup,
507
- self.dtypes,
508
- )
509
-
510
- return (
511
- dict_to_dataset(
512
- data,
513
- library=self.cmdstanpy,
514
- coords=self.coords,
515
- dims=self.dims,
516
- index_origin=self.index_origin,
517
- ),
518
- dict_to_dataset(
519
- data_warmup,
520
- library=self.cmdstanpy,
521
- coords=self.coords,
522
- dims=self.dims,
523
- index_origin=self.index_origin,
524
- ),
525
- )
526
-
527
- @requires("observed_data")
528
- def observed_data_to_xarray(self):
529
- """Convert observed data to xarray."""
530
- return dict_to_dataset(
531
- self.observed_data,
532
- library=self.cmdstanpy,
533
- coords=self.coords,
534
- dims=self.dims,
535
- default_dims=[],
536
- index_origin=self.index_origin,
537
- )
538
-
539
- @requires("constant_data")
540
- def constant_data_to_xarray(self):
541
- """Convert constant data to xarray."""
542
- return dict_to_dataset(
543
- self.constant_data,
544
- library=self.cmdstanpy,
545
- coords=self.coords,
546
- dims=self.dims,
547
- default_dims=[],
548
- index_origin=self.index_origin,
549
- )
550
-
551
- @requires("predictions_constant_data")
552
- def predictions_constant_data_to_xarray(self):
553
- """Convert constant data to xarray."""
554
- return dict_to_dataset(
555
- self.predictions_constant_data,
556
- library=self.cmdstanpy,
557
- coords=self.coords,
558
- dims=self.dims,
559
- attrs=make_attrs(library=self.cmdstanpy),
560
- default_dims=[],
561
- index_origin=self.index_origin,
562
- )
563
-
564
- def to_inference_data(self):
565
- """Convert all available data to an InferenceData object.
566
-
567
- Note that if groups can not be created (i.e., there is no `output`, so
568
- the `posterior` and `sample_stats` can not be extracted), then the InferenceData
569
- will not have those groups.
570
- """
571
- return InferenceData(
572
- save_warmup=self.save_warmup,
573
- **{
574
- "posterior": self.posterior_to_xarray(),
575
- "sample_stats": self.sample_stats_to_xarray(),
576
- "posterior_predictive": self.posterior_predictive_to_xarray(),
577
- "predictions": self.predictions_to_xarray(),
578
- "prior": self.prior_to_xarray(),
579
- "sample_stats_prior": self.sample_stats_prior_to_xarray(),
580
- "prior_predictive": self.prior_predictive_to_xarray(),
581
- "observed_data": self.observed_data_to_xarray(),
582
- "constant_data": self.constant_data_to_xarray(),
583
- "predictions_constant_data": self.predictions_constant_data_to_xarray(),
584
- "log_likelihood": self.log_likelihood_to_xarray(),
585
- },
586
- )
587
-
588
- def posterior_to_xarray_pre_v_1_2_0(self):
589
- items = list(self.posterior.metadata.stan_vars_cols)
590
- if self.posterior_predictive is not None:
591
- try:
592
- items = _filter(items, self.posterior_predictive)
593
- except ValueError:
594
- pass
595
- if self.predictions is not None:
596
- try:
597
- items = _filter(items, self.predictions)
598
- except ValueError:
599
- pass
600
- if self.log_likelihood is not None:
601
- try:
602
- items = _filter(items, self.log_likelihood)
603
- except ValueError:
604
- pass
605
-
606
- valid_cols = []
607
- for item in items:
608
- if hasattr(self.posterior, "metadata"):
609
- if item in self.posterior.metadata.stan_vars_cols:
610
- valid_cols.append(item)
611
-
612
- data, data_warmup = _unpack_fit_pre_v_1_2_0(
613
- self.posterior,
614
- items,
615
- self.save_warmup,
616
- self.dtypes,
617
- )
618
-
619
- dims = deepcopy(self.dims) if self.dims is not None else {}
620
- coords = deepcopy(self.coords) if self.coords is not None else {}
621
-
622
- return (
623
- dict_to_dataset(
624
- data,
625
- library=self.cmdstanpy,
626
- coords=coords,
627
- dims=dims,
628
- index_origin=self.index_origin,
629
- ),
630
- dict_to_dataset(
631
- data_warmup,
632
- library=self.cmdstanpy,
633
- coords=coords,
634
- dims=dims,
635
- index_origin=self.index_origin,
636
- ),
637
- )
638
-
639
- @requires("posterior")
640
- def posterior_to_xarray_pre_v_1_0_0(self):
641
- if hasattr(self.posterior, "metadata"):
642
- items = list(self.posterior.metadata.stan_vars_cols.keys())
643
- else:
644
- items = list(self.posterior.stan_vars_cols.keys())
645
- if self.posterior_predictive is not None:
646
- try:
647
- items = _filter(items, self.posterior_predictive)
648
- except ValueError:
649
- pass
650
- if self.predictions is not None:
651
- try:
652
- items = _filter(items, self.predictions)
653
- except ValueError:
654
- pass
655
- if self.log_likelihood is not None:
656
- try:
657
- items = _filter(items, self.log_likelihood)
658
- except ValueError:
659
- pass
660
-
661
- valid_cols = []
662
- for item in items:
663
- if hasattr(self.posterior, "metadata"):
664
- valid_cols.extend(self.posterior.metadata.stan_vars_cols[item])
665
- else:
666
- valid_cols.extend(self.posterior.stan_vars_cols[item])
667
-
668
- data, data_warmup = _unpack_fit_pre_v_1_0_0(
669
- self.posterior,
670
- items,
671
- self.save_warmup,
672
- self.dtypes,
673
- )
674
-
675
- dims = deepcopy(self.dims) if self.dims is not None else {}
676
- coords = deepcopy(self.coords) if self.coords is not None else {}
677
-
678
- return (
679
- dict_to_dataset(
680
- data,
681
- library=self.cmdstanpy,
682
- coords=coords,
683
- dims=dims,
684
- index_origin=self.index_origin,
685
- ),
686
- dict_to_dataset(
687
- data_warmup,
688
- library=self.cmdstanpy,
689
- coords=coords,
690
- dims=dims,
691
- index_origin=self.index_origin,
692
- ),
693
- )
694
-
695
- @requires("posterior")
696
- def posterior_to_xarray_pre_v_0_9_68(self):
697
- """Extract posterior samples from output csv."""
698
- columns = self.posterior.column_names
699
-
700
- # filter posterior_predictive, predictions and log_likelihood
701
- posterior_predictive = self.posterior_predictive
702
- if posterior_predictive is None:
703
- posterior_predictive = []
704
- elif isinstance(posterior_predictive, str):
705
- posterior_predictive = [
706
- col for col in columns if posterior_predictive == col.split("[")[0].split(".")[0]
707
- ]
708
- else:
709
- posterior_predictive = [
710
- col
711
- for col in columns
712
- if any(item == col.split("[")[0].split(".")[0] for item in posterior_predictive)
713
- ]
714
-
715
- predictions = self.predictions
716
- if predictions is None:
717
- predictions = []
718
- elif isinstance(predictions, str):
719
- predictions = [col for col in columns if predictions == col.split("[")[0].split(".")[0]]
720
- else:
721
- predictions = [
722
- col
723
- for col in columns
724
- if any(item == col.split("[")[0].split(".")[0] for item in predictions)
725
- ]
726
-
727
- log_likelihood = self.log_likelihood
728
- if log_likelihood is None:
729
- log_likelihood = []
730
- elif isinstance(log_likelihood, str):
731
- log_likelihood = [
732
- col for col in columns if log_likelihood == col.split("[")[0].split(".")[0]
733
- ]
734
- else:
735
- log_likelihood = [
736
- col
737
- for col in columns
738
- if any(item == col.split("[")[0].split(".")[0] for item in log_likelihood)
739
- ]
740
-
741
- invalid_cols = set(
742
- posterior_predictive
743
- + predictions
744
- + log_likelihood
745
- + [col for col in columns if col.endswith("__")]
746
- )
747
- valid_cols = [col for col in columns if col not in invalid_cols]
748
- data, data_warmup = _unpack_frame(
749
- self.posterior,
750
- columns,
751
- valid_cols,
752
- self.save_warmup,
753
- self.dtypes,
754
- )
755
-
756
- return (
757
- dict_to_dataset(
758
- data,
759
- library=self.cmdstanpy,
760
- coords=self.coords,
761
- dims=self.dims,
762
- index_origin=self.index_origin,
763
- ),
764
- dict_to_dataset(
765
- data_warmup,
766
- library=self.cmdstanpy,
767
- coords=self.coords,
768
- dims=self.dims,
769
- index_origin=self.index_origin,
770
- ),
771
- )
772
-
773
- def sample_stats_to_xarray_pre_v_1_2_0(self, fit):
774
- dtypes = {
775
- "divergent__": bool,
776
- "n_leapfrog__": np.int64,
777
- "treedepth__": np.int64,
778
- **self.dtypes,
779
- }
780
-
781
- items = list(fit.metadata.method_vars_cols.keys()) # pylint: disable=protected-access
782
-
783
- rename_dict = {
784
- "divergent": "diverging",
785
- "n_leapfrog": "n_steps",
786
- "treedepth": "tree_depth",
787
- "stepsize": "step_size",
788
- "accept_stat": "acceptance_rate",
789
- }
790
-
791
- data, data_warmup = _unpack_fit_pre_v_1_2_0(
792
- fit,
793
- items,
794
- self.save_warmup,
795
- self.dtypes,
796
- )
797
- for item in items:
798
- name = re.sub("__$", "", item)
799
- name = rename_dict.get(name, name)
800
- data[name] = data.pop(item).astype(dtypes.get(item, float))
801
- if data_warmup:
802
- data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
803
-
804
- return (
805
- dict_to_dataset(
806
- data,
807
- library=self.cmdstanpy,
808
- coords=self.coords,
809
- dims=self.dims,
810
- index_origin=self.index_origin,
811
- ),
812
- dict_to_dataset(
813
- data_warmup,
814
- library=self.cmdstanpy,
815
- coords=self.coords,
816
- dims=self.dims,
817
- index_origin=self.index_origin,
818
- ),
819
- )
820
-
821
- def sample_stats_to_xarray_pre_v_1_0_0(self, fit):
822
- """Extract sample_stats from fit."""
823
- dtypes = {
824
- "divergent__": bool,
825
- "n_leapfrog__": np.int64,
826
- "treedepth__": np.int64,
827
- **self.dtypes,
828
- }
829
- if hasattr(fit, "metadata"):
830
- items = list(fit.metadata._method_vars_cols.keys()) # pylint: disable=protected-access
831
- else:
832
- items = list(fit.sampler_vars_cols.keys())
833
- rename_dict = {
834
- "divergent": "diverging",
835
- "n_leapfrog": "n_steps",
836
- "treedepth": "tree_depth",
837
- "stepsize": "step_size",
838
- "accept_stat": "acceptance_rate",
839
- }
840
-
841
- data, data_warmup = _unpack_fit_pre_v_1_0_0(
842
- fit,
843
- items,
844
- self.save_warmup,
845
- self.dtypes,
846
- )
847
- for item in items:
848
- name = re.sub("__$", "", item)
849
- name = rename_dict.get(name, name)
850
- data[name] = data.pop(item).astype(dtypes.get(item, float))
851
- if data_warmup:
852
- data_warmup[name] = data_warmup.pop(item).astype(dtypes.get(item, float))
853
- return (
854
- dict_to_dataset(
855
- data,
856
- library=self.cmdstanpy,
857
- coords=self.coords,
858
- dims=self.dims,
859
- index_origin=self.index_origin,
860
- ),
861
- dict_to_dataset(
862
- data_warmup,
863
- library=self.cmdstanpy,
864
- coords=self.coords,
865
- dims=self.dims,
866
- index_origin=self.index_origin,
867
- ),
868
- )
869
-
870
- def sample_stats_to_xarray_pre_v_0_9_68(self, fit):
871
- """Extract sample_stats from fit."""
872
- dtypes = {"divergent__": bool, "n_leapfrog__": np.int64, "treedepth__": np.int64}
873
- columns = fit.column_names
874
- valid_cols = [col for col in columns if col.endswith("__")]
875
- data, data_warmup = _unpack_frame(
876
- fit,
877
- columns,
878
- valid_cols,
879
- self.save_warmup,
880
- self.dtypes,
881
- )
882
- for s_param in list(data.keys()):
883
- s_param_, *_ = s_param.split(".")
884
- name = re.sub("__$", "", s_param_)
885
- name = "diverging" if name == "divergent" else name
886
- data[name] = data.pop(s_param).astype(dtypes.get(s_param, float))
887
- if data_warmup:
888
- data_warmup[name] = data_warmup.pop(s_param).astype(dtypes.get(s_param, float))
889
- return (
890
- dict_to_dataset(
891
- data,
892
- library=self.cmdstanpy,
893
- coords=self.coords,
894
- dims=self.dims,
895
- index_origin=self.index_origin,
896
- ),
897
- dict_to_dataset(
898
- data_warmup,
899
- library=self.cmdstanpy,
900
- coords=self.coords,
901
- dims=self.dims,
902
- index_origin=self.index_origin,
903
- ),
904
- )
905
-
906
-
907
- def _as_set(spec):
908
- """Uniform representation for args which be name or list of names."""
909
- if spec is None:
910
- return []
911
- if isinstance(spec, str):
912
- return [spec]
913
- try:
914
- return set(spec.values())
915
- except AttributeError:
916
- return set(spec)
917
-
918
-
919
- def _filter(names, spec):
920
- """Remove names from list of names."""
921
- if isinstance(spec, str):
922
- names.remove(spec)
923
- elif isinstance(spec, list):
924
- for item in spec:
925
- names.remove(item)
926
- elif isinstance(spec, dict):
927
- for item in spec.values():
928
- names.remove(item)
929
- return names
930
-
931
-
932
- def _filter_columns(columns, spec):
933
- """Parse variable name from column label, removing element index, if any."""
934
- return [col for col in columns if col.split("[")[0].split(".")[0] in spec]
935
-
936
-
937
- def _unpack_fit(fit, items, save_warmup, dtypes):
938
- num_warmup = 0
939
- if save_warmup:
940
- if not fit._save_warmup: # pylint: disable=protected-access
941
- save_warmup = False
942
- else:
943
- num_warmup = fit.num_draws_warmup
944
-
945
- nchains = fit.chains
946
- sample = {}
947
- sample_warmup = {}
948
- stan_vars_cols = list(fit.metadata.stan_vars)
949
- sampler_vars = fit.method_variables()
950
- for item in items:
951
- if item in stan_vars_cols:
952
- raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
953
- raw_draws = np.swapaxes(
954
- raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
955
- )
956
- elif item in sampler_vars:
957
- raw_draws = np.swapaxes(sampler_vars[item], 0, 1)
958
- else:
959
- raise ValueError(f"fit data, unknown variable: {item}")
960
- raw_draws = raw_draws.astype(dtypes.get(item))
961
- if save_warmup:
962
- sample_warmup[item] = raw_draws[:, :num_warmup, ...]
963
- sample[item] = raw_draws[:, num_warmup:, ...]
964
- else:
965
- sample[item] = raw_draws
966
-
967
- return sample, sample_warmup
968
-
969
-
970
- def _unpack_fit_pre_v_1_2_0(fit, items, save_warmup, dtypes):
971
- num_warmup = 0
972
- if save_warmup:
973
- if not fit._save_warmup: # pylint: disable=protected-access
974
- save_warmup = False
975
- else:
976
- num_warmup = fit.num_draws_warmup
977
-
978
- nchains = fit.chains
979
- sample = {}
980
- sample_warmup = {}
981
- stan_vars_cols = list(fit.metadata.stan_vars_cols)
982
- sampler_vars = fit.method_variables()
983
- for item in items:
984
- if item in stan_vars_cols:
985
- raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
986
- raw_draws = np.swapaxes(
987
- raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
988
- )
989
- elif item in sampler_vars:
990
- raw_draws = np.swapaxes(sampler_vars[item], 0, 1)
991
- else:
992
- raise ValueError(f"fit data, unknown variable: {item}")
993
- raw_draws = raw_draws.astype(dtypes.get(item))
994
- if save_warmup:
995
- sample_warmup[item] = raw_draws[:, :num_warmup, ...]
996
- sample[item] = raw_draws[:, num_warmup:, ...]
997
- else:
998
- sample[item] = raw_draws
999
-
1000
- return sample, sample_warmup
1001
-
1002
-
1003
- def _unpack_fit_pre_v_1_0_0(fit, items, save_warmup, dtypes):
1004
- """Transform fit to dictionary containing ndarrays.
1005
-
1006
- Parameters
1007
- ----------
1008
- data: cmdstanpy.CmdStanMCMC
1009
- items: list
1010
- save_warmup: bool
1011
- dtypes: dict
1012
-
1013
- Returns
1014
- -------
1015
- dict
1016
- key, values pairs. Values are formatted to shape = (chains, draws, *shape)
1017
- """
1018
- num_warmup = 0
1019
- if save_warmup:
1020
- if not fit._save_warmup: # pylint: disable=protected-access
1021
- save_warmup = False
1022
- else:
1023
- num_warmup = fit.num_draws_warmup
1024
-
1025
- nchains = fit.chains
1026
- draws = np.swapaxes(fit.draws(inc_warmup=save_warmup), 0, 1)
1027
- sample = {}
1028
- sample_warmup = {}
1029
-
1030
- stan_vars_cols = fit.metadata.stan_vars_cols if hasattr(fit, "metadata") else fit.stan_vars_cols
1031
- sampler_vars_cols = (
1032
- fit.metadata._method_vars_cols # pylint: disable=protected-access
1033
- if hasattr(fit, "metadata")
1034
- else fit.sampler_vars_cols
1035
- )
1036
- for item in items:
1037
- if item in stan_vars_cols:
1038
- col_idxs = stan_vars_cols[item]
1039
- raw_draws = fit.stan_variable(item, inc_warmup=save_warmup)
1040
- raw_draws = np.swapaxes(
1041
- raw_draws.reshape((-1, nchains, *raw_draws.shape[1:]), order="F"), 0, 1
1042
- )
1043
- elif item in sampler_vars_cols:
1044
- col_idxs = sampler_vars_cols[item]
1045
- raw_draws = draws[..., col_idxs[0]]
1046
- else:
1047
- raise ValueError(f"fit data, unknown variable: {item}")
1048
- raw_draws = raw_draws.astype(dtypes.get(item))
1049
- if save_warmup:
1050
- sample_warmup[item] = raw_draws[:, :num_warmup, ...]
1051
- sample[item] = raw_draws[:, num_warmup:, ...]
1052
- else:
1053
- sample[item] = raw_draws
1054
-
1055
- return sample, sample_warmup
1056
-
1057
-
1058
- def _unpack_frame(fit, columns, valid_cols, save_warmup, dtypes):
1059
- """Transform fit to dictionary containing ndarrays.
1060
-
1061
- Called when fit object created by cmdstanpy version < 0.9.68
1062
-
1063
- Parameters
1064
- ----------
1065
- data: cmdstanpy.CmdStanMCMC
1066
- columns: list
1067
- valid_cols: list
1068
- save_warmup: bool
1069
- dtypes: dict
1070
-
1071
- Returns
1072
- -------
1073
- dict
1074
- key, values pairs. Values are formatted to shape = (chains, draws, *shape)
1075
- """
1076
- if save_warmup and not fit._save_warmup: # pylint: disable=protected-access
1077
- save_warmup = False
1078
- if hasattr(fit, "draws"):
1079
- data = fit.draws(inc_warmup=save_warmup)
1080
- if save_warmup:
1081
- num_warmup = fit._draws_warmup # pylint: disable=protected-access
1082
- data_warmup = data[:num_warmup]
1083
- data = data[num_warmup:]
1084
- else:
1085
- data = fit.sample
1086
- if save_warmup:
1087
- data_warmup = fit.warmup[: data.shape[0]]
1088
-
1089
- draws, chains, *_ = data.shape
1090
- if save_warmup:
1091
- draws_warmup, *_ = data_warmup.shape
1092
-
1093
- column_groups = defaultdict(list)
1094
- column_locs = defaultdict(list)
1095
- # iterate flat column names
1096
- for i, col in enumerate(columns):
1097
- if "." in col:
1098
- # parse parameter names e.g. X.1.2 --> X, (1,2)
1099
- col_base, *col_tail = col.split(".")
1100
- else:
1101
- # parse parameter names e.g. X[1,2] --> X, (1,2)
1102
- col_base, *col_tail = col.replace("]", "").replace("[", ",").split(",")
1103
- if len(col_tail):
1104
- # gather nD array locations
1105
- column_groups[col_base].append(tuple(map(int, col_tail)))
1106
- # gather raw data locations for each parameter
1107
- column_locs[col_base].append(i)
1108
- # gather parameter dimensions (assumes dense arrays)
1109
- dims = {
1110
- colname: tuple(np.array(col_dims).max(0)) for colname, col_dims in column_groups.items()
1111
- }
1112
- sample = {}
1113
- sample_warmup = {}
1114
- valid_base_cols = []
1115
- # get list of parameters for extraction (basename) X.1.2 --> X
1116
- for col in valid_cols:
1117
- base_col = col.split("[")[0].split(".")[0]
1118
- if base_col not in valid_base_cols:
1119
- valid_base_cols.append(base_col)
1120
-
1121
- # extract each wanted parameter to ndarray with correct shape
1122
- for key in valid_base_cols:
1123
- ndim = dims.get(key, None)
1124
- shape_location = column_groups.get(key, None)
1125
- if ndim is not None:
1126
- sample[key] = np.full((chains, draws, *ndim), np.nan)
1127
- if save_warmup:
1128
- sample_warmup[key] = np.full((chains, draws_warmup, *ndim), np.nan)
1129
- if shape_location is None:
1130
- # reorder draw, chain -> chain, draw
1131
- (i,) = column_locs[key]
1132
- sample[key] = np.swapaxes(data[..., i], 0, 1)
1133
- if save_warmup:
1134
- sample_warmup[key] = np.swapaxes(data_warmup[..., i], 0, 1)
1135
- else:
1136
- for i, shape_loc in zip(column_locs[key], shape_location):
1137
- # location to insert extracted array
1138
- shape_loc = tuple([Ellipsis] + [j - 1 for j in shape_loc])
1139
- # reorder draw, chain -> chain, draw and insert to ndarray
1140
- sample[key][shape_loc] = np.swapaxes(data[..., i], 0, 1)
1141
- if save_warmup:
1142
- sample_warmup[key][shape_loc] = np.swapaxes(data_warmup[..., i], 0, 1)
1143
-
1144
- for key, dtype in dtypes.items():
1145
- if key in sample:
1146
- sample[key] = sample[key].astype(dtype)
1147
- if save_warmup and key in sample_warmup:
1148
- sample_warmup[key] = sample_warmup[key].astype(dtype)
1149
- return sample, sample_warmup
1150
-
1151
-
1152
- def from_cmdstanpy(
1153
- posterior=None,
1154
- *,
1155
- posterior_predictive=None,
1156
- predictions=None,
1157
- prior=None,
1158
- prior_predictive=None,
1159
- observed_data=None,
1160
- constant_data=None,
1161
- predictions_constant_data=None,
1162
- log_likelihood=None,
1163
- index_origin=None,
1164
- coords=None,
1165
- dims=None,
1166
- save_warmup=None,
1167
- dtypes=None,
1168
- ):
1169
- """Convert CmdStanPy data into an InferenceData object.
1170
-
1171
- For a usage example read the
1172
- :ref:`Creating InferenceData section on from_cmdstanpy <creating_InferenceData>`
1173
-
1174
- Parameters
1175
- ----------
1176
- posterior : CmdStanMCMC object
1177
- CmdStanPy CmdStanMCMC
1178
- posterior_predictive : str, list of str
1179
- Posterior predictive samples for the fit.
1180
- predictions : str, list of str
1181
- Out of sample prediction samples for the fit.
1182
- prior : CmdStanMCMC
1183
- CmdStanPy CmdStanMCMC
1184
- prior_predictive : str, list of str
1185
- Prior predictive samples for the fit.
1186
- observed_data : dict
1187
- Observed data used in the sampling.
1188
- constant_data : dict
1189
- Constant data used in the sampling.
1190
- predictions_constant_data : dict
1191
- Constant data for predictions used in the sampling.
1192
- log_likelihood : str, list of str, dict of {str: str}, optional
1193
- Pointwise log_likelihood for the data. If a dict, its keys should represent var_names
1194
- from the corresponding observed data and its values the stan variable where the
1195
- data is stored. By default, if a variable ``log_lik`` is present in the Stan model,
1196
- it will be retrieved as pointwise log likelihood values. Use ``False``
1197
- or set ``data.log_likelihood`` to false to avoid this behaviour.
1198
- index_origin : int, optional
1199
- Starting value of integer coordinate values. Defaults to the value in rcParam
1200
- ``data.index_origin``.
1201
- coords : dict of str or dict of iterable
1202
- A dictionary containing the values that are used as index. The key
1203
- is the name of the dimension, the values are the index values.
1204
- dims : dict of str or list of str
1205
- A mapping from variables to a list of coordinate names for the variable.
1206
- save_warmup : bool
1207
- Save warmup iterations into InferenceData object, if found in the input files.
1208
- If not defined, use default defined by the rcParams.
1209
- dtypes: dict or str or cmdstanpy.CmdStanModel
1210
- A dictionary containing dtype information (int, float) for parameters.
1211
- If input is a string, it is assumed to be a model code or path to model code file.
1212
- Model code can extracted from cmdstanpy.CmdStanModel object.
1213
-
1214
- Returns
1215
- -------
1216
- InferenceData object
1217
- """
1218
- return CmdStanPyConverter(
1219
- posterior=posterior,
1220
- posterior_predictive=posterior_predictive,
1221
- predictions=predictions,
1222
- prior=prior,
1223
- prior_predictive=prior_predictive,
1224
- observed_data=observed_data,
1225
- constant_data=constant_data,
1226
- predictions_constant_data=predictions_constant_data,
1227
- log_likelihood=log_likelihood,
1228
- index_origin=index_origin,
1229
- coords=coords,
1230
- dims=dims,
1231
- save_warmup=save_warmup,
1232
- dtypes=dtypes,
1233
- ).to_inference_data()