arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/data/io_cmdstan.py DELETED
@@ -1,1036 +0,0 @@
1
- # pylint: disable=too-many-lines
2
- """CmdStan-specific conversion code."""
3
- try:
4
- import ujson as json
5
- except ImportError:
6
- # Can't find ujson using json
7
- # mypy struggles with conditional imports expressed as catching ImportError:
8
- # https://github.com/python/mypy/issues/1153
9
- import json # type: ignore
10
- import logging
11
- import os
12
- import re
13
- from collections import defaultdict
14
- from glob import glob
15
- from pathlib import Path
16
- from typing import Dict, List, Optional, Union
17
-
18
- import numpy as np
19
-
20
- from .. import utils
21
- from ..rcparams import rcParams
22
- from .base import CoordSpec, DimSpec, dict_to_dataset, infer_stan_dtypes, requires
23
- from .inference_data import InferenceData
24
-
25
- _log = logging.getLogger(__name__)
26
-
27
-
28
- def check_glob(path, group, disable_glob):
29
- """Find files with glob."""
30
- if isinstance(path, str) and (not disable_glob):
31
- path_glob = glob(path)
32
- if path_glob:
33
- path = sorted(path_glob)
34
- msg = "\n".join(f"{i}: {os.path.normpath(fpath)}" for i, fpath in enumerate(path, 1))
35
- len_p = len(path)
36
- _log.info("glob found %d files for '%s':\n%s", len_p, group, msg)
37
- return path
38
-
39
-
40
- class CmdStanConverter:
41
- """Encapsulate CmdStan specific logic."""
42
-
43
- # pylint: disable=too-many-instance-attributes
44
-
45
- def __init__(
46
- self,
47
- *,
48
- posterior=None,
49
- posterior_predictive=None,
50
- predictions=None,
51
- prior=None,
52
- prior_predictive=None,
53
- observed_data=None,
54
- observed_data_var=None,
55
- constant_data=None,
56
- constant_data_var=None,
57
- predictions_constant_data=None,
58
- predictions_constant_data_var=None,
59
- log_likelihood=None,
60
- index_origin=None,
61
- coords=None,
62
- dims=None,
63
- disable_glob=False,
64
- save_warmup=None,
65
- dtypes=None,
66
- ):
67
- self.posterior_ = check_glob(posterior, "posterior", disable_glob)
68
- self.posterior_predictive = check_glob(
69
- posterior_predictive, "posterior_predictive", disable_glob
70
- )
71
- self.predictions = check_glob(predictions, "predictions", disable_glob)
72
- self.prior_ = check_glob(prior, "prior", disable_glob)
73
- self.prior_predictive = check_glob(prior_predictive, "prior_predictive", disable_glob)
74
- self.log_likelihood = check_glob(log_likelihood, "log_likelihood", disable_glob)
75
- self.observed_data = observed_data
76
- self.observed_data_var = observed_data_var
77
- self.constant_data = constant_data
78
- self.constant_data_var = constant_data_var
79
- self.predictions_constant_data = predictions_constant_data
80
- self.predictions_constant_data_var = predictions_constant_data_var
81
- self.coords = coords if coords is not None else {}
82
- self.dims = dims if dims is not None else {}
83
-
84
- self.posterior = None
85
- self.prior = None
86
- self.attrs = None
87
- self.attrs_prior = None
88
-
89
- self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
90
- self.index_origin = index_origin
91
-
92
- if dtypes is None:
93
- dtypes = {}
94
- elif isinstance(dtypes, str):
95
- dtypes_path = Path(dtypes)
96
- if dtypes_path.exists():
97
- with dtypes_path.open("r", encoding="UTF-8") as f_obj:
98
- model_code = f_obj.read()
99
- else:
100
- model_code = dtypes
101
-
102
- dtypes = infer_stan_dtypes(model_code)
103
-
104
- self.dtypes = dtypes
105
-
106
- # populate posterior and sample_stats
107
- self._parse_posterior()
108
- self._parse_prior()
109
-
110
- if (
111
- self.log_likelihood is None
112
- and self.posterior_ is not None
113
- and any(name.split(".")[0] == "log_lik" for name in self.posterior_columns)
114
- ):
115
- self.log_likelihood = ["log_lik"]
116
- elif isinstance(self.log_likelihood, bool):
117
- self.log_likelihood = None
118
-
119
- @requires("posterior_")
120
- def _parse_posterior(self):
121
- """Read csv paths to list of ndarrays."""
122
- paths = self.posterior_
123
- if isinstance(paths, str):
124
- paths = [paths]
125
-
126
- chain_data = []
127
- columns = None
128
- for path in paths:
129
- output_data = _read_output(path)
130
- chain_data.append(output_data)
131
- if columns is None:
132
- columns = output_data
133
-
134
- self.posterior = (
135
- [item["sample"] for item in chain_data],
136
- [item["sample_warmup"] for item in chain_data],
137
- )
138
- self.posterior_columns = columns["sample_columns"]
139
- self.sample_stats_columns = columns["sample_stats_columns"]
140
-
141
- attrs = {}
142
- for item in chain_data:
143
- for key, value in item["configuration_info"].items():
144
- if key not in attrs:
145
- attrs[key] = []
146
- attrs[key].append(value)
147
- self.attrs = attrs
148
-
149
- @requires("prior_")
150
- def _parse_prior(self):
151
- """Read csv paths to list of ndarrays."""
152
- paths = self.prior_
153
- if isinstance(paths, str):
154
- paths = [paths]
155
-
156
- chain_data = []
157
- columns = None
158
- for path in paths:
159
- output_data = _read_output(path)
160
- chain_data.append(output_data)
161
- if columns is None:
162
- columns = output_data
163
-
164
- self.prior = (
165
- [item["sample"] for item in chain_data],
166
- [item["sample_warmup"] for item in chain_data],
167
- )
168
- self.prior_columns = columns["sample_columns"]
169
- self.sample_stats_prior_columns = columns["sample_stats_columns"]
170
-
171
- attrs = {}
172
- for item in chain_data:
173
- for key, value in item["configuration_info"].items():
174
- if key not in attrs:
175
- attrs[key] = []
176
- attrs[key].append(value)
177
- self.attrs_prior = attrs
178
-
179
- @requires("posterior")
180
- def posterior_to_xarray(self):
181
- """Extract posterior samples from output csv."""
182
- columns = self.posterior_columns
183
-
184
- # filter posterior_predictive, predictions and log_likelihood
185
- posterior_predictive = self.posterior_predictive
186
- if posterior_predictive is None or (
187
- isinstance(posterior_predictive, str) and posterior_predictive.lower().endswith(".csv")
188
- ):
189
- posterior_predictive = []
190
- elif isinstance(posterior_predictive, str):
191
- posterior_predictive = [
192
- col for col in columns if posterior_predictive == col.split(".")[0]
193
- ]
194
- else:
195
- posterior_predictive = [
196
- col
197
- for col in columns
198
- if any(item == col.split(".")[0] for item in posterior_predictive)
199
- ]
200
-
201
- predictions = self.predictions
202
- if predictions is None or (
203
- isinstance(predictions, str) and predictions.lower().endswith(".csv")
204
- ):
205
- predictions = []
206
- elif isinstance(predictions, str):
207
- predictions = [col for col in columns if predictions == col.split(".")[0]]
208
- else:
209
- predictions = [
210
- col for col in columns if any(item == col.split(".")[0] for item in predictions)
211
- ]
212
-
213
- log_likelihood = self.log_likelihood
214
- if log_likelihood is None or (
215
- isinstance(log_likelihood, str) and log_likelihood.lower().endswith(".csv")
216
- ):
217
- log_likelihood = []
218
- elif isinstance(log_likelihood, str):
219
- log_likelihood = [col for col in columns if log_likelihood == col.split(".")[0]]
220
- elif isinstance(log_likelihood, dict):
221
- log_likelihood = [
222
- col
223
- for col in columns
224
- if any(item == col.split(".")[0] for item in log_likelihood.values())
225
- ]
226
- else:
227
- log_likelihood = [
228
- col for col in columns if any(item == col.split(".")[0] for item in log_likelihood)
229
- ]
230
-
231
- invalid_cols = posterior_predictive + predictions + log_likelihood
232
- valid_cols = {col: idx for col, idx in columns.items() if col not in invalid_cols}
233
- data = _unpack_ndarrays(self.posterior[0], valid_cols, self.dtypes)
234
- data_warmup = _unpack_ndarrays(self.posterior[1], valid_cols, self.dtypes)
235
- return (
236
- dict_to_dataset(
237
- data,
238
- coords=self.coords,
239
- dims=self.dims,
240
- attrs=self.attrs,
241
- index_origin=self.index_origin,
242
- ),
243
- dict_to_dataset(
244
- data_warmup,
245
- coords=self.coords,
246
- dims=self.dims,
247
- attrs=self.attrs,
248
- index_origin=self.index_origin,
249
- ),
250
- )
251
-
252
- @requires("posterior")
253
- @requires("sample_stats_columns")
254
- def sample_stats_to_xarray(self):
255
- """Extract sample_stats from fit."""
256
- dtypes = {"diverging": bool, "n_steps": np.int64, "tree_depth": np.int64, **self.dtypes}
257
- rename_dict = {
258
- "divergent": "diverging",
259
- "n_leapfrog": "n_steps",
260
- "treedepth": "tree_depth",
261
- "stepsize": "step_size",
262
- "accept_stat": "acceptance_rate",
263
- }
264
-
265
- columns_new = {}
266
- for key, idx in self.sample_stats_columns.items():
267
- name = re.sub("__$", "", key)
268
- name = rename_dict.get(name, name)
269
- columns_new[name] = idx
270
-
271
- data = _unpack_ndarrays(self.posterior[0], columns_new, dtypes)
272
- data_warmup = _unpack_ndarrays(self.posterior[1], columns_new, dtypes)
273
- return (
274
- dict_to_dataset(
275
- data,
276
- coords=self.coords,
277
- dims=self.dims,
278
- attrs={item: key for key, item in rename_dict.items()},
279
- index_origin=self.index_origin,
280
- ),
281
- dict_to_dataset(
282
- data_warmup,
283
- coords=self.coords,
284
- dims=self.dims,
285
- attrs={item: key for key, item in rename_dict.items()},
286
- index_origin=self.index_origin,
287
- ),
288
- )
289
-
290
- @requires("posterior")
291
- @requires("posterior_predictive")
292
- def posterior_predictive_to_xarray(self):
293
- """Convert posterior_predictive samples to xarray."""
294
- posterior_predictive = self.posterior_predictive
295
-
296
- if (
297
- isinstance(posterior_predictive, (tuple, list))
298
- and posterior_predictive[0].endswith(".csv")
299
- ) or (isinstance(posterior_predictive, str) and posterior_predictive.endswith(".csv")):
300
- if isinstance(posterior_predictive, str):
301
- posterior_predictive = [posterior_predictive]
302
- chain_data = []
303
- chain_data_warmup = []
304
- columns = None
305
- attrs = {}
306
- for path in posterior_predictive:
307
- parsed_output = _read_output(path)
308
- chain_data.append(parsed_output["sample"])
309
- chain_data_warmup.append(parsed_output["sample_warmup"])
310
- if columns is None:
311
- columns = parsed_output["sample_columns"]
312
-
313
- for key, value in parsed_output["configuration_info"].items():
314
- if key not in attrs:
315
- attrs[key] = []
316
- attrs[key].append(value)
317
-
318
- data = _unpack_ndarrays(chain_data, columns, self.dtypes)
319
- data_warmup = _unpack_ndarrays(chain_data_warmup, columns, self.dtypes)
320
-
321
- else:
322
- if isinstance(posterior_predictive, str):
323
- posterior_predictive = [posterior_predictive]
324
- columns = {
325
- col: idx
326
- for col, idx in self.posterior_columns.items()
327
- if any(item == col.split(".")[0] for item in posterior_predictive)
328
- }
329
- data = _unpack_ndarrays(self.posterior[0], columns, self.dtypes)
330
- data_warmup = _unpack_ndarrays(self.posterior[1], columns, self.dtypes)
331
-
332
- attrs = None
333
- return (
334
- dict_to_dataset(
335
- data,
336
- coords=self.coords,
337
- dims=self.dims,
338
- attrs=attrs,
339
- index_origin=self.index_origin,
340
- ),
341
- dict_to_dataset(
342
- data_warmup,
343
- coords=self.coords,
344
- dims=self.dims,
345
- attrs=attrs,
346
- index_origin=self.index_origin,
347
- ),
348
- )
349
-
350
- @requires("posterior")
351
- @requires("predictions")
352
- def predictions_to_xarray(self):
353
- """Convert out of sample predictions samples to xarray."""
354
- predictions = self.predictions
355
-
356
- if (isinstance(predictions, (tuple, list)) and predictions[0].endswith(".csv")) or (
357
- isinstance(predictions, str) and predictions.endswith(".csv")
358
- ):
359
- if isinstance(predictions, str):
360
- predictions = [predictions]
361
- chain_data = []
362
- chain_data_warmup = []
363
- columns = None
364
- attrs = {}
365
- for path in predictions:
366
- parsed_output = _read_output(path)
367
- chain_data.append(parsed_output["sample"])
368
- chain_data_warmup.append(parsed_output["sample_warmup"])
369
- if columns is None:
370
- columns = parsed_output["sample_columns"]
371
-
372
- for key, value in parsed_output["configuration_info"].items():
373
- if key not in attrs:
374
- attrs[key] = []
375
- attrs[key].append(value)
376
-
377
- data = _unpack_ndarrays(chain_data, columns, self.dtypes)
378
- data_warmup = _unpack_ndarrays(chain_data_warmup, columns, self.dtypes)
379
- else:
380
- if isinstance(predictions, str):
381
- predictions = [predictions]
382
- columns = {
383
- col: idx
384
- for col, idx in self.posterior_columns.items()
385
- if any(item == col.split(".")[0] for item in predictions)
386
- }
387
- data = _unpack_ndarrays(self.posterior[0], columns, self.dtypes)
388
- data_warmup = _unpack_ndarrays(self.posterior[1], columns, self.dtypes)
389
-
390
- attrs = None
391
- return (
392
- dict_to_dataset(
393
- data,
394
- coords=self.coords,
395
- dims=self.dims,
396
- attrs=attrs,
397
- index_origin=self.index_origin,
398
- ),
399
- dict_to_dataset(
400
- data_warmup,
401
- coords=self.coords,
402
- dims=self.dims,
403
- attrs=attrs,
404
- index_origin=self.index_origin,
405
- ),
406
- )
407
-
408
- @requires("posterior")
409
- @requires("log_likelihood")
410
- def log_likelihood_to_xarray(self):
411
- """Convert elementwise log_likelihood samples to xarray."""
412
- log_likelihood = self.log_likelihood
413
-
414
- if (isinstance(log_likelihood, (tuple, list)) and log_likelihood[0].endswith(".csv")) or (
415
- isinstance(log_likelihood, str) and log_likelihood.endswith(".csv")
416
- ):
417
- if isinstance(log_likelihood, str):
418
- log_likelihood = [log_likelihood]
419
-
420
- chain_data = []
421
- chain_data_warmup = []
422
- columns = None
423
- attrs = {}
424
- for path in log_likelihood:
425
- parsed_output = _read_output(path)
426
- chain_data.append(parsed_output["sample"])
427
- chain_data_warmup.append(parsed_output["sample_warmup"])
428
-
429
- if columns is None:
430
- columns = parsed_output["sample_columns"]
431
-
432
- for key, value in parsed_output["configuration_info"].items():
433
- if key not in attrs:
434
- attrs[key] = []
435
- attrs[key].append(value)
436
- data = _unpack_ndarrays(chain_data, columns, self.dtypes)
437
- data_warmup = _unpack_ndarrays(chain_data_warmup, columns, self.dtypes)
438
- else:
439
- if isinstance(log_likelihood, dict):
440
- log_lik_to_obs_name = {v: k for k, v in log_likelihood.items()}
441
- columns = {
442
- col.replace(col_name, log_lik_to_obs_name[col_name]): idx
443
- for col, col_name, idx in (
444
- (col, col.split(".")[0], idx) for col, idx in self.posterior_columns.items()
445
- )
446
- if any(item == col_name for item in log_likelihood.values())
447
- }
448
- else:
449
- if isinstance(log_likelihood, str):
450
- log_likelihood = [log_likelihood]
451
- columns = {
452
- col: idx
453
- for col, idx in self.posterior_columns.items()
454
- if any(item == col.split(".")[0] for item in log_likelihood)
455
- }
456
- data = _unpack_ndarrays(self.posterior[0], columns, self.dtypes)
457
- data_warmup = _unpack_ndarrays(self.posterior[1], columns, self.dtypes)
458
- attrs = None
459
- return (
460
- dict_to_dataset(
461
- data,
462
- coords=self.coords,
463
- dims=self.dims,
464
- attrs=attrs,
465
- index_origin=self.index_origin,
466
- skip_event_dims=True,
467
- ),
468
- dict_to_dataset(
469
- data_warmup,
470
- coords=self.coords,
471
- dims=self.dims,
472
- attrs=attrs,
473
- index_origin=self.index_origin,
474
- skip_event_dims=True,
475
- ),
476
- )
477
-
478
- @requires("prior")
479
- def prior_to_xarray(self):
480
- """Convert prior samples to xarray."""
481
- # filter prior_predictive
482
- prior_predictive = self.prior_predictive
483
-
484
- columns = self.prior_columns
485
-
486
- if prior_predictive is None or (
487
- isinstance(prior_predictive, str) and prior_predictive.lower().endswith(".csv")
488
- ):
489
- prior_predictive = []
490
- elif isinstance(prior_predictive, str):
491
- prior_predictive = [col for col in columns if prior_predictive == col.split(".")[0]]
492
- else:
493
- prior_predictive = [
494
- col
495
- for col in columns
496
- if any(item == col.split(".")[0] for item in prior_predictive)
497
- ]
498
-
499
- invalid_cols = prior_predictive
500
- valid_cols = {col: idx for col, idx in columns.items() if col not in invalid_cols}
501
- data = _unpack_ndarrays(self.prior[0], valid_cols, self.dtypes)
502
- data_warmup = _unpack_ndarrays(self.prior[1], valid_cols, self.dtypes)
503
- return (
504
- dict_to_dataset(
505
- data,
506
- coords=self.coords,
507
- dims=self.dims,
508
- attrs=self.attrs_prior,
509
- index_origin=self.index_origin,
510
- ),
511
- dict_to_dataset(
512
- data_warmup,
513
- coords=self.coords,
514
- dims=self.dims,
515
- attrs=self.attrs_prior,
516
- index_origin=self.index_origin,
517
- ),
518
- )
519
-
520
- @requires("prior")
521
- @requires("sample_stats_prior_columns")
522
- def sample_stats_prior_to_xarray(self):
523
- """Extract sample_stats from fit."""
524
- dtypes = {"diverging": bool, "n_steps": np.int64, "tree_depth": np.int64, **self.dtypes}
525
- rename_dict = {
526
- "divergent": "diverging",
527
- "n_leapfrog": "n_steps",
528
- "treedepth": "tree_depth",
529
- "stepsize": "step_size",
530
- "accept_stat": "acceptance_rate",
531
- }
532
-
533
- columns_new = {}
534
- for key, idx in self.sample_stats_prior_columns.items():
535
- name = re.sub("__$", "", key)
536
- name = rename_dict.get(name, name)
537
- columns_new[name] = idx
538
-
539
- data = _unpack_ndarrays(self.posterior[0], columns_new, dtypes)
540
- data_warmup = _unpack_ndarrays(self.posterior[1], columns_new, dtypes)
541
- return (
542
- dict_to_dataset(
543
- data,
544
- coords=self.coords,
545
- dims=self.dims,
546
- attrs={item: key for key, item in rename_dict.items()},
547
- index_origin=self.index_origin,
548
- ),
549
- dict_to_dataset(
550
- data_warmup,
551
- coords=self.coords,
552
- dims=self.dims,
553
- attrs={item: key for key, item in rename_dict.items()},
554
- index_origin=self.index_origin,
555
- ),
556
- )
557
-
558
- @requires("prior")
559
- @requires("prior_predictive")
560
- def prior_predictive_to_xarray(self):
561
- """Convert prior_predictive samples to xarray."""
562
- prior_predictive = self.prior_predictive
563
-
564
- if (
565
- isinstance(prior_predictive, (tuple, list)) and prior_predictive[0].endswith(".csv")
566
- ) or (isinstance(prior_predictive, str) and prior_predictive.endswith(".csv")):
567
- if isinstance(prior_predictive, str):
568
- prior_predictive = [prior_predictive]
569
- chain_data = []
570
- chain_data_warmup = []
571
- columns = None
572
- attrs = {}
573
- for path in prior_predictive:
574
- parsed_output = _read_output(path)
575
- chain_data.append(parsed_output["sample"])
576
- chain_data_warmup.append(parsed_output["sample_warmup"])
577
- if columns is None:
578
- columns = parsed_output["sample_columns"]
579
- for key, value in parsed_output["configuration_info"].items():
580
- if key not in attrs:
581
- attrs[key] = []
582
- attrs[key].append(value)
583
- data = _unpack_ndarrays(chain_data, columns, self.dtypes)
584
- data_warmup = _unpack_ndarrays(chain_data_warmup, columns, self.dtypes)
585
- else:
586
- if isinstance(prior_predictive, str):
587
- prior_predictive = [prior_predictive]
588
- columns = {
589
- col: idx
590
- for col, idx in self.prior_columns.items()
591
- if any(item == col.split(".")[0] for item in prior_predictive)
592
- }
593
- data = _unpack_ndarrays(self.prior[0], columns, self.dtypes)
594
- data_warmup = _unpack_ndarrays(self.prior[1], columns, self.dtypes)
595
- attrs = None
596
- return (
597
- dict_to_dataset(
598
- data,
599
- coords=self.coords,
600
- dims=self.dims,
601
- attrs=attrs,
602
- index_origin=self.index_origin,
603
- ),
604
- dict_to_dataset(
605
- data_warmup,
606
- coords=self.coords,
607
- dims=self.dims,
608
- attrs=attrs,
609
- index_origin=self.index_origin,
610
- ),
611
- )
612
-
613
- @requires("observed_data")
614
- def observed_data_to_xarray(self):
615
- """Convert observed data to xarray."""
616
- observed_data_raw = _read_data(self.observed_data)
617
- variables = self.observed_data_var
618
- if isinstance(variables, str):
619
- variables = [variables]
620
- observed_data = {
621
- key: utils.one_de(vals)
622
- for key, vals in observed_data_raw.items()
623
- if variables is None or key in variables
624
- }
625
- return dict_to_dataset(
626
- observed_data,
627
- coords=self.coords,
628
- dims=self.dims,
629
- default_dims=[],
630
- index_origin=self.index_origin,
631
- )
632
-
633
- @requires("constant_data")
634
- def constant_data_to_xarray(self):
635
- """Convert constant data to xarray."""
636
- constant_data_raw = _read_data(self.constant_data)
637
- variables = self.constant_data_var
638
- if isinstance(variables, str):
639
- variables = [variables]
640
- constant_data = {
641
- key: utils.one_de(vals)
642
- for key, vals in constant_data_raw.items()
643
- if variables is None or key in variables
644
- }
645
- return dict_to_dataset(
646
- constant_data,
647
- coords=self.coords,
648
- dims=self.dims,
649
- default_dims=[],
650
- index_origin=self.index_origin,
651
- )
652
-
653
- @requires("predictions_constant_data")
654
- def predictions_constant_data_to_xarray(self):
655
- """Convert predictions constant data to xarray."""
656
- predictions_constant_data_raw = _read_data(self.predictions_constant_data)
657
- variables = self.predictions_constant_data_var
658
- if isinstance(variables, str):
659
- variables = [variables]
660
- predictions_constant_data = {}
661
- for key, vals in predictions_constant_data_raw.items():
662
- if variables is not None and key not in variables:
663
- continue
664
- vals = utils.one_de(vals)
665
- predictions_constant_data[key] = utils.one_de(vals)
666
- return dict_to_dataset(
667
- predictions_constant_data,
668
- coords=self.coords,
669
- dims=self.dims,
670
- default_dims=[],
671
- index_origin=self.index_origin,
672
- )
673
-
674
- def to_inference_data(self):
675
- """Convert all available data to an InferenceData object.
676
-
677
- Note that if groups can not be created (i.e., there is no `output`, so
678
- the `posterior` and `sample_stats` can not be extracted), then the InferenceData
679
- will not have those groups.
680
- """
681
- return InferenceData(
682
- save_warmup=self.save_warmup,
683
- **{
684
- "posterior": self.posterior_to_xarray(),
685
- "sample_stats": self.sample_stats_to_xarray(),
686
- "log_likelihood": self.log_likelihood_to_xarray(),
687
- "posterior_predictive": self.posterior_predictive_to_xarray(),
688
- "prior": self.prior_to_xarray(),
689
- "sample_stats_prior": self.sample_stats_prior_to_xarray(),
690
- "prior_predictive": self.prior_predictive_to_xarray(),
691
- "observed_data": self.observed_data_to_xarray(),
692
- "constant_data": self.constant_data_to_xarray(),
693
- "predictions": self.predictions_to_xarray(),
694
- "predictions_constant_data": self.predictions_constant_data_to_xarray(),
695
- },
696
- )
697
-
698
-
699
- def _process_configuration(comments):
700
- """Extract sampling information."""
701
- results = {
702
- "comments": "\n".join(comments),
703
- "stan_version": {},
704
- }
705
-
706
- comments_gen = iter(comments)
707
-
708
- for comment in comments_gen:
709
- comment = re.sub(r"^\s*#\s*|\s*\(Default\)\s*$", "", comment).strip()
710
- if comment.startswith("stan_version_"):
711
- key, val = re.sub(r"^\s*stan_version_", "", comment).split("=")
712
- results["stan_version"][key.strip()] = val.strip()
713
- elif comment.startswith("Step size"):
714
- _, val = comment.split("=")
715
- results["step_size"] = float(val.strip())
716
- elif "inverse mass matrix" in comment:
717
- comment = re.sub(r"^\s*#\s*", "", next(comments_gen)).strip()
718
- results["inverse_mass_matrix"] = [float(item) for item in comment.split(",")]
719
- elif ("seconds" in comment) and any(
720
- item in comment for item in ("(Warm-up)", "(Sampling)", "(Total)")
721
- ):
722
- value = re.sub(
723
- (
724
- r"^Elapsed\s*Time:\s*|"
725
- r"\s*seconds\s*\(Warm-up\)\s*|"
726
- r"\s*seconds\s*\(Sampling\)\s*|"
727
- r"\s*seconds\s*\(Total\)\s*"
728
- ),
729
- "",
730
- comment,
731
- )
732
- key = (
733
- "warmup_time_seconds"
734
- if "(Warm-up)" in comment
735
- else "sampling_time_seconds" if "(Sampling)" in comment else "total_time_seconds"
736
- )
737
- results[key] = float(value)
738
- elif "=" in comment:
739
- match_int = re.search(r"^(\S+)\s*=\s*([-+]?[0-9]+)$", comment)
740
- match_float = re.search(r"^(\S+)\s*=\s*([-+]?[0-9]+\.[0-9]+)$", comment)
741
- match_str_bool = re.search(r"^(\S+)\s*=\s*(true|false)$", comment)
742
- match_str = re.search(r"^(\S+)\s*=\s*(\S+)$", comment)
743
- match_empty = re.search(r"^(\S+)\s*=\s*$", comment)
744
- if match_int:
745
- key, value = match_int.group(1), match_int.group(2)
746
- results[key] = int(value)
747
- elif match_float:
748
- key, value = match_float.group(1), match_float.group(2)
749
- results[key] = float(value)
750
- elif match_str_bool:
751
- key, value = match_str_bool.group(1), match_str_bool.group(2)
752
- results[key] = int(value == "true")
753
- elif match_str:
754
- key, value = match_str.group(1), match_str.group(2)
755
- results[key] = value
756
- elif match_empty:
757
- key = match_empty.group(1)
758
- results[key] = None
759
-
760
- results = {key: str(results[key]) for key in sorted(results)}
761
- return results
762
-
763
-
764
- def _read_output_file(path):
765
- """Read Stan csv file to ndarray."""
766
- comments = []
767
- data = []
768
- columns = None
769
- with open(path, "rb") as f_obj:
770
- # read header
771
- for line in f_obj:
772
- if line.startswith(b"#"):
773
- comments.append(line.strip().decode("utf-8"))
774
- continue
775
- columns = {key: idx for idx, key in enumerate(line.strip().decode("utf-8").split(","))}
776
- break
777
- # read data
778
- for line in f_obj:
779
- line = line.strip()
780
- if line.startswith(b"#"):
781
- comments.append(line.decode("utf-8"))
782
- continue
783
- if line:
784
- data.append(np.array(line.split(b","), dtype=np.float64))
785
-
786
- return columns, np.array(data, dtype=np.float64), comments
787
-
788
-
789
- def _read_output(path):
790
- """Read CmdStan output csv file.
791
-
792
- Parameters
793
- ----------
794
- path : str
795
-
796
- Returns
797
- -------
798
- Dict[str, Any]
799
- """
800
- # Read data
801
- columns, data, comments = _read_output_file(path)
802
-
803
- pconf = _process_configuration(comments)
804
-
805
- # split dataframe to warmup and draws
806
- saved_warmup = (
807
- int(pconf.get("save_warmup", 0))
808
- * int(pconf.get("num_warmup", 0))
809
- // int(pconf.get("thin", 1))
810
- )
811
-
812
- data_warmup = data[:saved_warmup]
813
- data = data[saved_warmup:]
814
-
815
- # Split data to sample_stats and sample
816
- sample_stats_columns = {col: idx for col, idx in columns.items() if col.endswith("__")}
817
- sample_columns = {col: idx for col, idx in columns.items() if col not in sample_stats_columns}
818
-
819
- return {
820
- "sample": data,
821
- "sample_warmup": data_warmup,
822
- "sample_columns": sample_columns,
823
- "sample_stats_columns": sample_stats_columns,
824
- "configuration_info": pconf,
825
- }
826
-
827
-
828
- def _process_data_var(string):
829
- """Transform datastring to key, values pair.
830
-
831
- All values are transformed to floating point values.
832
-
833
- Parameters
834
- ----------
835
- string : str
836
-
837
- Returns
838
- -------
839
- Tuple[Str, Str]
840
- key, values pair
841
- """
842
- key, var = string.split("<-")
843
- if "structure" in var:
844
- var, dim = var.replace("structure(", "").replace(",", "").split(".Dim")
845
- # dtype = int if '.' not in var and 'e' not in var.lower() else float
846
- dtype = float
847
- var = var.replace("c(", "").replace(")", "").strip().split()
848
- dim = dim.replace("=", "").replace("c(", "").replace(")", "").strip().split()
849
- dim = tuple(map(int, dim))
850
- var = np.fromiter(map(dtype, var), dtype).reshape(dim, order="F")
851
- elif "c(" in var:
852
- # dtype = int if '.' not in var and 'e' not in var.lower() else float
853
- dtype = float
854
- var = var.replace("c(", "").replace(")", "").split(",")
855
- var = np.fromiter(map(dtype, var), dtype)
856
- else:
857
- # dtype = int if '.' not in var and 'e' not in var.lower() else float
858
- dtype = float
859
- var = dtype(var)
860
- return key.strip(), var
861
-
862
-
863
- def _read_data(path):
864
- """Read Rdump or JSON output to dictionary.
865
-
866
- Parameters
867
- ----------
868
- path : str
869
-
870
- Returns
871
- -------
872
- Dict
873
- key, values pairs from Rdump/JSON formatted data.
874
- """
875
- data = {}
876
- with open(path, "r", encoding="utf8") as f_obj:
877
- if path.lower().endswith(".json"):
878
- return json.load(f_obj)
879
- var = ""
880
- for line in f_obj:
881
- if "<-" in line:
882
- if len(var):
883
- key, var = _process_data_var(var)
884
- data[key] = var
885
- var = ""
886
- var += f" {line.strip()}"
887
- if len(var):
888
- key, var = _process_data_var(var)
889
- data[key] = var
890
- return data
891
-
892
-
893
- def _unpack_ndarrays(arrays, columns, dtypes=None):
894
- """Transform a list of ndarrays to dictionary containing ndarrays.
895
-
896
- Parameters
897
- ----------
898
- arrays : List[np.ndarray]
899
- columns: Dict[str, int]
900
- dtypes: Dict[str, Any]
901
-
902
- Returns
903
- -------
904
- Dict
905
- key, values pairs. Values are formatted to shape = (nchain, ndraws, *shape)
906
- """
907
- col_groups = defaultdict(list)
908
- for col, col_idx in columns.items():
909
- key, *loc = col.split(".")
910
- loc = tuple(int(i) - 1 for i in loc)
911
- col_groups[key].append((col_idx, loc))
912
-
913
- chains = len(arrays)
914
- draws = len(arrays[0])
915
- sample = {}
916
- if draws:
917
- for key, cols_locs in col_groups.items():
918
- ndim = np.array([loc for _, loc in cols_locs]).max(0) + 1
919
- dtype = dtypes.get(key, np.float64)
920
- sample[key] = np.zeros((chains, draws, *ndim), dtype=dtype)
921
- for col, loc in cols_locs:
922
- for chain_id, arr in enumerate(arrays):
923
- draw = arr[:, col]
924
- if loc == ():
925
- sample[key][chain_id, :] = draw
926
- else:
927
- axis1_all = range(sample[key].shape[1])
928
- slicer = (chain_id, axis1_all, *loc)
929
- sample[key][slicer] = draw
930
- return sample
931
-
932
-
933
- def from_cmdstan(
934
- posterior: Optional[Union[str, List[str]]] = None,
935
- *,
936
- posterior_predictive: Optional[Union[str, List[str]]] = None,
937
- predictions: Optional[Union[str, List[str]]] = None,
938
- prior: Optional[Union[str, List[str]]] = None,
939
- prior_predictive: Optional[Union[str, List[str]]] = None,
940
- observed_data: Optional[str] = None,
941
- observed_data_var: Optional[Union[str, List[str]]] = None,
942
- constant_data: Optional[str] = None,
943
- constant_data_var: Optional[Union[str, List[str]]] = None,
944
- predictions_constant_data: Optional[str] = None,
945
- predictions_constant_data_var: Optional[Union[str, List[str]]] = None,
946
- log_likelihood: Optional[Union[str, List[str]]] = None,
947
- index_origin: Optional[int] = None,
948
- coords: Optional[CoordSpec] = None,
949
- dims: Optional[DimSpec] = None,
950
- disable_glob: Optional[bool] = False,
951
- save_warmup: Optional[bool] = None,
952
- dtypes: Optional[Dict] = None,
953
- ) -> InferenceData:
954
- """Convert CmdStan data into an InferenceData object.
955
-
956
- For a usage example read the
957
- :ref:`Creating InferenceData section on from_cmdstan <creating_InferenceData>`
958
-
959
- Parameters
960
- ----------
961
- posterior : str or list of str, optional
962
- List of paths to output.csv files.
963
- posterior_predictive : str or list of str, optional
964
- Posterior predictive samples for the fit. If endswith ".csv" assumes file.
965
- predictions : str or list of str, optional
966
- Out of sample predictions samples for the fit. If endswith ".csv" assumes file.
967
- prior : str or list of str, optional
968
- List of paths to output.csv files
969
- prior_predictive : str or list of str, optional
970
- Prior predictive samples for the fit. If endswith ".csv" assumes file.
971
- observed_data : str, optional
972
- Observed data used in the sampling. Path to data file in Rdump or JSON format.
973
- observed_data_var : str or list of str, optional
974
- Variable(s) used for slicing observed_data. If not defined, all
975
- data variables are imported.
976
- constant_data : str, optional
977
- Constant data used in the sampling. Path to data file in Rdump or JSON format.
978
- constant_data_var : str or list of str, optional
979
- Variable(s) used for slicing constant_data. If not defined, all
980
- data variables are imported.
981
- predictions_constant_data : str, optional
982
- Constant data for predictions used in the sampling.
983
- Path to data file in Rdump or JSON format.
984
- predictions_constant_data_var : str or list of str, optional
985
- Variable(s) used for slicing predictions_constant_data.
986
- If not defined, all data variables are imported.
987
- log_likelihood : dict of {str: str}, list of str or str, optional
988
- Pointwise log_likelihood for the data. log_likelihood is extracted from the
989
- posterior. It is recommended to use this argument as a dictionary whose keys
990
- are observed variable names and its values are the variables storing log
991
- likelihood arrays in the Stan code. In other cases, a dictionary with keys
992
- equal to its values is used. By default, if a variable ``log_lik`` is
993
- present in the Stan model, it will be retrieved as pointwise log
994
- likelihood values. Use ``False`` to avoid this behaviour.
995
- index_origin : int, optional
996
- Starting value of integer coordinate values. Defaults to the value in rcParam
997
- ``data.index_origin``.
998
- coords : dict of {str: array_like}, optional
999
- A dictionary containing the values that are used as index. The key
1000
- is the name of the dimension, the values are the index values.
1001
- dims : dict of {str: list of str}, optional
1002
- A mapping from variables to a list of coordinate names for the variable.
1003
- disable_glob : bool
1004
- Don't use glob for string input. This means that all string input is
1005
- assumed to be variable names (samples) or a path (data).
1006
- save_warmup : bool
1007
- Save warmup iterations into InferenceData object, if found in the input files.
1008
- If not defined, use default defined by the rcParams.
1009
- dtypes : dict or str
1010
- A dictionary containing dtype information (int, float) for parameters.
1011
- If input is a string, it is assumed to be a model code or path to model code file.
1012
-
1013
- Returns
1014
- -------
1015
- InferenceData object
1016
- """
1017
- return CmdStanConverter(
1018
- posterior=posterior,
1019
- posterior_predictive=posterior_predictive,
1020
- predictions=predictions,
1021
- prior=prior,
1022
- prior_predictive=prior_predictive,
1023
- observed_data=observed_data,
1024
- observed_data_var=observed_data_var,
1025
- constant_data=constant_data,
1026
- constant_data_var=constant_data_var,
1027
- predictions_constant_data=predictions_constant_data,
1028
- predictions_constant_data_var=predictions_constant_data_var,
1029
- log_likelihood=log_likelihood,
1030
- index_origin=index_origin,
1031
- coords=coords,
1032
- dims=dims,
1033
- disable_glob=disable_glob,
1034
- save_warmup=save_warmup,
1035
- dtypes=dtypes,
1036
- ).to_inference_data()