arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
arviz/data/io_numpyro.py DELETED
@@ -1,497 +0,0 @@
1
- """NumPyro-specific conversion code."""
2
-
3
- from collections import defaultdict
4
- import logging
5
- from typing import Any, Callable, Optional, Dict, List, Tuple
6
-
7
- import numpy as np
8
-
9
- from .. import utils
10
- from ..rcparams import rcParams
11
- from .base import dict_to_dataset, requires
12
- from .inference_data import InferenceData
13
-
14
- _log = logging.getLogger(__name__)
15
-
16
-
17
- def _add_dims(dims_a: Dict[str, List[str]], dims_b: Dict[str, List[str]]) -> Dict[str, List[str]]:
18
- merged = defaultdict(list)
19
-
20
- for k, v in dims_a.items():
21
- merged[k].extend(v)
22
-
23
- for k, v in dims_b.items():
24
- merged[k].extend(v)
25
-
26
- # Convert back to a regular dict
27
- return dict(merged)
28
-
29
-
30
- def infer_dims(
31
- model: Callable,
32
- model_args: Optional[Tuple[Any, ...]] = None,
33
- model_kwargs: Optional[Dict[str, Any]] = None,
34
- ) -> Dict[str, List[str]]:
35
-
36
- from numpyro import handlers, distributions as dist
37
- from numpyro.ops.pytree import PytreeTrace
38
- from numpyro.infer.initialization import init_to_sample
39
- import jax
40
-
41
- model_args = tuple() if model_args is None else model_args
42
- model_kwargs = dict() if model_args is None else model_kwargs
43
-
44
- def _get_dist_name(fn):
45
- if isinstance(fn, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution)):
46
- return _get_dist_name(fn.base_dist)
47
- return type(fn).__name__
48
-
49
- def get_trace():
50
- # We use `init_to_sample` to get around ImproperUniform distribution,
51
- # which does not have `sample` method.
52
- subs_model = handlers.substitute(
53
- handlers.seed(model, 0),
54
- substitute_fn=init_to_sample,
55
- )
56
- trace = handlers.trace(subs_model).get_trace(*model_args, **model_kwargs)
57
- # Work around an issue where jax.eval_shape does not work
58
- # for distribution output (e.g. the function `lambda: dist.Normal(0, 1)`)
59
- # Here we will remove `fn` and store its name in the trace.
60
- for _, site in trace.items():
61
- if site["type"] == "sample":
62
- site["fn_name"] = _get_dist_name(site.pop("fn"))
63
- elif site["type"] == "deterministic":
64
- site["fn_name"] = "Deterministic"
65
- return PytreeTrace(trace)
66
-
67
- # We use eval_shape to avoid any array computation.
68
- trace = jax.eval_shape(get_trace).trace
69
-
70
- named_dims = {}
71
-
72
- for name, site in trace.items():
73
- batch_dims = [frame.name for frame in sorted(site["cond_indep_stack"], key=lambda x: x.dim)]
74
- event_dims = list(site.get("infer", {}).get("event_dims", []))
75
- if site["type"] in ["sample", "deterministic"] and (batch_dims or event_dims):
76
- named_dims[name] = batch_dims + event_dims
77
-
78
- return named_dims
79
-
80
-
81
- class NumPyroConverter:
82
- """Encapsulate NumPyro specific logic."""
83
-
84
- # pylint: disable=too-many-instance-attributes
85
-
86
- model = None # type: Optional[Callable]
87
- nchains = None # type: int
88
- ndraws = None # type: int
89
-
90
- def __init__(
91
- self,
92
- *,
93
- posterior=None,
94
- prior=None,
95
- posterior_predictive=None,
96
- predictions=None,
97
- constant_data=None,
98
- predictions_constant_data=None,
99
- log_likelihood=None,
100
- index_origin=None,
101
- coords=None,
102
- dims=None,
103
- pred_dims=None,
104
- extra_event_dims=None,
105
- num_chains=1,
106
- ):
107
- """Convert NumPyro data into an InferenceData object.
108
-
109
- Parameters
110
- ----------
111
- posterior : numpyro.mcmc.MCMC
112
- Fitted MCMC object from NumPyro
113
- prior: dict
114
- Prior samples from a NumPyro model
115
- posterior_predictive : dict
116
- Posterior predictive samples for the posterior
117
- predictions: dict
118
- Out of sample predictions
119
- constant_data: dict
120
- Dictionary containing constant data variables mapped to their values.
121
- predictions_constant_data: dict
122
- Constant data used for out-of-sample predictions.
123
- index_origin : int, optional
124
- coords : dict[str] -> list[str]
125
- Map of dimensions to coordinates
126
- dims : dict[str] -> list[str]
127
- Map variable names to their coordinates. Will be inferred if they are not provided.
128
- pred_dims: dict
129
- Dims for predictions data. Map variable names to their coordinates.
130
- extra_event_dims: dict
131
- Extra event dims for deterministic sites. Maps event dims that couldnt be inferred to
132
- their coordinates.
133
- num_chains: int
134
- Number of chains used for sampling. Ignored if posterior is present.
135
- """
136
- import jax
137
- import numpyro
138
-
139
- self.posterior = posterior
140
- self.prior = jax.device_get(prior)
141
- self.posterior_predictive = jax.device_get(posterior_predictive)
142
- self.predictions = predictions
143
- self.constant_data = constant_data
144
- self.predictions_constant_data = predictions_constant_data
145
- self.log_likelihood = (
146
- rcParams["data.log_likelihood"] if log_likelihood is None else log_likelihood
147
- )
148
- self.index_origin = rcParams["data.index_origin"] if index_origin is None else index_origin
149
- self.coords = coords
150
- self.dims = dims
151
- self.pred_dims = pred_dims
152
- self.extra_event_dims = extra_event_dims
153
- self.numpyro = numpyro
154
-
155
- def arbitrary_element(dct):
156
- return next(iter(dct.values()))
157
-
158
- if posterior is not None:
159
- samples = jax.device_get(self.posterior.get_samples(group_by_chain=True))
160
- if hasattr(samples, "_asdict"):
161
- # In case it is easy to convert to a dictionary, as in the case of namedtuples
162
- samples = samples._asdict()
163
- if not isinstance(samples, dict):
164
- # handle the case we run MCMC with a general potential_fn
165
- # (instead of a NumPyro model) whose args is not a dictionary
166
- # (e.g. f(x) = x ** 2)
167
- tree_flatten_samples = jax.tree_util.tree_flatten(samples)[0]
168
- samples = {
169
- f"Param:{i}": jax.device_get(v) for i, v in enumerate(tree_flatten_samples)
170
- }
171
- self._samples = samples
172
- self.nchains, self.ndraws = (
173
- posterior.num_chains,
174
- posterior.num_samples // posterior.thinning,
175
- )
176
- self.model = self.posterior.sampler.model
177
- # model arguments and keyword arguments
178
- self._args = self.posterior._args # pylint: disable=protected-access
179
- self._kwargs = self.posterior._kwargs # pylint: disable=protected-access
180
- self.dims = self.dims if self.dims is not None else self.infer_dims()
181
- self.pred_dims = (
182
- self.pred_dims if self.pred_dims is not None else self.infer_pred_dims()
183
- )
184
- else:
185
- self.nchains = num_chains
186
- get_from = None
187
- if predictions is not None:
188
- get_from = predictions
189
- elif posterior_predictive is not None:
190
- get_from = posterior_predictive
191
- elif prior is not None:
192
- get_from = prior
193
- if get_from is None and constant_data is None and predictions_constant_data is None:
194
- raise ValueError(
195
- "When constructing InferenceData must have at least"
196
- " one of posterior, prior, posterior_predictive or predictions."
197
- )
198
- if get_from is not None:
199
- aelem = arbitrary_element(get_from)
200
- self.ndraws = aelem.shape[0] // self.nchains
201
-
202
- observations = {}
203
- if self.model is not None:
204
- # we need to use an init strategy to generate random samples for ImproperUniform sites
205
- seeded_model = numpyro.handlers.substitute(
206
- numpyro.handlers.seed(self.model, jax.random.PRNGKey(0)),
207
- substitute_fn=numpyro.infer.init_to_sample,
208
- )
209
- trace = numpyro.handlers.trace(seeded_model).get_trace(*self._args, **self._kwargs)
210
- observations = {
211
- name: site["value"]
212
- for name, site in trace.items()
213
- if site["type"] == "sample" and site["is_observed"]
214
- }
215
- self.observations = observations if observations else None
216
-
217
- @requires("posterior")
218
- def posterior_to_xarray(self):
219
- """Convert the posterior to an xarray dataset."""
220
- data = self._samples
221
- return dict_to_dataset(
222
- data,
223
- library=self.numpyro,
224
- coords=self.coords,
225
- dims=self.dims,
226
- index_origin=self.index_origin,
227
- )
228
-
229
- @requires("posterior")
230
- def sample_stats_to_xarray(self):
231
- """Extract sample_stats from NumPyro posterior."""
232
- rename_key = {
233
- "potential_energy": "lp",
234
- "adapt_state.step_size": "step_size",
235
- "num_steps": "n_steps",
236
- "accept_prob": "acceptance_rate",
237
- }
238
- data = {}
239
- for stat, value in self.posterior.get_extra_fields(group_by_chain=True).items():
240
- if isinstance(value, (dict, tuple)):
241
- continue
242
- name = rename_key.get(stat, stat)
243
- value = value.copy()
244
- if stat == "potential_energy":
245
- data[name] = -value
246
- else:
247
- data[name] = value
248
- if stat == "num_steps":
249
- data["tree_depth"] = np.log2(value).astype(int) + 1
250
- return dict_to_dataset(
251
- data,
252
- library=self.numpyro,
253
- dims=None,
254
- coords=self.coords,
255
- index_origin=self.index_origin,
256
- )
257
-
258
- @requires("posterior")
259
- @requires("model")
260
- def log_likelihood_to_xarray(self):
261
- """Extract log likelihood from NumPyro posterior."""
262
- if not self.log_likelihood:
263
- return None
264
- data = {}
265
- if self.observations is not None:
266
- samples = self.posterior.get_samples(group_by_chain=False)
267
- if hasattr(samples, "_asdict"):
268
- samples = samples._asdict()
269
- log_likelihood_dict = self.numpyro.infer.log_likelihood(
270
- self.model, samples, *self._args, **self._kwargs
271
- )
272
- for obs_name, log_like in log_likelihood_dict.items():
273
- shape = (self.nchains, self.ndraws) + log_like.shape[1:]
274
- data[obs_name] = np.reshape(np.asarray(log_like), shape)
275
- return dict_to_dataset(
276
- data,
277
- library=self.numpyro,
278
- dims=self.dims,
279
- coords=self.coords,
280
- index_origin=self.index_origin,
281
- skip_event_dims=True,
282
- )
283
-
284
- def translate_posterior_predictive_dict_to_xarray(self, dct, dims):
285
- """Convert posterior_predictive or prediction samples to xarray."""
286
- data = {}
287
- for k, ary in dct.items():
288
- shape = ary.shape
289
- if shape[0] == self.nchains and shape[1] == self.ndraws:
290
- data[k] = ary
291
- elif shape[0] == self.nchains * self.ndraws:
292
- data[k] = ary.reshape((self.nchains, self.ndraws, *shape[1:]))
293
- else:
294
- data[k] = utils.expand_dims(ary)
295
- _log.warning(
296
- "posterior predictive shape not compatible with number of chains and draws. "
297
- "This can mean that some draws or even whole chains are not represented."
298
- )
299
- return dict_to_dataset(
300
- data,
301
- library=self.numpyro,
302
- coords=self.coords,
303
- dims=dims,
304
- index_origin=self.index_origin,
305
- )
306
-
307
- @requires("posterior_predictive")
308
- def posterior_predictive_to_xarray(self):
309
- """Convert posterior_predictive samples to xarray."""
310
- return self.translate_posterior_predictive_dict_to_xarray(
311
- self.posterior_predictive, self.dims
312
- )
313
-
314
- @requires("predictions")
315
- def predictions_to_xarray(self):
316
- """Convert predictions to xarray."""
317
- return self.translate_posterior_predictive_dict_to_xarray(self.predictions, self.pred_dims)
318
-
319
- def priors_to_xarray(self):
320
- """Convert prior samples (and if possible prior predictive too) to xarray."""
321
- if self.prior is None:
322
- return {"prior": None, "prior_predictive": None}
323
- if self.posterior is not None:
324
- prior_vars = list(self._samples.keys())
325
- prior_predictive_vars = [key for key in self.prior.keys() if key not in prior_vars]
326
- else:
327
- prior_vars = self.prior.keys()
328
- prior_predictive_vars = None
329
- priors_dict = {
330
- group: (
331
- None
332
- if var_names is None
333
- else dict_to_dataset(
334
- {k: utils.expand_dims(self.prior[k]) for k in var_names},
335
- library=self.numpyro,
336
- coords=self.coords,
337
- dims=self.dims,
338
- index_origin=self.index_origin,
339
- )
340
- )
341
- for group, var_names in zip(
342
- ("prior", "prior_predictive"), (prior_vars, prior_predictive_vars)
343
- )
344
- }
345
- return priors_dict
346
-
347
- @requires("observations")
348
- @requires("model")
349
- def observed_data_to_xarray(self):
350
- """Convert observed data to xarray."""
351
- return dict_to_dataset(
352
- self.observations,
353
- library=self.numpyro,
354
- dims=self.dims,
355
- coords=self.coords,
356
- default_dims=[],
357
- index_origin=self.index_origin,
358
- )
359
-
360
- @requires("constant_data")
361
- def constant_data_to_xarray(self):
362
- """Convert constant_data to xarray."""
363
- return dict_to_dataset(
364
- self.constant_data,
365
- library=self.numpyro,
366
- dims=self.dims,
367
- coords=self.coords,
368
- default_dims=[],
369
- index_origin=self.index_origin,
370
- )
371
-
372
- @requires("predictions_constant_data")
373
- def predictions_constant_data_to_xarray(self):
374
- """Convert predictions_constant_data to xarray."""
375
- return dict_to_dataset(
376
- self.predictions_constant_data,
377
- library=self.numpyro,
378
- dims=self.pred_dims,
379
- coords=self.coords,
380
- default_dims=[],
381
- index_origin=self.index_origin,
382
- )
383
-
384
- def to_inference_data(self):
385
- """Convert all available data to an InferenceData object.
386
-
387
- Note that if groups can not be created (i.e., there is no `trace`, so
388
- the `posterior` and `sample_stats` can not be extracted), then the InferenceData
389
- will not have those groups.
390
- """
391
- return InferenceData(
392
- **{
393
- "posterior": self.posterior_to_xarray(),
394
- "sample_stats": self.sample_stats_to_xarray(),
395
- "log_likelihood": self.log_likelihood_to_xarray(),
396
- "posterior_predictive": self.posterior_predictive_to_xarray(),
397
- "predictions": self.predictions_to_xarray(),
398
- **self.priors_to_xarray(),
399
- "observed_data": self.observed_data_to_xarray(),
400
- "constant_data": self.constant_data_to_xarray(),
401
- "predictions_constant_data": self.predictions_constant_data_to_xarray(),
402
- }
403
- )
404
-
405
- @requires("posterior")
406
- @requires("model")
407
- def infer_dims(self) -> Dict[str, List[str]]:
408
- dims = infer_dims(self.model, self._args, self._kwargs)
409
- if self.extra_event_dims:
410
- dims = _add_dims(dims, self.extra_event_dims)
411
- return dims
412
-
413
- @requires("posterior")
414
- @requires("model")
415
- @requires("predictions")
416
- def infer_pred_dims(self) -> Dict[str, List[str]]:
417
- dims = infer_dims(self.model, self._args, self._kwargs)
418
- if self.extra_event_dims:
419
- dims = _add_dims(dims, self.extra_event_dims)
420
- return dims
421
-
422
-
423
- def from_numpyro(
424
- posterior=None,
425
- *,
426
- prior=None,
427
- posterior_predictive=None,
428
- predictions=None,
429
- constant_data=None,
430
- predictions_constant_data=None,
431
- log_likelihood=None,
432
- index_origin=None,
433
- coords=None,
434
- dims=None,
435
- pred_dims=None,
436
- extra_event_dims=None,
437
- num_chains=1,
438
- ):
439
- """Convert NumPyro data into an InferenceData object.
440
-
441
- If no dims are provided, this will infer batch dim names from NumPyro model plates.
442
- For event dim names, such as with the ZeroSumNormal, `infer={"event_dims":dim_names}`
443
- can be provided in numpyro.sample, i.e.::
444
-
445
- # equivalent to dims entry, {"gamma": ["groups"]}
446
- gamma = numpyro.sample(
447
- "gamma",
448
- dist.ZeroSumNormal(1, event_shape=(n_groups,)),
449
- infer={"event_dims":["groups"]}
450
- )
451
-
452
- There is also an additional `extra_event_dims` input to cover any edge cases, for instance
453
- deterministic sites with event dims (which dont have an `infer` argument to provide metadata).
454
-
455
- For a usage example read the
456
- :ref:`Creating InferenceData section on from_numpyro <creating_InferenceData>`
457
-
458
- Parameters
459
- ----------
460
- posterior : numpyro.mcmc.MCMC
461
- Fitted MCMC object from NumPyro
462
- prior: dict
463
- Prior samples from a NumPyro model
464
- posterior_predictive : dict
465
- Posterior predictive samples for the posterior
466
- predictions: dict
467
- Out of sample predictions
468
- constant_data: dict
469
- Dictionary containing constant data variables mapped to their values.
470
- predictions_constant_data: dict
471
- Constant data used for out-of-sample predictions.
472
- index_origin : int, optional
473
- coords : dict[str] -> list[str]
474
- Map of dimensions to coordinates
475
- dims : dict[str] -> list[str]
476
- Map variable names to their coordinates. Will be inferred if they are not provided.
477
- pred_dims: dict
478
- Dims for predictions data. Map variable names to their coordinates. Default behavior is to
479
- infer dims if this is not provided
480
- num_chains: int
481
- Number of chains used for sampling. Ignored if posterior is present.
482
- """
483
- return NumPyroConverter(
484
- posterior=posterior,
485
- prior=prior,
486
- posterior_predictive=posterior_predictive,
487
- predictions=predictions,
488
- constant_data=constant_data,
489
- predictions_constant_data=predictions_constant_data,
490
- log_likelihood=log_likelihood,
491
- index_origin=index_origin,
492
- coords=coords,
493
- dims=dims,
494
- pred_dims=pred_dims,
495
- extra_event_dims=extra_event_dims,
496
- num_chains=num_chains,
497
- ).to_inference_data()