arviz 0.23.3__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -367
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.3.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.3.dist-info/METADATA +0 -264
  184. arviz-0.23.3.dist-info/RECORD +0 -183
  185. arviz-0.23.3.dist-info/top_level.txt +0 -1
@@ -1,13 +0,0 @@
1
- """Sampling wrappers."""
2
-
3
- from .base import SamplingWrapper
4
- from .wrap_stan import PyStan2SamplingWrapper, PyStanSamplingWrapper, CmdStanPySamplingWrapper
5
- from .wrap_pymc import PyMCSamplingWrapper
6
-
7
- __all__ = [
8
- "CmdStanPySamplingWrapper",
9
- "PyMCSamplingWrapper",
10
- "PyStan2SamplingWrapper",
11
- "PyStanSamplingWrapper",
12
- "SamplingWrapper",
13
- ]
arviz/wrappers/base.py DELETED
@@ -1,236 +0,0 @@
1
- # pylint: disable=too-many-instance-attributes,too-many-arguments
2
- """Base class for sampling wrappers."""
3
- from xarray import apply_ufunc
4
-
5
- # from ..data import InferenceData
6
- from ..stats import wrap_xarray_ufunc as _wrap_xarray_ufunc
7
-
8
-
9
- class SamplingWrapper:
10
- """Class wrapping sampling routines for its usage via ArviZ.
11
-
12
- Using a common class, all inference backends can be supported in ArviZ. Hence, statistical
13
- functions requiring refitting like Leave Future Out or Simulation Based Calibration can be
14
- performed from ArviZ.
15
-
16
- For usage examples see user guide pages on :ref:`wrapper_guide`.See other
17
- SamplingWrapper classes at :ref:`wrappers api section <wrappers_api>`.
18
-
19
- Parameters
20
- ----------
21
- model
22
- The model object used for sampling.
23
- idata_orig : InferenceData, optional
24
- Original InferenceData object.
25
- log_lik_fun : callable, optional
26
- For simple cases where the pointwise log likelihood is a Python function, this
27
- function will be used to calculate the log likelihood. Otherwise,
28
- ``point_log_likelihood`` method must be implemented. It's callback must be
29
- ``log_lik_fun(*args, **log_lik_kwargs)`` and will be called using
30
- :func:`wrap_xarray_ufunc` or :func:`xarray:xarray.apply_ufunc` depending
31
- on the value of `is_ufunc`.
32
-
33
- For more details on ``args`` or ``log_lik_kwargs`` see the notes and
34
- parameters ``posterior_vars`` and ``log_lik_kwargs``.
35
- is_ufunc : bool, default True
36
- If True, call ``log_lik_fun`` using :func:`xarray:xarray.apply_ufunc` otherwise
37
- use :func:`wrap_xarray_ufunc`.
38
- posterior_vars : list of str, optional
39
- List of variable names to unpack as ``args`` for ``log_lik_fun``. Each string in
40
- the list will be used to retrieve a DataArray from the Dataset in the posterior
41
- group and passed to ``log_lik_fun``.
42
- sample_kwargs : dict, optional
43
- Sampling kwargs are stored as class attributes for their usage in the ``sample``
44
- method.
45
- idata_kwargs : dict, optional
46
- kwargs are stored as class attributes to be used in the ``get_inference_data`` method.
47
- log_lik_kwargs : dict, optional
48
- Keyword arguments passed to ``log_lik_fun``.
49
- apply_ufunc_kwargs : dict, optional
50
- Passed to :func:`xarray:xarray.apply_ufunc` or :func:`wrap_xarray_ufunc`.
51
-
52
-
53
- Warnings
54
- --------
55
- Sampling wrappers are an experimental feature in a very early stage. Please use them
56
- with caution.
57
-
58
- Notes
59
- -----
60
- Example of ``log_like_fun`` usage.
61
- """
62
-
63
- def __init__(
64
- self,
65
- model,
66
- idata_orig=None,
67
- log_lik_fun=None,
68
- is_ufunc=True,
69
- posterior_vars=None,
70
- sample_kwargs=None,
71
- idata_kwargs=None,
72
- log_lik_kwargs=None,
73
- apply_ufunc_kwargs=None,
74
- ):
75
- self.model = model
76
-
77
- # if not isinstance(idata_orig, InferenceData) or idata_orig is not None:
78
- # raise TypeError("idata_orig must be of InferenceData type or None")
79
- self.idata_orig = idata_orig
80
-
81
- if log_lik_fun is None or callable(log_lik_fun):
82
- self.log_lik_fun = log_lik_fun
83
- self.is_ufunc = is_ufunc
84
- self.posterior_vars = posterior_vars
85
- else:
86
- raise TypeError("log_like_fun must be a callable object or None")
87
-
88
- self.sample_kwargs = {} if sample_kwargs is None else sample_kwargs
89
- self.idata_kwargs = {} if idata_kwargs is None else idata_kwargs
90
- self.log_lik_kwargs = {} if log_lik_kwargs is None else log_lik_kwargs
91
- self.apply_ufunc_kwargs = {} if apply_ufunc_kwargs is None else apply_ufunc_kwargs
92
-
93
- def sel_observations(self, idx):
94
- """Select a subset of the observations in idata_orig.
95
-
96
- **Not implemented**: This method must be implemented by the SamplingWrapper subclasses.
97
- It is documented here to show its format and call signature.
98
-
99
- Parameters
100
- ----------
101
- idx
102
- Indexes to separate from the rest of the observed data.
103
-
104
- Returns
105
- -------
106
- modified_observed_data
107
- Observed data whose index is *not* ``idx``
108
- excluded_observed_data
109
- Observed data whose index is ``idx``
110
- """
111
- raise NotImplementedError("sel_observations method must be implemented for each subclass")
112
-
113
- def sample(self, modified_observed_data):
114
- """Sample ``self.model`` on the ``modified_observed_data`` subset.
115
-
116
- **Not implemented**: This method must be implemented by the SamplingWrapper subclasses.
117
- It is documented here to show its format and call signature.
118
-
119
- Parameters
120
- ----------
121
- modified_observed_data
122
- Data to fit the model on.
123
-
124
- Returns
125
- -------
126
- fitted_model
127
- Result of the fit.
128
- """
129
- raise NotImplementedError("sample method must be implemented for each subclass")
130
-
131
- def get_inference_data(self, fitted_model):
132
- """Convert the ``fitted_model`` to an InferenceData object.
133
-
134
- **Not implemented**: This method must be implemented by the SamplingWrapper subclasses.
135
- It is documented here to show its format and call signature.
136
-
137
- Parameters
138
- ----------
139
- fitted_model
140
- Result of the current fit.
141
-
142
- Returns
143
- -------
144
- idata_current: InferenceData
145
- InferenceData object containing the samples in ``fitted_model``
146
- """
147
- raise NotImplementedError("get_inference_data method must be implemented for each subclass")
148
-
149
- def log_likelihood__i(self, excluded_obs, idata__i):
150
- r"""Get the log likelilhood samples :math:`\log p_{post(-i)}(y_i)`.
151
-
152
- Calculate the log likelihood of the data contained in excluded_obs using the
153
- model fitted with this data excluded, the results of which are stored in ``idata__i``.
154
-
155
- Parameters
156
- ----------
157
- excluded_obs
158
- Observations for which to calculate their log likelihood. The second item from
159
- the tuple returned by `sel_observations` is passed as this argument.
160
- idata__i: InferenceData
161
- Inference results of refitting the data excluding some observations. The
162
- result of `get_inference_data` is used as this argument.
163
-
164
- Returns
165
- -------
166
- log_likelihood: xr.Dataarray
167
- Log likelihood of ``excluded_obs`` evaluated at each of the posterior samples
168
- stored in ``idata__i``.
169
- """
170
- if self.log_lik_fun is None:
171
- raise NotImplementedError(
172
- "When `log_like_fun` is not set during class initialization "
173
- "log_likelihood__i method must be overwritten"
174
- )
175
- posterior = idata__i.posterior
176
- arys = (*excluded_obs, *[posterior[var_name] for var_name in self.posterior_vars])
177
- ufunc_applier = apply_ufunc if self.is_ufunc else _wrap_xarray_ufunc
178
- log_lik_idx = ufunc_applier(
179
- self.log_lik_fun,
180
- *arys,
181
- kwargs=self.log_lik_kwargs,
182
- **self.apply_ufunc_kwargs,
183
- )
184
- return log_lik_idx
185
-
186
- def _check_method_is_implemented(self, method, *args):
187
- """Check a given method is implemented."""
188
- try:
189
- getattr(self, method)(*args)
190
- except NotImplementedError:
191
- return False
192
- except: # pylint: disable=bare-except
193
- return True
194
- return True
195
-
196
- def check_implemented_methods(self, methods):
197
- """Check that all methods listed are implemented.
198
-
199
- Not all functions that require refitting need to have all the methods implemented in
200
- order to work properly. This function should be used before using the SamplingWrapper and
201
- its subclasses to get informative error messages.
202
-
203
- Parameters
204
- ----------
205
- methods: list
206
- Check all elements in methods are implemented.
207
-
208
- Returns
209
- -------
210
- List with all non implemented methods
211
- """
212
- supported_methods_1arg = (
213
- "sel_observations",
214
- "sample",
215
- "get_inference_data",
216
- )
217
- supported_methods_2args = ("log_likelihood__i",)
218
- supported_methods = [*supported_methods_1arg, *supported_methods_2args]
219
- bad_methods = [method for method in methods if method not in supported_methods]
220
- if bad_methods:
221
- raise ValueError(
222
- f"Not all method(s) in {bad_methods} supported. "
223
- f"Supported methods in SamplingWrapper subclasses are:{supported_methods}"
224
- )
225
-
226
- not_implemented = []
227
- for method in methods:
228
- if method in supported_methods_1arg:
229
- if self._check_method_is_implemented(method, 1):
230
- continue
231
- not_implemented.append(method)
232
- elif method in supported_methods_2args:
233
- if self._check_method_is_implemented(method, 1, 1):
234
- continue
235
- not_implemented.append(method)
236
- return not_implemented
@@ -1,36 +0,0 @@
1
- # pylint: disable=arguments-differ
2
- """Base class for PyMC interface wrappers."""
3
- from .base import SamplingWrapper
4
-
5
-
6
- # pylint: disable=abstract-method
7
- class PyMCSamplingWrapper(SamplingWrapper):
8
- """PyMC (4.0+) sampling wrapper base class.
9
-
10
- See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
11
- description. An example of ``PyMCSamplingWrapper`` usage can be found
12
- in the :ref:`pymc_refitting` notebook.
13
-
14
- Warnings
15
- --------
16
- Sampling wrappers are an experimental feature in a very early stage. Please use them
17
- with caution.
18
- """
19
-
20
- def sample(self, modified_observed_data):
21
- """Update data and sample model on modified_observed_data."""
22
- import pymc # pylint: disable=import-error
23
-
24
- with self.model:
25
- pymc.set_data(modified_observed_data)
26
- idata = pymc.sample(
27
- **self.sample_kwargs,
28
- )
29
- return idata
30
-
31
- def get_inference_data(self, fitted_model):
32
- """Return sampling result without modifying.
33
-
34
- PyMC sampling already returns and InferenceData object.
35
- """
36
- return fitted_model
@@ -1,148 +0,0 @@
1
- # pylint: disable=arguments-differ
2
- """Base class for Stan interface wrappers."""
3
- from typing import Union
4
-
5
- from ..data import from_cmdstanpy, from_pystan
6
- from .base import SamplingWrapper
7
-
8
-
9
- # pylint: disable=abstract-method
10
- class StanSamplingWrapper(SamplingWrapper):
11
- """Stan sampling wrapper base class.
12
-
13
- See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
14
- description. An example of ``PyStanSamplingWrapper`` usage can be found
15
- in the :ref:`pystan_refitting` notebook. For usage examples of other wrappers
16
- see the user guide pages on :ref:`wrapper_guide`.
17
-
18
- Warnings
19
- --------
20
- Sampling wrappers are an experimental feature in a very early stage. Please use them
21
- with caution.
22
-
23
- See Also
24
- --------
25
- SamplingWrapper
26
- """
27
-
28
- def sel_observations(self, idx):
29
- """Select a subset of the observations in idata_orig.
30
-
31
- **Not implemented**: This method must be implemented on a model basis.
32
- It is documented here to show its format and call signature.
33
-
34
- Parameters
35
- ----------
36
- idx
37
- Indexes to separate from the rest of the observed data.
38
-
39
- Returns
40
- -------
41
- modified_observed_data : dict
42
- Dictionary containing both excluded and included data but properly divided
43
- in the different keys. Passed to ``data`` argument of ``model.sampling``.
44
- excluded_observed_data : str
45
- Variable name containing the pointwise log likelihood data of the excluded
46
- data. As PyStan cannot call C++ functions and log_likelihood__i is already
47
- calculated *during* the simulation, instead of the value on which to evaluate
48
- the likelihood, ``log_likelihood__i`` expects a string so it can extract the
49
- corresponding data from the InferenceData object.
50
- """
51
- raise NotImplementedError("sel_observations must be implemented on a model basis")
52
-
53
- def get_inference_data(self, fitted_model): # pylint: disable=arguments-renamed
54
- """Convert the fit object returned by ``self.sample`` to InferenceData."""
55
- if fitted_model.__class__.__name__ == "CmdStanMCMC":
56
- idata = from_cmdstanpy(posterior=fitted_model, **self.idata_kwargs)
57
- else:
58
- idata = from_pystan(posterior=fitted_model, **self.idata_kwargs)
59
- return idata
60
-
61
- def log_likelihood__i(self, excluded_obs, idata__i):
62
- """Retrieve the log likelihood of the excluded observations from ``idata__i``."""
63
- return idata__i.log_likelihood[excluded_obs]
64
-
65
-
66
- class PyStan2SamplingWrapper(StanSamplingWrapper):
67
- """PyStan (2.x) sampling wrapper base class.
68
-
69
- See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
70
- description. An example of ``PyStanSamplingWrapper`` usage can be found
71
- in the :ref:`pystan_refitting` notebook. For usage examples of other wrappers
72
- see the user guide pages on :ref:`wrapper_guide`.
73
-
74
- Warnings
75
- --------
76
- Sampling wrappers are an experimental feature in a very early stage. Please use them
77
- with caution.
78
-
79
- See Also
80
- --------
81
- SamplingWrapper
82
- """
83
-
84
- def sample(self, modified_observed_data):
85
- """Resample the PyStan model stored in self.model on modified_observed_data."""
86
- fit = self.model.sampling(data=modified_observed_data, **self.sample_kwargs)
87
- return fit
88
-
89
-
90
- class PyStanSamplingWrapper(StanSamplingWrapper):
91
- """PyStan (3.0+) sampling wrapper base class.
92
-
93
- See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
94
- description. An example of ``PyStanSamplingWrapper`` usage can be found
95
- in the :ref:`pystan_refitting` notebook.
96
-
97
- Warnings
98
- --------
99
- Sampling wrappers are an experimental feature in a very early stage. Please use them
100
- with caution.
101
- """
102
-
103
- def sample(self, modified_observed_data):
104
- """Rebuild and resample the PyStan model on modified_observed_data."""
105
- import stan # pylint: disable=import-error,import-outside-toplevel
106
-
107
- self.model: Union[str, stan.Model]
108
- if isinstance(self.model, str):
109
- program_code = self.model
110
- else:
111
- program_code = self.model.program_code
112
- self.model = stan.build(program_code, data=modified_observed_data)
113
- fit = self.model.sample(**self.sample_kwargs)
114
- return fit
115
-
116
-
117
- class CmdStanPySamplingWrapper(StanSamplingWrapper):
118
- """CmdStanPy sampling wrapper base class.
119
-
120
- See the documentation on :class:`~arviz.SamplingWrapper` for a more detailed
121
- description. An example of ``CmdStanPySamplingWrapper`` usage can be found
122
- in the :ref:`cmdstanpy_refitting` notebook.
123
-
124
- Warnings
125
- --------
126
- Sampling wrappers are an experimental feature in a very early stage. Please use them
127
- with caution.
128
- """
129
-
130
- def __init__(self, data_file, **kwargs):
131
- """Initialize the CmdStanPySamplingWrapper.
132
-
133
- Parameters
134
- ----------
135
- data_file : str
136
- Filename on which to store the data for every refit.
137
- It's contents will be overwritten.
138
- """
139
- super().__init__(**kwargs)
140
- self.data_file = data_file
141
-
142
- def sample(self, modified_observed_data):
143
- """Resample cmdstanpy model on modified_observed_data."""
144
- from cmdstanpy import write_stan_json
145
-
146
- write_stan_json(self.data_file, modified_observed_data)
147
- fit = self.model.sample(**{**self.sample_kwargs, "data": self.data_file})
148
- return fit