arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- arviz/__init__.py +52 -367
- arviz-1.0.0rc0.dist-info/METADATA +182 -0
- arviz-1.0.0rc0.dist-info/RECORD +5 -0
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
- {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
- arviz/data/__init__.py +0 -55
- arviz/data/base.py +0 -596
- arviz/data/converters.py +0 -203
- arviz/data/datasets.py +0 -161
- arviz/data/example_data/code/radon/radon.json +0 -326
- arviz/data/example_data/data/centered_eight.nc +0 -0
- arviz/data/example_data/data/non_centered_eight.nc +0 -0
- arviz/data/example_data/data_local.json +0 -12
- arviz/data/example_data/data_remote.json +0 -58
- arviz/data/inference_data.py +0 -2386
- arviz/data/io_beanmachine.py +0 -112
- arviz/data/io_cmdstan.py +0 -1036
- arviz/data/io_cmdstanpy.py +0 -1233
- arviz/data/io_datatree.py +0 -23
- arviz/data/io_dict.py +0 -462
- arviz/data/io_emcee.py +0 -317
- arviz/data/io_json.py +0 -54
- arviz/data/io_netcdf.py +0 -68
- arviz/data/io_numpyro.py +0 -497
- arviz/data/io_pyjags.py +0 -378
- arviz/data/io_pyro.py +0 -333
- arviz/data/io_pystan.py +0 -1095
- arviz/data/io_zarr.py +0 -46
- arviz/data/utils.py +0 -139
- arviz/labels.py +0 -210
- arviz/plots/__init__.py +0 -61
- arviz/plots/autocorrplot.py +0 -171
- arviz/plots/backends/__init__.py +0 -223
- arviz/plots/backends/bokeh/__init__.py +0 -166
- arviz/plots/backends/bokeh/autocorrplot.py +0 -101
- arviz/plots/backends/bokeh/bfplot.py +0 -23
- arviz/plots/backends/bokeh/bpvplot.py +0 -193
- arviz/plots/backends/bokeh/compareplot.py +0 -167
- arviz/plots/backends/bokeh/densityplot.py +0 -239
- arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
- arviz/plots/backends/bokeh/distplot.py +0 -183
- arviz/plots/backends/bokeh/dotplot.py +0 -113
- arviz/plots/backends/bokeh/ecdfplot.py +0 -73
- arviz/plots/backends/bokeh/elpdplot.py +0 -203
- arviz/plots/backends/bokeh/energyplot.py +0 -155
- arviz/plots/backends/bokeh/essplot.py +0 -176
- arviz/plots/backends/bokeh/forestplot.py +0 -772
- arviz/plots/backends/bokeh/hdiplot.py +0 -54
- arviz/plots/backends/bokeh/kdeplot.py +0 -268
- arviz/plots/backends/bokeh/khatplot.py +0 -163
- arviz/plots/backends/bokeh/lmplot.py +0 -185
- arviz/plots/backends/bokeh/loopitplot.py +0 -211
- arviz/plots/backends/bokeh/mcseplot.py +0 -184
- arviz/plots/backends/bokeh/pairplot.py +0 -328
- arviz/plots/backends/bokeh/parallelplot.py +0 -81
- arviz/plots/backends/bokeh/posteriorplot.py +0 -324
- arviz/plots/backends/bokeh/ppcplot.py +0 -379
- arviz/plots/backends/bokeh/rankplot.py +0 -149
- arviz/plots/backends/bokeh/separationplot.py +0 -107
- arviz/plots/backends/bokeh/traceplot.py +0 -436
- arviz/plots/backends/bokeh/violinplot.py +0 -164
- arviz/plots/backends/matplotlib/__init__.py +0 -124
- arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
- arviz/plots/backends/matplotlib/bfplot.py +0 -78
- arviz/plots/backends/matplotlib/bpvplot.py +0 -177
- arviz/plots/backends/matplotlib/compareplot.py +0 -135
- arviz/plots/backends/matplotlib/densityplot.py +0 -194
- arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
- arviz/plots/backends/matplotlib/distplot.py +0 -178
- arviz/plots/backends/matplotlib/dotplot.py +0 -116
- arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
- arviz/plots/backends/matplotlib/elpdplot.py +0 -189
- arviz/plots/backends/matplotlib/energyplot.py +0 -113
- arviz/plots/backends/matplotlib/essplot.py +0 -180
- arviz/plots/backends/matplotlib/forestplot.py +0 -656
- arviz/plots/backends/matplotlib/hdiplot.py +0 -48
- arviz/plots/backends/matplotlib/kdeplot.py +0 -177
- arviz/plots/backends/matplotlib/khatplot.py +0 -241
- arviz/plots/backends/matplotlib/lmplot.py +0 -149
- arviz/plots/backends/matplotlib/loopitplot.py +0 -144
- arviz/plots/backends/matplotlib/mcseplot.py +0 -161
- arviz/plots/backends/matplotlib/pairplot.py +0 -355
- arviz/plots/backends/matplotlib/parallelplot.py +0 -58
- arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
- arviz/plots/backends/matplotlib/ppcplot.py +0 -478
- arviz/plots/backends/matplotlib/rankplot.py +0 -119
- arviz/plots/backends/matplotlib/separationplot.py +0 -97
- arviz/plots/backends/matplotlib/traceplot.py +0 -526
- arviz/plots/backends/matplotlib/tsplot.py +0 -121
- arviz/plots/backends/matplotlib/violinplot.py +0 -148
- arviz/plots/bfplot.py +0 -128
- arviz/plots/bpvplot.py +0 -308
- arviz/plots/compareplot.py +0 -177
- arviz/plots/densityplot.py +0 -284
- arviz/plots/distcomparisonplot.py +0 -197
- arviz/plots/distplot.py +0 -233
- arviz/plots/dotplot.py +0 -233
- arviz/plots/ecdfplot.py +0 -372
- arviz/plots/elpdplot.py +0 -174
- arviz/plots/energyplot.py +0 -147
- arviz/plots/essplot.py +0 -319
- arviz/plots/forestplot.py +0 -304
- arviz/plots/hdiplot.py +0 -211
- arviz/plots/kdeplot.py +0 -357
- arviz/plots/khatplot.py +0 -236
- arviz/plots/lmplot.py +0 -380
- arviz/plots/loopitplot.py +0 -224
- arviz/plots/mcseplot.py +0 -194
- arviz/plots/pairplot.py +0 -281
- arviz/plots/parallelplot.py +0 -204
- arviz/plots/plot_utils.py +0 -599
- arviz/plots/posteriorplot.py +0 -298
- arviz/plots/ppcplot.py +0 -369
- arviz/plots/rankplot.py +0 -232
- arviz/plots/separationplot.py +0 -167
- arviz/plots/styles/arviz-bluish.mplstyle +0 -1
- arviz/plots/styles/arviz-brownish.mplstyle +0 -1
- arviz/plots/styles/arviz-colors.mplstyle +0 -2
- arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
- arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
- arviz/plots/styles/arviz-doc.mplstyle +0 -88
- arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
- arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
- arviz/plots/styles/arviz-greenish.mplstyle +0 -1
- arviz/plots/styles/arviz-orangish.mplstyle +0 -1
- arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
- arviz/plots/styles/arviz-purplish.mplstyle +0 -1
- arviz/plots/styles/arviz-redish.mplstyle +0 -1
- arviz/plots/styles/arviz-royish.mplstyle +0 -1
- arviz/plots/styles/arviz-viridish.mplstyle +0 -1
- arviz/plots/styles/arviz-white.mplstyle +0 -40
- arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
- arviz/plots/traceplot.py +0 -273
- arviz/plots/tsplot.py +0 -440
- arviz/plots/violinplot.py +0 -192
- arviz/preview.py +0 -58
- arviz/py.typed +0 -0
- arviz/rcparams.py +0 -606
- arviz/sel_utils.py +0 -223
- arviz/static/css/style.css +0 -340
- arviz/static/html/icons-svg-inline.html +0 -15
- arviz/stats/__init__.py +0 -37
- arviz/stats/density_utils.py +0 -1013
- arviz/stats/diagnostics.py +0 -1013
- arviz/stats/ecdf_utils.py +0 -324
- arviz/stats/stats.py +0 -2422
- arviz/stats/stats_refitting.py +0 -119
- arviz/stats/stats_utils.py +0 -609
- arviz/tests/__init__.py +0 -1
- arviz/tests/base_tests/__init__.py +0 -1
- arviz/tests/base_tests/test_data.py +0 -1679
- arviz/tests/base_tests/test_data_zarr.py +0 -143
- arviz/tests/base_tests/test_diagnostics.py +0 -511
- arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
- arviz/tests/base_tests/test_helpers.py +0 -18
- arviz/tests/base_tests/test_labels.py +0 -69
- arviz/tests/base_tests/test_plot_utils.py +0 -342
- arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
- arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
- arviz/tests/base_tests/test_rcparams.py +0 -317
- arviz/tests/base_tests/test_stats.py +0 -925
- arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
- arviz/tests/base_tests/test_stats_numba.py +0 -45
- arviz/tests/base_tests/test_stats_utils.py +0 -384
- arviz/tests/base_tests/test_utils.py +0 -376
- arviz/tests/base_tests/test_utils_numba.py +0 -87
- arviz/tests/conftest.py +0 -46
- arviz/tests/external_tests/__init__.py +0 -1
- arviz/tests/external_tests/test_data_beanmachine.py +0 -78
- arviz/tests/external_tests/test_data_cmdstan.py +0 -398
- arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
- arviz/tests/external_tests/test_data_emcee.py +0 -166
- arviz/tests/external_tests/test_data_numpyro.py +0 -434
- arviz/tests/external_tests/test_data_pyjags.py +0 -119
- arviz/tests/external_tests/test_data_pyro.py +0 -260
- arviz/tests/external_tests/test_data_pystan.py +0 -307
- arviz/tests/helpers.py +0 -677
- arviz/utils.py +0 -773
- arviz/wrappers/__init__.py +0 -13
- arviz/wrappers/base.py +0 -236
- arviz/wrappers/wrap_pymc.py +0 -36
- arviz/wrappers/wrap_stan.py +0 -148
- arviz-0.23.3.dist-info/METADATA +0 -264
- arviz-0.23.3.dist-info/RECORD +0 -183
- arviz-0.23.3.dist-info/top_level.txt +0 -1
arviz/wrappers/__init__.py
DELETED
|
@@ -1,13 +0,0 @@
|
|
|
1
|
-
"""Sampling wrappers."""
|
|
2
|
-
|
|
3
|
-
from .base import SamplingWrapper
|
|
4
|
-
from .wrap_stan import PyStan2SamplingWrapper, PyStanSamplingWrapper, CmdStanPySamplingWrapper
|
|
5
|
-
from .wrap_pymc import PyMCSamplingWrapper
|
|
6
|
-
|
|
7
|
-
__all__ = [
|
|
8
|
-
"CmdStanPySamplingWrapper",
|
|
9
|
-
"PyMCSamplingWrapper",
|
|
10
|
-
"PyStan2SamplingWrapper",
|
|
11
|
-
"PyStanSamplingWrapper",
|
|
12
|
-
"SamplingWrapper",
|
|
13
|
-
]
|
arviz/wrappers/base.py
DELETED
|
@@ -1,236 +0,0 @@
|
|
|
1
|
-
# pylint: disable=too-many-instance-attributes,too-many-arguments
|
|
2
|
-
"""Base class for sampling wrappers."""
|
|
3
|
-
from xarray import apply_ufunc
|
|
4
|
-
|
|
5
|
-
# from ..data import InferenceData
|
|
6
|
-
from ..stats import wrap_xarray_ufunc as _wrap_xarray_ufunc
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
class SamplingWrapper:
|
|
10
|
-
"""Class wrapping sampling routines for its usage via ArviZ.
|
|
11
|
-
|
|
12
|
-
Using a common class, all inference backends can be supported in ArviZ. Hence, statistical
|
|
13
|
-
functions requiring refitting like Leave Future Out or Simulation Based Calibration can be
|
|
14
|
-
performed from ArviZ.
|
|
15
|
-
|
|
16
|
-
For usage examples see user guide pages on :ref:`wrapper_guide`.See other
|
|
17
|
-
SamplingWrapper classes at :ref:`wrappers api section <wrappers_api>`.
|
|
18
|
-
|
|
19
|
-
Parameters
|
|
20
|
-
----------
|
|
21
|
-
model
|
|
22
|
-
The model object used for sampling.
|
|
23
|
-
idata_orig : InferenceData, optional
|
|
24
|
-
Original InferenceData object.
|
|
25
|
-
log_lik_fun : callable, optional
|
|
26
|
-
For simple cases where the pointwise log likelihood is a Python function, this
|
|
27
|
-
function will be used to calculate the log likelihood. Otherwise,
|
|
28
|
-
``point_log_likelihood`` method must be implemented. It's callback must be
|
|
29
|
-
``log_lik_fun(*args, **log_lik_kwargs)`` and will be called using
|
|
30
|
-
:func:`wrap_xarray_ufunc` or :func:`xarray:xarray.apply_ufunc` depending
|
|
31
|
-
on the value of `is_ufunc`.
|
|
32
|
-
|
|
33
|
-
For more details on ``args`` or ``log_lik_kwargs`` see the notes and
|
|
34
|
-
parameters ``posterior_vars`` and ``log_lik_kwargs``.
|
|
35
|
-
is_ufunc : bool, default True
|
|
36
|
-
If True, call ``log_lik_fun`` using :func:`xarray:xarray.apply_ufunc` otherwise
|
|
37
|
-
use :func:`wrap_xarray_ufunc`.
|
|
38
|
-
posterior_vars : list of str, optional
|
|
39
|
-
List of variable names to unpack as ``args`` for ``log_lik_fun``. Each string in
|
|
40
|
-
the list will be used to retrieve a DataArray from the Dataset in the posterior
|
|
41
|
-
group and passed to ``log_lik_fun``.
|
|
42
|
-
sample_kwargs : dict, optional
|
|
43
|
-
Sampling kwargs are stored as class attributes for their usage in the ``sample``
|
|
44
|
-
method.
|
|
45
|
-
idata_kwargs : dict, optional
|
|
46
|
-
kwargs are stored as class attributes to be used in the ``get_inference_data`` method.
|
|
47
|
-
log_lik_kwargs : dict, optional
|
|
48
|
-
Keyword arguments passed to ``log_lik_fun``.
|
|
49
|
-
apply_ufunc_kwargs : dict, optional
|
|
50
|
-
Passed to :func:`xarray:xarray.apply_ufunc` or :func:`wrap_xarray_ufunc`.
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
Warnings
|
|
54
|
-
--------
|
|
55
|
-
Sampling wrappers are an experimental feature in a very early stage. Please use them
|
|
56
|
-
with caution.
|
|
57
|
-
|
|
58
|
-
Notes
|
|
59
|
-
-----
|
|
60
|
-
Example of ``log_like_fun`` usage.
|
|
61
|
-
"""
|
|
62
|
-
|
|
63
|
-
def __init__(
|
|
64
|
-
self,
|
|
65
|
-
model,
|
|
66
|
-
idata_orig=None,
|
|
67
|
-
log_lik_fun=None,
|
|
68
|
-
is_ufunc=True,
|
|
69
|
-
posterior_vars=None,
|
|
70
|
-
sample_kwargs=None,
|
|
71
|
-
idata_kwargs=None,
|
|
72
|
-
log_lik_kwargs=None,
|
|
73
|
-
apply_ufunc_kwargs=None,
|
|
74
|
-
):
|
|
75
|
-
self.model = model
|
|
76
|
-
|
|
77
|
-
# if not isinstance(idata_orig, InferenceData) or idata_orig is not None:
|
|
78
|
-
# raise TypeError("idata_orig must be of InferenceData type or None")
|
|
79
|
-
self.idata_orig = idata_orig
|
|
80
|
-
|
|
81
|
-
if log_lik_fun is None or callable(log_lik_fun):
|
|
82
|
-
self.log_lik_fun = log_lik_fun
|
|
83
|
-
self.is_ufunc = is_ufunc
|
|
84
|
-
self.posterior_vars = posterior_vars
|
|
85
|
-
else:
|
|
86
|
-
raise TypeError("log_like_fun must be a callable object or None")
|
|
87
|
-
|
|
88
|
-
self.sample_kwargs = {} if sample_kwargs is None else sample_kwargs
|
|
89
|
-
self.idata_kwargs = {} if idata_kwargs is None else idata_kwargs
|
|
90
|
-
self.log_lik_kwargs = {} if log_lik_kwargs is None else log_lik_kwargs
|
|
91
|
-
self.apply_ufunc_kwargs = {} if apply_ufunc_kwargs is None else apply_ufunc_kwargs
|
|
92
|
-
|
|
93
|
-
def sel_observations(self, idx):
|
|
94
|
-
"""Select a subset of the observations in idata_orig.
|
|
95
|
-
|
|
96
|
-
**Not implemented**: This method must be implemented by the SamplingWrapper subclasses.
|
|
97
|
-
It is documented here to show its format and call signature.
|
|
98
|
-
|
|
99
|
-
Parameters
|
|
100
|
-
----------
|
|
101
|
-
idx
|
|
102
|
-
Indexes to separate from the rest of the observed data.
|
|
103
|
-
|
|
104
|
-
Returns
|
|
105
|
-
-------
|
|
106
|
-
modified_observed_data
|
|
107
|
-
Observed data whose index is *not* ``idx``
|
|
108
|
-
excluded_observed_data
|
|
109
|
-
Observed data whose index is ``idx``
|
|
110
|
-
"""
|
|
111
|
-
raise NotImplementedError("sel_observations method must be implemented for each subclass")
|
|
112
|
-
|
|
113
|
-
def sample(self, modified_observed_data):
|
|
114
|
-
"""Sample ``self.model`` on the ``modified_observed_data`` subset.
|
|
115
|
-
|
|
116
|
-
**Not implemented**: This method must be implemented by the SamplingWrapper subclasses.
|
|
117
|
-
It is documented here to show its format and call signature.
|
|
118
|
-
|
|
119
|
-
Parameters
|
|
120
|
-
----------
|
|
121
|
-
modified_observed_data
|
|
122
|
-
Data to fit the model on.
|
|
123
|
-
|
|
124
|
-
Returns
|
|
125
|
-
-------
|
|
126
|
-
fitted_model
|
|
127
|
-
Result of the fit.
|
|
128
|
-
"""
|
|
129
|
-
raise NotImplementedError("sample method must be implemented for each subclass")
|
|
130
|
-
|
|
131
|
-
def get_inference_data(self, fitted_model):
|
|
132
|
-
"""Convert the ``fitted_model`` to an InferenceData object.
|
|
133
|
-
|
|
134
|
-
**Not implemented**: This method must be implemented by the SamplingWrapper subclasses.
|
|
135
|
-
It is documented here to show its format and call signature.
|
|
136
|
-
|
|
137
|
-
Parameters
|
|
138
|
-
----------
|
|
139
|
-
fitted_model
|
|
140
|
-
Result of the current fit.
|
|
141
|
-
|
|
142
|
-
Returns
|
|
143
|
-
-------
|
|
144
|
-
idata_current: InferenceData
|
|
145
|
-
InferenceData object containing the samples in ``fitted_model``
|
|
146
|
-
"""
|
|
147
|
-
raise NotImplementedError("get_inference_data method must be implemented for each subclass")
|
|
148
|
-
|
|
149
|
-
def log_likelihood__i(self, excluded_obs, idata__i):
|
|
150
|
-
r"""Get the log likelilhood samples :math:`\log p_{post(-i)}(y_i)`.
|
|
151
|
-
|
|
152
|
-
Calculate the log likelihood of the data contained in excluded_obs using the
|
|
153
|
-
model fitted with this data excluded, the results of which are stored in ``idata__i``.
|
|
154
|
-
|
|
155
|
-
Parameters
|
|
156
|
-
----------
|
|
157
|
-
excluded_obs
|
|
158
|
-
Observations for which to calculate their log likelihood. The second item from
|
|
159
|
-
the tuple returned by `sel_observations` is passed as this argument.
|
|
160
|
-
idata__i: InferenceData
|
|
161
|
-
Inference results of refitting the data excluding some observations. The
|
|
162
|
-
result of `get_inference_data` is used as this argument.
|
|
163
|
-
|
|
164
|
-
Returns
|
|
165
|
-
-------
|
|
166
|
-
log_likelihood: xr.Dataarray
|
|
167
|
-
Log likelihood of ``excluded_obs`` evaluated at each of the posterior samples
|
|
168
|
-
stored in ``idata__i``.
|
|
169
|
-
"""
|
|
170
|
-
if self.log_lik_fun is None:
|
|
171
|
-
raise NotImplementedError(
|
|
172
|
-
"When `log_like_fun` is not set during class initialization "
|
|
173
|
-
"log_likelihood__i method must be overwritten"
|
|
174
|
-
)
|
|
175
|
-
posterior = idata__i.posterior
|
|
176
|
-
arys = (*excluded_obs, *[posterior[var_name] for var_name in self.posterior_vars])
|
|
177
|
-
ufunc_applier = apply_ufunc if self.is_ufunc else _wrap_xarray_ufunc
|
|
178
|
-
log_lik_idx = ufunc_applier(
|
|
179
|
-
self.log_lik_fun,
|
|
180
|
-
*arys,
|
|
181
|
-
kwargs=self.log_lik_kwargs,
|
|
182
|
-
**self.apply_ufunc_kwargs,
|
|
183
|
-
)
|
|
184
|
-
return log_lik_idx
|
|
185
|
-
|
|
186
|
-
def _check_method_is_implemented(self, method, *args):
|
|
187
|
-
"""Check a given method is implemented."""
|
|
188
|
-
try:
|
|
189
|
-
getattr(self, method)(*args)
|
|
190
|
-
except NotImplementedError:
|
|
191
|
-
return False
|
|
192
|
-
except: # pylint: disable=bare-except
|
|
193
|
-
return True
|
|
194
|
-
return True
|
|
195
|
-
|
|
196
|
-
def check_implemented_methods(self, methods):
|
|
197
|
-
"""Check that all methods listed are implemented.
|
|
198
|
-
|
|
199
|
-
Not all functions that require refitting need to have all the methods implemented in
|
|
200
|
-
order to work properly. This function should be used before using the SamplingWrapper and
|
|
201
|
-
its subclasses to get informative error messages.
|
|
202
|
-
|
|
203
|
-
Parameters
|
|
204
|
-
----------
|
|
205
|
-
methods: list
|
|
206
|
-
Check all elements in methods are implemented.
|
|
207
|
-
|
|
208
|
-
Returns
|
|
209
|
-
-------
|
|
210
|
-
List with all non implemented methods
|
|
211
|
-
"""
|
|
212
|
-
supported_methods_1arg = (
|
|
213
|
-
"sel_observations",
|
|
214
|
-
"sample",
|
|
215
|
-
"get_inference_data",
|
|
216
|
-
)
|
|
217
|
-
supported_methods_2args = ("log_likelihood__i",)
|
|
218
|
-
supported_methods = [*supported_methods_1arg, *supported_methods_2args]
|
|
219
|
-
bad_methods = [method for method in methods if method not in supported_methods]
|
|
220
|
-
if bad_methods:
|
|
221
|
-
raise ValueError(
|
|
222
|
-
f"Not all method(s) in {bad_methods} supported. "
|
|
223
|
-
f"Supported methods in SamplingWrapper subclasses are:{supported_methods}"
|
|
224
|
-
)
|
|
225
|
-
|
|
226
|
-
not_implemented = []
|
|
227
|
-
for method in methods:
|
|
228
|
-
if method in supported_methods_1arg:
|
|
229
|
-
if self._check_method_is_implemented(method, 1):
|
|
230
|
-
continue
|
|
231
|
-
not_implemented.append(method)
|
|
232
|
-
elif method in supported_methods_2args:
|
|
233
|
-
if self._check_method_is_implemented(method, 1, 1):
|
|
234
|
-
continue
|
|
235
|
-
not_implemented.append(method)
|
|
236
|
-
return not_implemented
|
arviz/wrappers/wrap_pymc.py
DELETED
|
@@ -1,36 +0,0 @@
|
|
|
1
|
-
# pylint: disable=arguments-differ
|
|
2
|
-
"""Base class for PyMC interface wrappers."""
|
|
3
|
-
from .base import SamplingWrapper
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
# pylint: disable=abstract-method
|
|
7
|
-
class PyMCSamplingWrapper(SamplingWrapper):
|
|
8
|
-
"""PyMC (4.0+) sampling wrapper base class.
|
|
9
|
-
|
|
10
|
-
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
|
|
11
|
-
description. An example of ``PyMCSamplingWrapper`` usage can be found
|
|
12
|
-
in the :ref:`pymc_refitting` notebook.
|
|
13
|
-
|
|
14
|
-
Warnings
|
|
15
|
-
--------
|
|
16
|
-
Sampling wrappers are an experimental feature in a very early stage. Please use them
|
|
17
|
-
with caution.
|
|
18
|
-
"""
|
|
19
|
-
|
|
20
|
-
def sample(self, modified_observed_data):
|
|
21
|
-
"""Update data and sample model on modified_observed_data."""
|
|
22
|
-
import pymc # pylint: disable=import-error
|
|
23
|
-
|
|
24
|
-
with self.model:
|
|
25
|
-
pymc.set_data(modified_observed_data)
|
|
26
|
-
idata = pymc.sample(
|
|
27
|
-
**self.sample_kwargs,
|
|
28
|
-
)
|
|
29
|
-
return idata
|
|
30
|
-
|
|
31
|
-
def get_inference_data(self, fitted_model):
|
|
32
|
-
"""Return sampling result without modifying.
|
|
33
|
-
|
|
34
|
-
PyMC sampling already returns and InferenceData object.
|
|
35
|
-
"""
|
|
36
|
-
return fitted_model
|
arviz/wrappers/wrap_stan.py
DELETED
|
@@ -1,148 +0,0 @@
|
|
|
1
|
-
# pylint: disable=arguments-differ
|
|
2
|
-
"""Base class for Stan interface wrappers."""
|
|
3
|
-
from typing import Union
|
|
4
|
-
|
|
5
|
-
from ..data import from_cmdstanpy, from_pystan
|
|
6
|
-
from .base import SamplingWrapper
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
# pylint: disable=abstract-method
|
|
10
|
-
class StanSamplingWrapper(SamplingWrapper):
|
|
11
|
-
"""Stan sampling wrapper base class.
|
|
12
|
-
|
|
13
|
-
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
|
|
14
|
-
description. An example of ``PyStanSamplingWrapper`` usage can be found
|
|
15
|
-
in the :ref:`pystan_refitting` notebook. For usage examples of other wrappers
|
|
16
|
-
see the user guide pages on :ref:`wrapper_guide`.
|
|
17
|
-
|
|
18
|
-
Warnings
|
|
19
|
-
--------
|
|
20
|
-
Sampling wrappers are an experimental feature in a very early stage. Please use them
|
|
21
|
-
with caution.
|
|
22
|
-
|
|
23
|
-
See Also
|
|
24
|
-
--------
|
|
25
|
-
SamplingWrapper
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
def sel_observations(self, idx):
|
|
29
|
-
"""Select a subset of the observations in idata_orig.
|
|
30
|
-
|
|
31
|
-
**Not implemented**: This method must be implemented on a model basis.
|
|
32
|
-
It is documented here to show its format and call signature.
|
|
33
|
-
|
|
34
|
-
Parameters
|
|
35
|
-
----------
|
|
36
|
-
idx
|
|
37
|
-
Indexes to separate from the rest of the observed data.
|
|
38
|
-
|
|
39
|
-
Returns
|
|
40
|
-
-------
|
|
41
|
-
modified_observed_data : dict
|
|
42
|
-
Dictionary containing both excluded and included data but properly divided
|
|
43
|
-
in the different keys. Passed to ``data`` argument of ``model.sampling``.
|
|
44
|
-
excluded_observed_data : str
|
|
45
|
-
Variable name containing the pointwise log likelihood data of the excluded
|
|
46
|
-
data. As PyStan cannot call C++ functions and log_likelihood__i is already
|
|
47
|
-
calculated *during* the simulation, instead of the value on which to evaluate
|
|
48
|
-
the likelihood, ``log_likelihood__i`` expects a string so it can extract the
|
|
49
|
-
corresponding data from the InferenceData object.
|
|
50
|
-
"""
|
|
51
|
-
raise NotImplementedError("sel_observations must be implemented on a model basis")
|
|
52
|
-
|
|
53
|
-
def get_inference_data(self, fitted_model): # pylint: disable=arguments-renamed
|
|
54
|
-
"""Convert the fit object returned by ``self.sample`` to InferenceData."""
|
|
55
|
-
if fitted_model.__class__.__name__ == "CmdStanMCMC":
|
|
56
|
-
idata = from_cmdstanpy(posterior=fitted_model, **self.idata_kwargs)
|
|
57
|
-
else:
|
|
58
|
-
idata = from_pystan(posterior=fitted_model, **self.idata_kwargs)
|
|
59
|
-
return idata
|
|
60
|
-
|
|
61
|
-
def log_likelihood__i(self, excluded_obs, idata__i):
|
|
62
|
-
"""Retrieve the log likelihood of the excluded observations from ``idata__i``."""
|
|
63
|
-
return idata__i.log_likelihood[excluded_obs]
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
class PyStan2SamplingWrapper(StanSamplingWrapper):
|
|
67
|
-
"""PyStan (2.x) sampling wrapper base class.
|
|
68
|
-
|
|
69
|
-
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
|
|
70
|
-
description. An example of ``PyStanSamplingWrapper`` usage can be found
|
|
71
|
-
in the :ref:`pystan_refitting` notebook. For usage examples of other wrappers
|
|
72
|
-
see the user guide pages on :ref:`wrapper_guide`.
|
|
73
|
-
|
|
74
|
-
Warnings
|
|
75
|
-
--------
|
|
76
|
-
Sampling wrappers are an experimental feature in a very early stage. Please use them
|
|
77
|
-
with caution.
|
|
78
|
-
|
|
79
|
-
See Also
|
|
80
|
-
--------
|
|
81
|
-
SamplingWrapper
|
|
82
|
-
"""
|
|
83
|
-
|
|
84
|
-
def sample(self, modified_observed_data):
|
|
85
|
-
"""Resample the PyStan model stored in self.model on modified_observed_data."""
|
|
86
|
-
fit = self.model.sampling(data=modified_observed_data, **self.sample_kwargs)
|
|
87
|
-
return fit
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
class PyStanSamplingWrapper(StanSamplingWrapper):
|
|
91
|
-
"""PyStan (3.0+) sampling wrapper base class.
|
|
92
|
-
|
|
93
|
-
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
|
|
94
|
-
description. An example of ``PyStanSamplingWrapper`` usage can be found
|
|
95
|
-
in the :ref:`pystan_refitting` notebook.
|
|
96
|
-
|
|
97
|
-
Warnings
|
|
98
|
-
--------
|
|
99
|
-
Sampling wrappers are an experimental feature in a very early stage. Please use them
|
|
100
|
-
with caution.
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
def sample(self, modified_observed_data):
|
|
104
|
-
"""Rebuild and resample the PyStan model on modified_observed_data."""
|
|
105
|
-
import stan # pylint: disable=import-error,import-outside-toplevel
|
|
106
|
-
|
|
107
|
-
self.model: Union[str, stan.Model]
|
|
108
|
-
if isinstance(self.model, str):
|
|
109
|
-
program_code = self.model
|
|
110
|
-
else:
|
|
111
|
-
program_code = self.model.program_code
|
|
112
|
-
self.model = stan.build(program_code, data=modified_observed_data)
|
|
113
|
-
fit = self.model.sample(**self.sample_kwargs)
|
|
114
|
-
return fit
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
class CmdStanPySamplingWrapper(StanSamplingWrapper):
|
|
118
|
-
"""CmdStanPy sampling wrapper base class.
|
|
119
|
-
|
|
120
|
-
See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
|
|
121
|
-
description. An example of ``CmdStanPySamplingWrapper`` usage can be found
|
|
122
|
-
in the :ref:`cmdstanpy_refitting` notebook.
|
|
123
|
-
|
|
124
|
-
Warnings
|
|
125
|
-
--------
|
|
126
|
-
Sampling wrappers are an experimental feature in a very early stage. Please use them
|
|
127
|
-
with caution.
|
|
128
|
-
"""
|
|
129
|
-
|
|
130
|
-
def __init__(self, data_file, **kwargs):
|
|
131
|
-
"""Initialize the CmdStanPySamplingWrapper.
|
|
132
|
-
|
|
133
|
-
Parameters
|
|
134
|
-
----------
|
|
135
|
-
data_file : str
|
|
136
|
-
Filename on which to store the data for every refit.
|
|
137
|
-
It's contents will be overwritten.
|
|
138
|
-
"""
|
|
139
|
-
super().__init__(**kwargs)
|
|
140
|
-
self.data_file = data_file
|
|
141
|
-
|
|
142
|
-
def sample(self, modified_observed_data):
|
|
143
|
-
"""Resample cmdstanpy model on modified_observed_data."""
|
|
144
|
-
from cmdstanpy import write_stan_json
|
|
145
|
-
|
|
146
|
-
write_stan_json(self.data_file, modified_observed_data)
|
|
147
|
-
fit = self.model.sample(**{**self.sample_kwargs, "data": self.data_file})
|
|
148
|
-
return fit
|