arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/data/io_pyjags.py DELETED
@@ -1,378 +0,0 @@
1
- """Convert PyJAGS sample dictionaries to ArviZ inference data objects."""
2
-
3
- import typing as tp
4
- from collections import OrderedDict
5
- from collections.abc import Iterable
6
-
7
- import numpy as np
8
- import xarray
9
-
10
- from .inference_data import InferenceData
11
-
12
- from ..rcparams import rcParams
13
- from .base import dict_to_dataset
14
-
15
-
16
- class PyJAGSConverter:
17
- """Encapsulate PyJAGS specific logic."""
18
-
19
- def __init__(
20
- self,
21
- *,
22
- posterior: tp.Optional[tp.Mapping[str, np.ndarray]] = None,
23
- prior: tp.Optional[tp.Mapping[str, np.ndarray]] = None,
24
- log_likelihood: tp.Optional[
25
- tp.Union[str, tp.List[str], tp.Tuple[str, ...], tp.Mapping[str, str]]
26
- ] = None,
27
- coords=None,
28
- dims=None,
29
- save_warmup: tp.Optional[bool] = None,
30
- warmup_iterations: int = 0,
31
- ) -> None:
32
- self.posterior: tp.Optional[tp.Mapping[str, np.ndarray]]
33
- self.log_likelihood: tp.Optional[tp.Dict[str, np.ndarray]]
34
- if log_likelihood is not None and posterior is not None:
35
- posterior_copy = dict(posterior) # create a shallow copy of the dictionary
36
-
37
- if isinstance(log_likelihood, str):
38
- log_likelihood = [log_likelihood]
39
- if isinstance(log_likelihood, (list, tuple)):
40
- log_likelihood = {name: name for name in log_likelihood}
41
-
42
- self.log_likelihood = {
43
- obs_var_name: posterior_copy.pop(log_like_name)
44
- for obs_var_name, log_like_name in log_likelihood.items()
45
- }
46
- self.posterior = posterior_copy
47
- else:
48
- self.posterior = posterior
49
- self.log_likelihood = None
50
- self.prior = prior
51
- self.coords = coords
52
- self.dims = dims
53
- self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
54
- self.warmup_iterations = warmup_iterations
55
-
56
- import pyjags # pylint: disable=import-error
57
-
58
- self.pyjags = pyjags
59
-
60
- def _pyjags_samples_to_xarray(
61
- self, pyjags_samples: tp.Mapping[str, np.ndarray]
62
- ) -> tp.Tuple[xarray.Dataset, xarray.Dataset]:
63
- data, data_warmup = get_draws(
64
- pyjags_samples=pyjags_samples,
65
- warmup_iterations=self.warmup_iterations,
66
- warmup=self.save_warmup,
67
- )
68
-
69
- return (
70
- dict_to_dataset(data, library=self.pyjags, coords=self.coords, dims=self.dims),
71
- dict_to_dataset(
72
- data_warmup,
73
- library=self.pyjags,
74
- coords=self.coords,
75
- dims=self.dims,
76
- ),
77
- )
78
-
79
- def posterior_to_xarray(self) -> tp.Optional[tp.Tuple[xarray.Dataset, xarray.Dataset]]:
80
- """Extract posterior samples from fit."""
81
- if self.posterior is None:
82
- return None
83
-
84
- return self._pyjags_samples_to_xarray(self.posterior)
85
-
86
- def prior_to_xarray(self) -> tp.Optional[tp.Tuple[xarray.Dataset, xarray.Dataset]]:
87
- """Extract posterior samples from fit."""
88
- if self.prior is None:
89
- return None
90
-
91
- return self._pyjags_samples_to_xarray(self.prior)
92
-
93
- def log_likelihood_to_xarray(self) -> tp.Optional[tp.Tuple[xarray.Dataset, xarray.Dataset]]:
94
- """Extract log likelihood samples from fit."""
95
- if self.log_likelihood is None:
96
- return None
97
-
98
- return self._pyjags_samples_to_xarray(self.log_likelihood)
99
-
100
- def to_inference_data(self):
101
- """Convert all available data to an InferenceData object."""
102
- # obs_const_dict = self.observed_and_constant_data_to_xarray()
103
- # predictions_const_data = self.predictions_constant_data_to_xarray()
104
- save_warmup = self.save_warmup and self.warmup_iterations > 0
105
- # self.posterior is not None
106
-
107
- idata_dict = {
108
- "posterior": self.posterior_to_xarray(),
109
- "prior": self.prior_to_xarray(),
110
- "log_likelihood": self.log_likelihood_to_xarray(),
111
- "save_warmup": save_warmup,
112
- }
113
-
114
- return InferenceData(**idata_dict)
115
-
116
-
117
- def get_draws(
118
- pyjags_samples: tp.Mapping[str, np.ndarray],
119
- variables: tp.Optional[tp.Union[str, tp.Iterable[str]]] = None,
120
- warmup: bool = False,
121
- warmup_iterations: int = 0,
122
- ) -> tp.Tuple[tp.Mapping[str, np.ndarray], tp.Mapping[str, np.ndarray]]:
123
- """
124
- Convert PyJAGS samples dictionary to ArviZ format and split warmup samples.
125
-
126
- Parameters
127
- ----------
128
- pyjags_samples: a dictionary mapping variable names to NumPy arrays of MCMC
129
- chains of samples with shape
130
- (parameter_dimension, chain_length, number_of_chains)
131
-
132
- variables: the variables to extract from the samples dictionary
133
- warmup: whether or not to return warmup draws in data_warmup
134
- warmup_iterations: the number of warmup iterations if any
135
-
136
- Returns
137
- -------
138
- A tuple of two samples dictionaries in ArviZ format
139
- """
140
- data_warmup: tp.Mapping[str, np.ndarray] = OrderedDict()
141
-
142
- if variables is None:
143
- variables = list(pyjags_samples.keys())
144
- elif isinstance(variables, str):
145
- variables = [variables]
146
-
147
- if not isinstance(variables, Iterable):
148
- raise TypeError("variables must be of type Sequence or str")
149
-
150
- variables = tuple(variables)
151
-
152
- if warmup_iterations > 0:
153
- (
154
- warmup_samples,
155
- actual_samples,
156
- ) = _split_pyjags_dict_in_warmup_and_actual_samples(
157
- pyjags_samples=pyjags_samples,
158
- warmup_iterations=warmup_iterations,
159
- variable_names=variables,
160
- )
161
-
162
- data = _convert_pyjags_dict_to_arviz_dict(samples=actual_samples, variable_names=variables)
163
-
164
- if warmup:
165
- data_warmup = _convert_pyjags_dict_to_arviz_dict(
166
- samples=warmup_samples, variable_names=variables
167
- )
168
- else:
169
- data = _convert_pyjags_dict_to_arviz_dict(samples=pyjags_samples, variable_names=variables)
170
-
171
- return data, data_warmup
172
-
173
-
174
- def _split_pyjags_dict_in_warmup_and_actual_samples(
175
- pyjags_samples: tp.Mapping[str, np.ndarray],
176
- warmup_iterations: int,
177
- variable_names: tp.Optional[tp.Tuple[str, ...]] = None,
178
- ) -> tp.Tuple[tp.Mapping[str, np.ndarray], tp.Mapping[str, np.ndarray]]:
179
- """
180
- Split a PyJAGS samples dictionary into actual samples and warmup samples.
181
-
182
- Parameters
183
- ----------
184
- pyjags_samples: a dictionary mapping variable names to NumPy arrays of MCMC
185
- chains of samples with shape
186
- (parameter_dimension, chain_length, number_of_chains)
187
-
188
- warmup_iterations: the number of draws to be split off for warmum
189
- variable_names: the variables in the dictionary to use; if None use all
190
-
191
- Returns
192
- -------
193
- A tuple of two pyjags samples dictionaries in PyJAGS format
194
- """
195
- if variable_names is None:
196
- variable_names = tuple(pyjags_samples.keys())
197
-
198
- warmup_samples: tp.Dict[str, np.ndarray] = {}
199
- actual_samples: tp.Dict[str, np.ndarray] = {}
200
-
201
- for variable_name, chains in pyjags_samples.items():
202
- if variable_name in variable_names:
203
- warmup_samples[variable_name] = chains[:, :warmup_iterations, :]
204
- actual_samples[variable_name] = chains[:, warmup_iterations:, :]
205
-
206
- return warmup_samples, actual_samples
207
-
208
-
209
- def _convert_pyjags_dict_to_arviz_dict(
210
- samples: tp.Mapping[str, np.ndarray],
211
- variable_names: tp.Optional[tp.Tuple[str, ...]] = None,
212
- ) -> tp.Mapping[str, np.ndarray]:
213
- """
214
- Convert a PyJAGS dictionary to an ArviZ dictionary.
215
-
216
- Takes a python dictionary of samples that has been generated by the sample
217
- method of a model instance and returns a dictionary of samples in ArviZ
218
- format.
219
-
220
- Parameters
221
- ----------
222
- samples: a dictionary mapping variable names to P arrays with shape
223
- (parameter_dimension, chain_length, number_of_chains)
224
-
225
- Returns
226
- -------
227
- a dictionary mapping variable names to NumPy arrays with shape
228
- (number_of_chains, chain_length, parameter_dimension)
229
- """
230
- # pyjags returns a dictionary of NumPy arrays with shape
231
- # (parameter_dimension, chain_length, number_of_chains)
232
- # but arviz expects samples with shape
233
- # (number_of_chains, chain_length, parameter_dimension)
234
-
235
- variable_name_to_samples_map = {}
236
-
237
- if variable_names is None:
238
- variable_names = tuple(samples.keys())
239
-
240
- for variable_name, chains in samples.items():
241
- if variable_name in variable_names:
242
- parameter_dimension, _, _ = chains.shape
243
- if parameter_dimension == 1:
244
- variable_name_to_samples_map[variable_name] = chains[0, :, :].transpose()
245
- else:
246
- variable_name_to_samples_map[variable_name] = np.swapaxes(chains, 0, 2)
247
-
248
- return variable_name_to_samples_map
249
-
250
-
251
- def _extract_arviz_dict_from_inference_data(
252
- idata,
253
- ) -> tp.Mapping[str, np.ndarray]:
254
- """
255
- Extract the samples dictionary from an ArviZ inference data object.
256
-
257
- Extracts a dictionary mapping parameter names to NumPy arrays of samples
258
- with shape (number_of_chains, chain_length, parameter_dimension) from an
259
- ArviZ inference data object.
260
-
261
- Parameters
262
- ----------
263
- idata: InferenceData
264
-
265
- Returns
266
- -------
267
- a dictionary mapping variable names to NumPy arrays with shape
268
- (number_of_chains, chain_length, parameter_dimension)
269
-
270
- """
271
- variable_name_to_samples_map = {
272
- key: np.array(value["data"])
273
- for key, value in idata.posterior.to_dict()["data_vars"].items()
274
- }
275
-
276
- return variable_name_to_samples_map
277
-
278
-
279
- def _convert_arviz_dict_to_pyjags_dict(
280
- samples: tp.Mapping[str, np.ndarray],
281
- ) -> tp.Mapping[str, np.ndarray]:
282
- """
283
- Convert and ArviZ dictionary to a PyJAGS dictionary.
284
-
285
- Takes a python dictionary of samples in ArviZ format and returns the samples
286
- as a dictionary in PyJAGS format.
287
-
288
- Parameters
289
- ----------
290
- samples: dict of {str : array_like}
291
- a dictionary mapping variable names to NumPy arrays with shape
292
- (number_of_chains, chain_length, parameter_dimension)
293
-
294
- Returns
295
- -------
296
- a dictionary mapping variable names to NumPy arrays with shape
297
- (parameter_dimension, chain_length, number_of_chains)
298
-
299
- """
300
- # pyjags returns a dictionary of NumPy arrays with shape
301
- # (parameter_dimension, chain_length, number_of_chains)
302
- # but arviz expects samples with shape
303
- # (number_of_chains, chain_length, parameter_dimension)
304
-
305
- variable_name_to_samples_map = {}
306
-
307
- for variable_name, chains in samples.items():
308
- if chains.ndim == 2:
309
- number_of_chains, chain_length = chains.shape
310
- chains = chains.reshape((number_of_chains, chain_length, 1))
311
-
312
- variable_name_to_samples_map[variable_name] = np.swapaxes(chains, 0, 2)
313
-
314
- return variable_name_to_samples_map
315
-
316
-
317
- def from_pyjags(
318
- posterior: tp.Optional[tp.Mapping[str, np.ndarray]] = None,
319
- prior: tp.Optional[tp.Mapping[str, np.ndarray]] = None,
320
- log_likelihood: tp.Optional[tp.Mapping[str, str]] = None,
321
- coords=None,
322
- dims=None,
323
- save_warmup=None,
324
- warmup_iterations: int = 0,
325
- ) -> InferenceData:
326
- """
327
- Convert PyJAGS posterior samples to an ArviZ inference data object.
328
-
329
- Takes a python dictionary of samples that has been generated by the sample
330
- method of a model instance and returns an Arviz inference data object.
331
- For a usage example read the
332
- :ref:`Creating InferenceData section on from_pyjags <creating_InferenceData>`
333
-
334
- Parameters
335
- ----------
336
- posterior: dict of {str : array_like}, optional
337
- a dictionary mapping variable names to NumPy arrays containing
338
- posterior samples with shape
339
- (parameter_dimension, chain_length, number_of_chains)
340
-
341
- prior: dict of {str : array_like}, optional
342
- a dictionary mapping variable names to NumPy arrays containing
343
- prior samples with shape
344
- (parameter_dimension, chain_length, number_of_chains)
345
-
346
- log_likelihood: dict of {str: str}, list of str or str, optional
347
- Pointwise log_likelihood for the data. log_likelihood is extracted from the
348
- posterior. It is recommended to use this argument as a dictionary whose keys
349
- are observed variable names and its values are the variables storing log
350
- likelihood arrays in the JAGS code. In other cases, a dictionary with keys
351
- equal to its values is used.
352
-
353
- coords: dict[str, iterable]
354
- A dictionary containing the values that are used as index. The key
355
- is the name of the dimension, the values are the index values.
356
-
357
- dims: dict[str, List(str)]
358
- A mapping from variables to a list of coordinate names for the variable.
359
-
360
- save_warmup : bool, optional
361
- Save warmup iterations in InferenceData. If not defined, use default defined by the rcParams.
362
-
363
- warmup_iterations: int, optional
364
- Number of warmup iterations
365
-
366
- Returns
367
- -------
368
- InferenceData
369
- """
370
- return PyJAGSConverter(
371
- posterior=posterior,
372
- prior=prior,
373
- log_likelihood=log_likelihood,
374
- dims=dims,
375
- coords=coords,
376
- save_warmup=save_warmup,
377
- warmup_iterations=warmup_iterations,
378
- ).to_inference_data()