arviz 0.23.1__py3-none-any.whl → 1.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. arviz/__init__.py +52 -357
  2. arviz-1.0.0rc0.dist-info/METADATA +182 -0
  3. arviz-1.0.0rc0.dist-info/RECORD +5 -0
  4. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/WHEEL +1 -2
  5. {arviz-0.23.1.dist-info → arviz-1.0.0rc0.dist-info}/licenses/LICENSE +0 -1
  6. arviz/data/__init__.py +0 -55
  7. arviz/data/base.py +0 -596
  8. arviz/data/converters.py +0 -203
  9. arviz/data/datasets.py +0 -161
  10. arviz/data/example_data/code/radon/radon.json +0 -326
  11. arviz/data/example_data/data/centered_eight.nc +0 -0
  12. arviz/data/example_data/data/non_centered_eight.nc +0 -0
  13. arviz/data/example_data/data_local.json +0 -12
  14. arviz/data/example_data/data_remote.json +0 -58
  15. arviz/data/inference_data.py +0 -2386
  16. arviz/data/io_beanmachine.py +0 -112
  17. arviz/data/io_cmdstan.py +0 -1036
  18. arviz/data/io_cmdstanpy.py +0 -1233
  19. arviz/data/io_datatree.py +0 -23
  20. arviz/data/io_dict.py +0 -462
  21. arviz/data/io_emcee.py +0 -317
  22. arviz/data/io_json.py +0 -54
  23. arviz/data/io_netcdf.py +0 -68
  24. arviz/data/io_numpyro.py +0 -497
  25. arviz/data/io_pyjags.py +0 -378
  26. arviz/data/io_pyro.py +0 -333
  27. arviz/data/io_pystan.py +0 -1095
  28. arviz/data/io_zarr.py +0 -46
  29. arviz/data/utils.py +0 -139
  30. arviz/labels.py +0 -210
  31. arviz/plots/__init__.py +0 -61
  32. arviz/plots/autocorrplot.py +0 -171
  33. arviz/plots/backends/__init__.py +0 -223
  34. arviz/plots/backends/bokeh/__init__.py +0 -166
  35. arviz/plots/backends/bokeh/autocorrplot.py +0 -101
  36. arviz/plots/backends/bokeh/bfplot.py +0 -23
  37. arviz/plots/backends/bokeh/bpvplot.py +0 -193
  38. arviz/plots/backends/bokeh/compareplot.py +0 -167
  39. arviz/plots/backends/bokeh/densityplot.py +0 -239
  40. arviz/plots/backends/bokeh/distcomparisonplot.py +0 -23
  41. arviz/plots/backends/bokeh/distplot.py +0 -183
  42. arviz/plots/backends/bokeh/dotplot.py +0 -113
  43. arviz/plots/backends/bokeh/ecdfplot.py +0 -73
  44. arviz/plots/backends/bokeh/elpdplot.py +0 -203
  45. arviz/plots/backends/bokeh/energyplot.py +0 -155
  46. arviz/plots/backends/bokeh/essplot.py +0 -176
  47. arviz/plots/backends/bokeh/forestplot.py +0 -772
  48. arviz/plots/backends/bokeh/hdiplot.py +0 -54
  49. arviz/plots/backends/bokeh/kdeplot.py +0 -268
  50. arviz/plots/backends/bokeh/khatplot.py +0 -163
  51. arviz/plots/backends/bokeh/lmplot.py +0 -185
  52. arviz/plots/backends/bokeh/loopitplot.py +0 -211
  53. arviz/plots/backends/bokeh/mcseplot.py +0 -184
  54. arviz/plots/backends/bokeh/pairplot.py +0 -328
  55. arviz/plots/backends/bokeh/parallelplot.py +0 -81
  56. arviz/plots/backends/bokeh/posteriorplot.py +0 -324
  57. arviz/plots/backends/bokeh/ppcplot.py +0 -379
  58. arviz/plots/backends/bokeh/rankplot.py +0 -149
  59. arviz/plots/backends/bokeh/separationplot.py +0 -107
  60. arviz/plots/backends/bokeh/traceplot.py +0 -436
  61. arviz/plots/backends/bokeh/violinplot.py +0 -164
  62. arviz/plots/backends/matplotlib/__init__.py +0 -124
  63. arviz/plots/backends/matplotlib/autocorrplot.py +0 -72
  64. arviz/plots/backends/matplotlib/bfplot.py +0 -78
  65. arviz/plots/backends/matplotlib/bpvplot.py +0 -177
  66. arviz/plots/backends/matplotlib/compareplot.py +0 -135
  67. arviz/plots/backends/matplotlib/densityplot.py +0 -194
  68. arviz/plots/backends/matplotlib/distcomparisonplot.py +0 -119
  69. arviz/plots/backends/matplotlib/distplot.py +0 -178
  70. arviz/plots/backends/matplotlib/dotplot.py +0 -116
  71. arviz/plots/backends/matplotlib/ecdfplot.py +0 -70
  72. arviz/plots/backends/matplotlib/elpdplot.py +0 -189
  73. arviz/plots/backends/matplotlib/energyplot.py +0 -113
  74. arviz/plots/backends/matplotlib/essplot.py +0 -180
  75. arviz/plots/backends/matplotlib/forestplot.py +0 -656
  76. arviz/plots/backends/matplotlib/hdiplot.py +0 -48
  77. arviz/plots/backends/matplotlib/kdeplot.py +0 -177
  78. arviz/plots/backends/matplotlib/khatplot.py +0 -241
  79. arviz/plots/backends/matplotlib/lmplot.py +0 -149
  80. arviz/plots/backends/matplotlib/loopitplot.py +0 -144
  81. arviz/plots/backends/matplotlib/mcseplot.py +0 -161
  82. arviz/plots/backends/matplotlib/pairplot.py +0 -355
  83. arviz/plots/backends/matplotlib/parallelplot.py +0 -58
  84. arviz/plots/backends/matplotlib/posteriorplot.py +0 -348
  85. arviz/plots/backends/matplotlib/ppcplot.py +0 -478
  86. arviz/plots/backends/matplotlib/rankplot.py +0 -119
  87. arviz/plots/backends/matplotlib/separationplot.py +0 -97
  88. arviz/plots/backends/matplotlib/traceplot.py +0 -526
  89. arviz/plots/backends/matplotlib/tsplot.py +0 -121
  90. arviz/plots/backends/matplotlib/violinplot.py +0 -148
  91. arviz/plots/bfplot.py +0 -128
  92. arviz/plots/bpvplot.py +0 -308
  93. arviz/plots/compareplot.py +0 -177
  94. arviz/plots/densityplot.py +0 -284
  95. arviz/plots/distcomparisonplot.py +0 -197
  96. arviz/plots/distplot.py +0 -233
  97. arviz/plots/dotplot.py +0 -233
  98. arviz/plots/ecdfplot.py +0 -372
  99. arviz/plots/elpdplot.py +0 -174
  100. arviz/plots/energyplot.py +0 -147
  101. arviz/plots/essplot.py +0 -319
  102. arviz/plots/forestplot.py +0 -304
  103. arviz/plots/hdiplot.py +0 -211
  104. arviz/plots/kdeplot.py +0 -357
  105. arviz/plots/khatplot.py +0 -236
  106. arviz/plots/lmplot.py +0 -380
  107. arviz/plots/loopitplot.py +0 -224
  108. arviz/plots/mcseplot.py +0 -194
  109. arviz/plots/pairplot.py +0 -281
  110. arviz/plots/parallelplot.py +0 -204
  111. arviz/plots/plot_utils.py +0 -599
  112. arviz/plots/posteriorplot.py +0 -298
  113. arviz/plots/ppcplot.py +0 -369
  114. arviz/plots/rankplot.py +0 -232
  115. arviz/plots/separationplot.py +0 -167
  116. arviz/plots/styles/arviz-bluish.mplstyle +0 -1
  117. arviz/plots/styles/arviz-brownish.mplstyle +0 -1
  118. arviz/plots/styles/arviz-colors.mplstyle +0 -2
  119. arviz/plots/styles/arviz-cyanish.mplstyle +0 -1
  120. arviz/plots/styles/arviz-darkgrid.mplstyle +0 -40
  121. arviz/plots/styles/arviz-doc.mplstyle +0 -88
  122. arviz/plots/styles/arviz-docgrid.mplstyle +0 -88
  123. arviz/plots/styles/arviz-grayscale.mplstyle +0 -41
  124. arviz/plots/styles/arviz-greenish.mplstyle +0 -1
  125. arviz/plots/styles/arviz-orangish.mplstyle +0 -1
  126. arviz/plots/styles/arviz-plasmish.mplstyle +0 -1
  127. arviz/plots/styles/arviz-purplish.mplstyle +0 -1
  128. arviz/plots/styles/arviz-redish.mplstyle +0 -1
  129. arviz/plots/styles/arviz-royish.mplstyle +0 -1
  130. arviz/plots/styles/arviz-viridish.mplstyle +0 -1
  131. arviz/plots/styles/arviz-white.mplstyle +0 -40
  132. arviz/plots/styles/arviz-whitegrid.mplstyle +0 -40
  133. arviz/plots/traceplot.py +0 -273
  134. arviz/plots/tsplot.py +0 -440
  135. arviz/plots/violinplot.py +0 -192
  136. arviz/preview.py +0 -58
  137. arviz/py.typed +0 -0
  138. arviz/rcparams.py +0 -606
  139. arviz/sel_utils.py +0 -223
  140. arviz/static/css/style.css +0 -340
  141. arviz/static/html/icons-svg-inline.html +0 -15
  142. arviz/stats/__init__.py +0 -37
  143. arviz/stats/density_utils.py +0 -1013
  144. arviz/stats/diagnostics.py +0 -1013
  145. arviz/stats/ecdf_utils.py +0 -324
  146. arviz/stats/stats.py +0 -2422
  147. arviz/stats/stats_refitting.py +0 -119
  148. arviz/stats/stats_utils.py +0 -609
  149. arviz/tests/__init__.py +0 -1
  150. arviz/tests/base_tests/__init__.py +0 -1
  151. arviz/tests/base_tests/test_data.py +0 -1679
  152. arviz/tests/base_tests/test_data_zarr.py +0 -143
  153. arviz/tests/base_tests/test_diagnostics.py +0 -511
  154. arviz/tests/base_tests/test_diagnostics_numba.py +0 -87
  155. arviz/tests/base_tests/test_helpers.py +0 -18
  156. arviz/tests/base_tests/test_labels.py +0 -69
  157. arviz/tests/base_tests/test_plot_utils.py +0 -342
  158. arviz/tests/base_tests/test_plots_bokeh.py +0 -1288
  159. arviz/tests/base_tests/test_plots_matplotlib.py +0 -2197
  160. arviz/tests/base_tests/test_rcparams.py +0 -317
  161. arviz/tests/base_tests/test_stats.py +0 -925
  162. arviz/tests/base_tests/test_stats_ecdf_utils.py +0 -166
  163. arviz/tests/base_tests/test_stats_numba.py +0 -45
  164. arviz/tests/base_tests/test_stats_utils.py +0 -384
  165. arviz/tests/base_tests/test_utils.py +0 -376
  166. arviz/tests/base_tests/test_utils_numba.py +0 -87
  167. arviz/tests/conftest.py +0 -46
  168. arviz/tests/external_tests/__init__.py +0 -1
  169. arviz/tests/external_tests/test_data_beanmachine.py +0 -78
  170. arviz/tests/external_tests/test_data_cmdstan.py +0 -398
  171. arviz/tests/external_tests/test_data_cmdstanpy.py +0 -496
  172. arviz/tests/external_tests/test_data_emcee.py +0 -166
  173. arviz/tests/external_tests/test_data_numpyro.py +0 -434
  174. arviz/tests/external_tests/test_data_pyjags.py +0 -119
  175. arviz/tests/external_tests/test_data_pyro.py +0 -260
  176. arviz/tests/external_tests/test_data_pystan.py +0 -307
  177. arviz/tests/helpers.py +0 -677
  178. arviz/utils.py +0 -773
  179. arviz/wrappers/__init__.py +0 -13
  180. arviz/wrappers/base.py +0 -236
  181. arviz/wrappers/wrap_pymc.py +0 -36
  182. arviz/wrappers/wrap_stan.py +0 -148
  183. arviz-0.23.1.dist-info/METADATA +0 -263
  184. arviz-0.23.1.dist-info/RECORD +0 -183
  185. arviz-0.23.1.dist-info/top_level.txt +0 -1
@@ -1,143 +0,0 @@
1
- # pylint: disable=redefined-outer-name
2
- import os
3
- from collections.abc import MutableMapping
4
- from tempfile import TemporaryDirectory
5
- from typing import Mapping
6
-
7
- import numpy as np
8
- import pytest
9
-
10
- from ... import InferenceData, from_dict
11
- from ... import to_zarr, from_zarr
12
-
13
- from ..helpers import ( # pylint: disable=unused-import
14
- chains,
15
- check_multiple_attrs,
16
- draws,
17
- eight_schools_params,
18
- importorskip,
19
- )
20
-
21
- zarr = importorskip("zarr") # pylint: disable=invalid-name
22
-
23
-
24
- class TestDataZarr:
25
- @pytest.fixture(scope="class")
26
- def data(self, draws, chains):
27
- class Data:
28
- # fake 8-school output
29
- shapes: Mapping[str, list] = {"mu": [], "tau": [], "eta": [8], "theta": [8]}
30
- obj = {key: np.random.randn(chains, draws, *shape) for key, shape in shapes.items()}
31
-
32
- return Data
33
-
34
- def get_inference_data(self, data, eight_schools_params, fill_attrs):
35
- return from_dict(
36
- posterior=data.obj,
37
- posterior_predictive=data.obj,
38
- sample_stats=data.obj,
39
- prior=data.obj,
40
- prior_predictive=data.obj,
41
- sample_stats_prior=data.obj,
42
- observed_data=eight_schools_params,
43
- coords={"school": np.arange(8)},
44
- dims={"theta": ["school"], "eta": ["school"]},
45
- attrs={"test": 1} if fill_attrs else None,
46
- )
47
-
48
- @pytest.mark.parametrize("store", [0, 1, 2])
49
- @pytest.mark.parametrize("fill_attrs", [True, False])
50
- def test_io_method(self, data, eight_schools_params, store, fill_attrs):
51
- # create InferenceData and check it has been properly created
52
- inference_data = self.get_inference_data( # pylint: disable=W0612
53
- data, eight_schools_params, fill_attrs
54
- )
55
- test_dict = {
56
- "posterior": ["eta", "theta", "mu", "tau"],
57
- "posterior_predictive": ["eta", "theta", "mu", "tau"],
58
- "sample_stats": ["eta", "theta", "mu", "tau"],
59
- "prior": ["eta", "theta", "mu", "tau"],
60
- "prior_predictive": ["eta", "theta", "mu", "tau"],
61
- "sample_stats_prior": ["eta", "theta", "mu", "tau"],
62
- "observed_data": ["J", "y", "sigma"],
63
- }
64
- fails = check_multiple_attrs(test_dict, inference_data)
65
- assert not fails
66
-
67
- if fill_attrs:
68
- assert inference_data.attrs["test"] == 1
69
- else:
70
- assert "test" not in inference_data.attrs
71
-
72
- # check filename does not exist and use to_zarr method
73
- with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
74
- filepath = os.path.join(tmp_dir, "zarr")
75
-
76
- # InferenceData method
77
- if store == 0:
78
- # Tempdir
79
- store = inference_data.to_zarr(store=None)
80
- assert isinstance(store, MutableMapping)
81
- elif store == 1:
82
- inference_data.to_zarr(store=filepath)
83
- # assert file has been saved correctly
84
- assert os.path.exists(filepath)
85
- assert os.path.getsize(filepath) > 0
86
- elif store == 2:
87
- store = zarr.storage.DirectoryStore(filepath)
88
- inference_data.to_zarr(store=store)
89
- # assert file has been saved correctly
90
- assert os.path.exists(filepath)
91
- assert os.path.getsize(filepath) > 0
92
-
93
- if isinstance(store, MutableMapping):
94
- inference_data2 = InferenceData.from_zarr(store)
95
- else:
96
- inference_data2 = InferenceData.from_zarr(filepath)
97
-
98
- # Everything in dict still available in inference_data2 ?
99
- fails = check_multiple_attrs(test_dict, inference_data2)
100
- assert not fails
101
-
102
- if fill_attrs:
103
- assert inference_data2.attrs["test"] == 1
104
- else:
105
- assert "test" not in inference_data2.attrs
106
-
107
- def test_io_function(self, data, eight_schools_params):
108
- # create InferenceData and check it has been properly created
109
- inference_data = self.get_inference_data( # pylint: disable=W0612
110
- data,
111
- eight_schools_params,
112
- fill_attrs=True,
113
- )
114
- test_dict = {
115
- "posterior": ["eta", "theta", "mu", "tau"],
116
- "posterior_predictive": ["eta", "theta", "mu", "tau"],
117
- "sample_stats": ["eta", "theta", "mu", "tau"],
118
- "prior": ["eta", "theta", "mu", "tau"],
119
- "prior_predictive": ["eta", "theta", "mu", "tau"],
120
- "sample_stats_prior": ["eta", "theta", "mu", "tau"],
121
- "observed_data": ["J", "y", "sigma"],
122
- }
123
- fails = check_multiple_attrs(test_dict, inference_data)
124
- assert not fails
125
-
126
- assert inference_data.attrs["test"] == 1
127
-
128
- # check filename does not exist and use to_zarr method
129
- with TemporaryDirectory(prefix="arviz_tests_") as tmp_dir:
130
- filepath = os.path.join(tmp_dir, "zarr")
131
-
132
- to_zarr(inference_data, store=filepath)
133
- # assert file has been saved correctly
134
- assert os.path.exists(filepath)
135
- assert os.path.getsize(filepath) > 0
136
-
137
- inference_data2 = from_zarr(filepath)
138
-
139
- # Everything in dict still available in inference_data2 ?
140
- fails = check_multiple_attrs(test_dict, inference_data2)
141
- assert not fails
142
-
143
- assert inference_data2.attrs["test"] == 1
@@ -1,511 +0,0 @@
1
- """Test Diagnostic methods"""
2
-
3
- # pylint: disable=redefined-outer-name, no-member, too-many-public-methods
4
- import os
5
-
6
- import numpy as np
7
- import packaging
8
- import pandas as pd
9
- import pytest
10
- import scipy
11
- from numpy.testing import assert_almost_equal
12
-
13
- from ...data import from_cmdstan, load_arviz_data
14
- from ...rcparams import rcParams
15
- from ...sel_utils import xarray_var_iter
16
- from ...stats import bfmi, ess, mcse, rhat
17
- from ...stats.diagnostics import (
18
- _ess,
19
- _ess_quantile,
20
- _mc_error,
21
- _mcse_quantile,
22
- _multichain_statistics,
23
- _rhat,
24
- _rhat_rank,
25
- _split_chains,
26
- _z_scale,
27
- ks_summary,
28
- )
29
-
30
- # For tests only, recommended value should be closer to 1.01-1.05
31
- # See discussion in https://github.com/stan-dev/rstan/pull/618
32
- GOOD_RHAT = 1.1
33
-
34
- rcParams["data.load"] = "eager"
35
-
36
-
37
- @pytest.fixture(scope="session")
38
- def data():
39
- centered_eight = load_arviz_data("centered_eight")
40
- return centered_eight.posterior
41
-
42
-
43
- class TestDiagnostics:
44
- def test_bfmi(self):
45
- energy = np.array([1, 2, 3, 4])
46
- assert_almost_equal(bfmi(energy), 0.6)
47
-
48
- def test_bfmi_dataset(self):
49
- data = load_arviz_data("centered_eight")
50
- assert bfmi(data).all()
51
-
52
- def test_bfmi_dataset_bad(self):
53
- data = load_arviz_data("centered_eight")
54
- del data.sample_stats["energy"]
55
- with pytest.raises(TypeError):
56
- bfmi(data)
57
-
58
- def test_bfmi_correctly_transposed(self):
59
- data = load_arviz_data("centered_eight")
60
- vals1 = bfmi(data)
61
- data.sample_stats["energy"] = data.sample_stats["energy"].T
62
- vals2 = bfmi(data)
63
- assert_almost_equal(vals1, vals2)
64
-
65
- def test_deterministic(self):
66
- """
67
- Test algorithm against posterior (R) convergence functions.
68
-
69
- posterior: https://github.com/stan-dev/posterior
70
- R code:
71
- ```
72
- library("posterior")
73
- data2 <- read.csv("blocker.2.csv", comment.char = "#")
74
- data1 <- read.csv("blocker.1.csv", comment.char = "#")
75
- output <- matrix(ncol=17, nrow=length(names(data1))-4)
76
- j = 0
77
- for (i in 1:length(names(data1))) {
78
- name = names(data1)[i]
79
- ary = matrix(c(data1[,name], data2[,name]), 1000, 2)
80
- if (!endsWith(name, "__"))
81
- j <- j + 1
82
- output[j,] <- c(
83
- posterior::rhat(ary),
84
- posterior::rhat_basic(ary, FALSE),
85
- posterior::ess_bulk(ary),
86
- posterior::ess_tail(ary),
87
- posterior::ess_mean(ary),
88
- posterior::ess_sd(ary),
89
- posterior::ess_median(ary),
90
- posterior::ess_basic(ary, FALSE),
91
- posterior::ess_quantile(ary, 0.01),
92
- posterior::ess_quantile(ary, 0.1),
93
- posterior::ess_quantile(ary, 0.3),
94
- posterior::mcse_mean(ary),
95
- posterior::mcse_sd(ary),
96
- posterior::mcse_median(ary),
97
- posterior::mcse_quantile(ary, prob=0.01),
98
- posterior::mcse_quantile(ary, prob=0.1),
99
- posterior::mcse_quantile(ary, prob=0.3))
100
- }
101
- df = data.frame(output, row.names = names(data1)[5:ncol(data1)])
102
- colnames(df) <- c("rhat_rank",
103
- "rhat_raw",
104
- "ess_bulk",
105
- "ess_tail",
106
- "ess_mean",
107
- "ess_sd",
108
- "ess_median",
109
- "ess_raw",
110
- "ess_quantile01",
111
- "ess_quantile10",
112
- "ess_quantile30",
113
- "mcse_mean",
114
- "mcse_sd",
115
- "mcse_median",
116
- "mcse_quantile01",
117
- "mcse_quantile10",
118
- "mcse_quantile30")
119
- write.csv(df, "reference_posterior.csv")
120
- ```
121
- Reference file:
122
-
123
- Created: 2024-12-20
124
- System: Ubuntu 24.04.1 LTS
125
- R version 4.4.2 (2024-10-31)
126
- posterior version from https://github.com/stan-dev/posterior/pull/388
127
- (after release 1.6.0 but before the fixes in the PR were released).
128
- """
129
- # download input files
130
- here = os.path.dirname(os.path.abspath(__file__))
131
- data_directory = os.path.join(here, "..", "saved_models")
132
- path = os.path.join(data_directory, "stan_diagnostics", "blocker.[0-9].csv")
133
- posterior = from_cmdstan(path)
134
- reference_path = os.path.join(data_directory, "stan_diagnostics", "reference_posterior.csv")
135
- reference = (
136
- pd.read_csv(reference_path, index_col=0, float_precision="high")
137
- .sort_index(axis=1)
138
- .sort_index(axis=0)
139
- )
140
- # test arviz functions
141
- funcs = {
142
- "rhat_rank": lambda x: rhat(x, method="rank"),
143
- "rhat_raw": lambda x: rhat(x, method="identity"),
144
- "ess_bulk": lambda x: ess(x, method="bulk"),
145
- "ess_tail": lambda x: ess(x, method="tail"),
146
- "ess_mean": lambda x: ess(x, method="mean"),
147
- "ess_sd": lambda x: ess(x, method="sd"),
148
- "ess_median": lambda x: ess(x, method="median"),
149
- "ess_raw": lambda x: ess(x, method="identity"),
150
- "ess_quantile01": lambda x: ess(x, method="quantile", prob=0.01),
151
- "ess_quantile10": lambda x: ess(x, method="quantile", prob=0.1),
152
- "ess_quantile30": lambda x: ess(x, method="quantile", prob=0.3),
153
- "mcse_mean": lambda x: mcse(x, method="mean"),
154
- "mcse_sd": lambda x: mcse(x, method="sd"),
155
- "mcse_median": lambda x: mcse(x, method="median"),
156
- "mcse_quantile01": lambda x: mcse(x, method="quantile", prob=0.01),
157
- "mcse_quantile10": lambda x: mcse(x, method="quantile", prob=0.1),
158
- "mcse_quantile30": lambda x: mcse(x, method="quantile", prob=0.3),
159
- }
160
- results = {}
161
- for key, coord_dict, _, vals in xarray_var_iter(posterior.posterior, combined=True):
162
- if coord_dict:
163
- key = f"{key}.{list(coord_dict.values())[0] + 1}"
164
- results[key] = {func_name: func(vals) for func_name, func in funcs.items()}
165
- arviz_data = pd.DataFrame.from_dict(results).T.sort_index(axis=1).sort_index(axis=0)
166
-
167
- # check column names
168
- assert set(arviz_data.columns) == set(reference.columns)
169
-
170
- # check parameter names
171
- assert set(arviz_data.index) == set(reference.index)
172
-
173
- # show print with pytests '-s' tag
174
- np.set_printoptions(16)
175
- print(abs(reference - arviz_data).max())
176
-
177
- # test absolute accuracy
178
- assert (abs(reference - arviz_data).values < 1e-8).all(None)
179
-
180
- @pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
181
- @pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
182
- def test_rhat(self, data, var_names, method):
183
- """Confirm R-hat statistic is close to 1 for a large
184
- number of samples. Also checks the correct shape"""
185
- rhat_data = rhat(data, var_names=var_names, method=method)
186
- for r_hat in rhat_data.data_vars.values():
187
- assert ((1 / GOOD_RHAT < r_hat.values) | (r_hat.values < GOOD_RHAT)).all()
188
-
189
- # In None case check that all varnames from rhat_data match input data
190
- if var_names is None:
191
- assert list(rhat_data.data_vars) == list(data.data_vars)
192
-
193
- @pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
194
- def test_rhat_nan(self, method):
195
- """Confirm R-hat statistic returns nan."""
196
- data = np.random.randn(4, 100)
197
- data[0, 0] = np.nan # pylint: disable=unsupported-assignment-operation
198
- rhat_data = rhat(data, method=method)
199
- assert np.isnan(rhat_data)
200
- if method == "rank":
201
- assert np.isnan(_rhat(rhat_data))
202
-
203
- @pytest.mark.parametrize("method", ("rank", "split", "folded", "z_scale", "identity"))
204
- @pytest.mark.parametrize("chain", (None, 1, 2))
205
- @pytest.mark.parametrize("draw", (1, 2, 3, 4))
206
- def test_rhat_shape(self, method, chain, draw):
207
- """Confirm R-hat statistic returns nan."""
208
- data = np.random.randn(draw) if chain is None else np.random.randn(chain, draw)
209
- if (chain in (None, 1)) or (draw < 4):
210
- rhat_data = rhat(data, method=method)
211
- assert np.isnan(rhat_data)
212
- else:
213
- rhat_data = rhat(data, method=method)
214
- assert not np.isnan(rhat_data)
215
-
216
- def test_rhat_bad(self):
217
- """Confirm rank normalized Split R-hat statistic is
218
- far from 1 for a small number of samples."""
219
- r_hat = rhat(np.vstack([20 + np.random.randn(1, 100), np.random.randn(1, 100)]))
220
- assert 1 / GOOD_RHAT > r_hat or GOOD_RHAT < r_hat
221
-
222
- def test_rhat_bad_method(self):
223
- with pytest.raises(TypeError):
224
- rhat(np.random.randn(2, 300), method="wrong_method")
225
-
226
- def test_rhat_ndarray(self):
227
- with pytest.raises(TypeError):
228
- rhat(np.random.randn(2, 300, 10))
229
-
230
- @pytest.mark.parametrize(
231
- "method",
232
- (
233
- "bulk",
234
- "tail",
235
- "quantile",
236
- "local",
237
- "mean",
238
- "sd",
239
- "median",
240
- "mad",
241
- "z_scale",
242
- "folded",
243
- "identity",
244
- ),
245
- )
246
- @pytest.mark.parametrize("relative", (True, False))
247
- def test_effective_sample_size_array(self, data, method, relative):
248
- n_low = 100 / 400 if relative else 100
249
- n_high = 800 / 400 if relative else 800
250
- if method in ("quantile", "tail"):
251
- ess_hat = ess(data, method=method, prob=0.34, relative=relative)
252
- if method == "tail":
253
- assert ess_hat > n_low
254
- assert ess_hat < n_high
255
- ess_hat = ess(np.random.randn(4, 100), method=method, relative=relative)
256
- assert ess_hat > n_low
257
- assert ess_hat < n_high
258
- ess_hat = ess(
259
- np.random.randn(4, 100), method=method, prob=(0.2, 0.8), relative=relative
260
- )
261
- elif method == "local":
262
- ess_hat = ess(
263
- np.random.randn(4, 100), method=method, prob=(0.2, 0.3), relative=relative
264
- )
265
- else:
266
- ess_hat = ess(np.random.randn(4, 100), method=method, relative=relative)
267
- assert ess_hat > n_low
268
- assert ess_hat < n_high
269
-
270
- @pytest.mark.parametrize(
271
- "method",
272
- (
273
- "bulk",
274
- "tail",
275
- "quantile",
276
- "local",
277
- "mean",
278
- "sd",
279
- "median",
280
- "mad",
281
- "z_scale",
282
- "folded",
283
- "identity",
284
- ),
285
- )
286
- @pytest.mark.parametrize("relative", (True, False))
287
- @pytest.mark.parametrize("chain", (None, 1, 2))
288
- @pytest.mark.parametrize("draw", (1, 2, 3, 4))
289
- @pytest.mark.parametrize("use_nan", (True, False))
290
- def test_effective_sample_size_nan(self, method, relative, chain, draw, use_nan):
291
- data = np.random.randn(draw) if chain is None else np.random.randn(chain, draw)
292
- if use_nan:
293
- data[0] = np.nan
294
- if method in ("quantile", "tail"):
295
- ess_value = ess(data, method=method, prob=0.34, relative=relative)
296
- elif method == "local":
297
- ess_value = ess(data, method=method, prob=(0.2, 0.3), relative=relative)
298
- else:
299
- ess_value = ess(data, method=method, relative=relative)
300
- if (draw < 4) or use_nan:
301
- assert np.isnan(ess_value)
302
- else:
303
- assert not np.isnan(ess_value)
304
- # test following only once tests are run
305
- if (method == "bulk") and (not relative) and (chain is None) and (draw == 4):
306
- if use_nan:
307
- assert np.isnan(_ess(data))
308
- else:
309
- assert not np.isnan(_ess(data))
310
-
311
- @pytest.mark.parametrize("relative", (True, False))
312
- def test_effective_sample_size_missing_prob(self, relative):
313
- with pytest.raises(TypeError):
314
- ess(np.random.randn(4, 100), method="quantile", relative=relative)
315
- with pytest.raises(TypeError):
316
- _ess_quantile(np.random.randn(4, 100), prob=None, relative=relative)
317
- with pytest.raises(TypeError):
318
- ess(np.random.randn(4, 100), method="local", relative=relative)
319
-
320
- @pytest.mark.parametrize("relative", (True, False))
321
- def test_effective_sample_size_too_many_probs(self, relative):
322
- with pytest.raises(ValueError):
323
- ess(np.random.randn(4, 100), method="local", prob=[0.1, 0.2, 0.9], relative=relative)
324
-
325
- def test_effective_sample_size_constant(self):
326
- assert ess(np.ones((4, 100))) == 400
327
-
328
- def test_effective_sample_size_bad_method(self):
329
- with pytest.raises(TypeError):
330
- ess(np.random.randn(4, 100), method="wrong_method")
331
-
332
- def test_effective_sample_size_ndarray(self):
333
- with pytest.raises(TypeError):
334
- ess(np.random.randn(2, 300, 10))
335
-
336
- @pytest.mark.parametrize(
337
- "method",
338
- (
339
- "bulk",
340
- "tail",
341
- "quantile",
342
- "local",
343
- "mean",
344
- "sd",
345
- "median",
346
- "mad",
347
- "z_scale",
348
- "folded",
349
- "identity",
350
- ),
351
- )
352
- @pytest.mark.parametrize("relative", (True, False))
353
- @pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
354
- def test_effective_sample_size_dataset(self, data, method, var_names, relative):
355
- n_low = 100 / (data.chain.size * data.draw.size) if relative else 100
356
- if method in ("quantile", "tail"):
357
- ess_hat = ess(data, var_names=var_names, method=method, prob=0.34, relative=relative)
358
- elif method == "local":
359
- ess_hat = ess(
360
- data, var_names=var_names, method=method, prob=(0.2, 0.3), relative=relative
361
- )
362
- else:
363
- ess_hat = ess(data, var_names=var_names, method=method, relative=relative)
364
- assert np.all(ess_hat.mu.values > n_low) # This might break if the data is regenerated
365
-
366
- @pytest.mark.parametrize("mcse_method", ("mean", "sd", "median", "quantile"))
367
- def test_mcse_array(self, mcse_method):
368
- if mcse_method == "quantile":
369
- mcse_hat = mcse(np.random.randn(4, 100), method=mcse_method, prob=0.34)
370
- else:
371
- mcse_hat = mcse(np.random.randn(4, 100), method=mcse_method)
372
- assert mcse_hat
373
-
374
- def test_mcse_ndarray(self):
375
- with pytest.raises(TypeError):
376
- mcse(np.random.randn(2, 300, 10))
377
-
378
- @pytest.mark.parametrize("mcse_method", ("mean", "sd", "median", "quantile"))
379
- @pytest.mark.parametrize("var_names", (None, "mu", ["mu", "tau"]))
380
- def test_mcse_dataset(self, data, mcse_method, var_names):
381
- if mcse_method == "quantile":
382
- mcse_hat = mcse(data, var_names=var_names, method=mcse_method, prob=0.34)
383
- else:
384
- mcse_hat = mcse(data, var_names=var_names, method=mcse_method)
385
- assert mcse_hat # This might break if the data is regenerated
386
-
387
- @pytest.mark.parametrize("mcse_method", ("mean", "sd", "median", "quantile"))
388
- @pytest.mark.parametrize("chain", (None, 1, 2))
389
- @pytest.mark.parametrize("draw", (1, 2, 3, 4))
390
- @pytest.mark.parametrize("use_nan", (True, False))
391
- def test_mcse_nan(self, mcse_method, chain, draw, use_nan):
392
- data = np.random.randn(draw) if chain is None else np.random.randn(chain, draw)
393
- if use_nan:
394
- data[0] = np.nan
395
- if mcse_method == "quantile":
396
- mcse_hat = mcse(data, method=mcse_method, prob=0.34)
397
- else:
398
- mcse_hat = mcse(data, method=mcse_method)
399
- if draw < 4 or use_nan:
400
- assert np.isnan(mcse_hat)
401
- else:
402
- assert not np.isnan(mcse_hat)
403
-
404
- @pytest.mark.parametrize("method", ("wrong_method", "quantile"))
405
- def test_mcse_bad_method(self, data, method):
406
- with pytest.raises(TypeError):
407
- mcse(data, method=method, prob=None)
408
-
409
- @pytest.mark.parametrize("draws", (3, 4, 100))
410
- @pytest.mark.parametrize("chains", (None, 1, 2))
411
- def test_multichain_summary_array(self, draws, chains):
412
- """Test multichain statistics against individual functions."""
413
- if chains is None:
414
- ary = np.random.randn(draws)
415
- else:
416
- ary = np.random.randn(chains, draws)
417
-
418
- mcse_mean_hat = mcse(ary, method="mean")
419
- mcse_sd_hat = mcse(ary, method="sd")
420
- ess_bulk_hat = ess(ary, method="bulk")
421
- ess_tail_hat = ess(ary, method="tail")
422
- rhat_hat = _rhat_rank(ary)
423
- (
424
- mcse_mean_hat_,
425
- mcse_sd_hat_,
426
- ess_bulk_hat_,
427
- ess_tail_hat_,
428
- rhat_hat_,
429
- ) = _multichain_statistics(ary)
430
- if draws == 3:
431
- assert np.isnan(
432
- (
433
- mcse_mean_hat,
434
- mcse_sd_hat,
435
- ess_bulk_hat,
436
- ess_tail_hat,
437
- rhat_hat,
438
- )
439
- ).all()
440
- assert np.isnan(
441
- (
442
- mcse_mean_hat_,
443
- mcse_sd_hat_,
444
- ess_bulk_hat_,
445
- ess_tail_hat_,
446
- rhat_hat_,
447
- )
448
- ).all()
449
- else:
450
- assert_almost_equal(mcse_mean_hat, mcse_mean_hat_)
451
- assert_almost_equal(mcse_sd_hat, mcse_sd_hat_)
452
- assert_almost_equal(ess_bulk_hat, ess_bulk_hat_)
453
- assert_almost_equal(ess_tail_hat, ess_tail_hat_)
454
- if chains in (None, 1):
455
- assert np.isnan(rhat_hat)
456
- assert np.isnan(rhat_hat_)
457
- else:
458
- assert round(rhat_hat, 3) == round(rhat_hat_, 3)
459
-
460
- def test_ks_summary(self):
461
- """Instead of psislw data, this test uses fake data."""
462
- pareto_tail_indices = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2])
463
- with pytest.warns(UserWarning):
464
- summary = ks_summary(pareto_tail_indices)
465
- assert summary is not None
466
- pareto_tail_indices2 = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2, 0.6])
467
- with pytest.warns(UserWarning):
468
- summary2 = ks_summary(pareto_tail_indices2)
469
- assert summary2 is not None
470
-
471
- @pytest.mark.parametrize("size", [100, 101])
472
- @pytest.mark.parametrize("batches", [1, 2, 3, 5, 7])
473
- @pytest.mark.parametrize("ndim", [1, 2, 3])
474
- @pytest.mark.parametrize("circular", [False, True])
475
- def test_mc_error(self, size, batches, ndim, circular):
476
- x = np.random.randn(size, ndim).squeeze() # pylint: disable=no-member
477
- assert _mc_error(x, batches=batches, circular=circular) is not None
478
-
479
- @pytest.mark.parametrize("size", [100, 101])
480
- @pytest.mark.parametrize("ndim", [1, 2, 3])
481
- def test_mc_error_nan(self, size, ndim):
482
- x = np.random.randn(size, ndim).squeeze() # pylint: disable=no-member
483
- x[0] = np.nan
484
- if ndim != 1:
485
- assert np.isnan(_mc_error(x)).all()
486
- else:
487
- assert np.isnan(_mc_error(x))
488
-
489
- @pytest.mark.parametrize("func", ("_mcse_quantile", "_z_scale"))
490
- def test_nan_behaviour(self, func):
491
- data = np.random.randn(100, 4)
492
- data[0, 0] = np.nan # pylint: disable=unsupported-assignment-operation
493
- if func == "_mcse_quantile":
494
- assert np.isnan(_mcse_quantile(data, 0.5)).all(None)
495
- elif packaging.version.parse(scipy.__version__) < packaging.version.parse("1.10.0.dev0"):
496
- assert not np.isnan(_z_scale(data)).all(None)
497
- assert not np.isnan(_z_scale(data)).any(None)
498
- else:
499
- assert np.isnan(_z_scale(data)).sum() == 1
500
-
501
- @pytest.mark.parametrize("chains", (None, 1, 2, 3))
502
- @pytest.mark.parametrize("draws", (2, 3, 100, 101))
503
- def test_split_chain_dims(self, chains, draws):
504
- if chains is None:
505
- data = np.random.randn(draws)
506
- else:
507
- data = np.random.randn(chains, draws)
508
- split_data = _split_chains(data)
509
- if chains is None:
510
- chains = 1
511
- assert split_data.shape == (chains * 2, draws // 2)