cloudposterior 0.6.0a1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cloudposterior/__init__.py +3 -0
- cloudposterior/_idata.py +184 -0
- cloudposterior/api.py +1532 -0
- cloudposterior/backends/__init__.py +93 -0
- cloudposterior/backends/modal_backend.py +792 -0
- cloudposterior/cache.py +152 -0
- cloudposterior/config.py +146 -0
- cloudposterior/dashboard.py +689 -0
- cloudposterior/display.py +492 -0
- cloudposterior/naming.py +95 -0
- cloudposterior/notify.py +170 -0
- cloudposterior/progress.py +234 -0
- cloudposterior/remote/__init__.py +0 -0
- cloudposterior/remote/worker.py +833 -0
- cloudposterior/serialize.py +142 -0
- cloudposterior/wordhash.py +28 -0
- cloudposterior-0.6.0a1.dist-info/METADATA +403 -0
- cloudposterior-0.6.0a1.dist-info/RECORD +20 -0
- cloudposterior-0.6.0a1.dist-info/WHEEL +4 -0
- cloudposterior-0.6.0a1.dist-info/licenses/LICENSE +21 -0
cloudposterior/_idata.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
"""arviz 0.x / 1.x compatibility shims.
|
|
2
|
+
|
|
3
|
+
PyMC 6 hard-requires ``arviz>=1.1``, a DataTree-based rewrite that:
|
|
4
|
+
|
|
5
|
+
- makes ``InferenceData.groups()`` a ``DataTree.groups`` *property* returning
|
|
6
|
+
slash-prefixed paths (``/posterior``) plus a ``/`` root,
|
|
7
|
+
- removes ``arviz.convert_to_inference_data``,
|
|
8
|
+
- changes ``arviz.ess(..., method="tail")`` to require a ``prob`` argument,
|
|
9
|
+
- returns ``xarray.DataTree`` from samplers / ``from_netcdf`` instead of
|
|
10
|
+
``arviz.InferenceData``.
|
|
11
|
+
|
|
12
|
+
PyMC 5 ships arviz 0.x with the old API. These helpers work on both majors so
|
|
13
|
+
cloudposterior (and the remote worker) stay version-agnostic.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def group_names(idata) -> list[str]:
|
|
20
|
+
"""Clean group names (no leading slash, no DataTree root) for both majors."""
|
|
21
|
+
groups = idata.groups
|
|
22
|
+
raw = list(groups() if callable(groups) else groups)
|
|
23
|
+
out = []
|
|
24
|
+
for name in raw:
|
|
25
|
+
name = name.strip("/")
|
|
26
|
+
if name: # drop the DataTree root group ("/")
|
|
27
|
+
out.append(name)
|
|
28
|
+
return out
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def get_group(idata, name):
|
|
32
|
+
"""Return a group dataset/node by name for both majors (attr or item access)."""
|
|
33
|
+
try:
|
|
34
|
+
group = getattr(idata, name)
|
|
35
|
+
if group is not None:
|
|
36
|
+
return group
|
|
37
|
+
except Exception:
|
|
38
|
+
pass
|
|
39
|
+
try:
|
|
40
|
+
return idata[name]
|
|
41
|
+
except Exception:
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def group_attrs(idata, name=None):
|
|
46
|
+
"""attrs dict of a group, or the top-level attrs when ``name`` is None."""
|
|
47
|
+
if name is None:
|
|
48
|
+
return getattr(idata, "attrs", None)
|
|
49
|
+
return getattr(get_group(idata, name), "attrs", None)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def add_group(idata, name: str, group) -> None:
|
|
53
|
+
"""Add ``group`` to ``idata`` under ``name`` in place, across arviz majors.
|
|
54
|
+
|
|
55
|
+
Used to merge a remotely-computed group (e.g. ``log_likelihood``) into the
|
|
56
|
+
caller's local idata so cloudposterior matches PyMC's
|
|
57
|
+
``extend_inferencedata=True`` in-place semantics. ``group`` may be an
|
|
58
|
+
xarray Dataset or a DataTree node; it is normalized to a Dataset.
|
|
59
|
+
"""
|
|
60
|
+
import xarray as xr
|
|
61
|
+
|
|
62
|
+
ds = group
|
|
63
|
+
if not isinstance(group, xr.Dataset):
|
|
64
|
+
to_ds = getattr(group, "to_dataset", None)
|
|
65
|
+
if callable(to_ds):
|
|
66
|
+
try:
|
|
67
|
+
ds = to_ds()
|
|
68
|
+
except Exception:
|
|
69
|
+
ds = group
|
|
70
|
+
|
|
71
|
+
# arviz 0.x InferenceData exposes add_groups({name: ds}).
|
|
72
|
+
adder = getattr(idata, "add_groups", None)
|
|
73
|
+
if callable(adder):
|
|
74
|
+
try:
|
|
75
|
+
adder({name: ds})
|
|
76
|
+
return
|
|
77
|
+
except Exception:
|
|
78
|
+
pass
|
|
79
|
+
|
|
80
|
+
# arviz 1.x DataTree (and general fallback): item assignment creates a child.
|
|
81
|
+
try:
|
|
82
|
+
idata[name] = ds
|
|
83
|
+
return
|
|
84
|
+
except Exception:
|
|
85
|
+
pass
|
|
86
|
+
|
|
87
|
+
setattr(idata, name, ds)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def load_all(idata) -> None:
|
|
91
|
+
"""Best-effort eager load of every group so a temp NetCDF file can be deleted."""
|
|
92
|
+
for name in group_names(idata):
|
|
93
|
+
loader = getattr(get_group(idata, name), "load", None)
|
|
94
|
+
if callable(loader):
|
|
95
|
+
try:
|
|
96
|
+
loader()
|
|
97
|
+
except Exception:
|
|
98
|
+
pass
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def to_inference_data(trace):
|
|
102
|
+
"""Normalize a sampler result to an arviz object across majors.
|
|
103
|
+
|
|
104
|
+
arviz 0.x: convert non-InferenceData via ``convert_to_inference_data``.
|
|
105
|
+
arviz 1.x: nutpie / pm.sample already return a DataTree -- use it as-is.
|
|
106
|
+
"""
|
|
107
|
+
import arviz as az
|
|
108
|
+
|
|
109
|
+
inference_data = getattr(az, "InferenceData", None)
|
|
110
|
+
if inference_data is not None and isinstance(trace, inference_data):
|
|
111
|
+
return trace
|
|
112
|
+
conv = getattr(az, "convert_to_inference_data", None)
|
|
113
|
+
if conv is not None:
|
|
114
|
+
try:
|
|
115
|
+
return conv(trace)
|
|
116
|
+
except Exception:
|
|
117
|
+
pass
|
|
118
|
+
return trace
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def ess_tail(arr) -> float:
|
|
122
|
+
"""Tail-ESS across majors (arviz 1.x changed ``ess(method="tail")``)."""
|
|
123
|
+
import arviz as az
|
|
124
|
+
|
|
125
|
+
try:
|
|
126
|
+
return float(az.ess(arr, method="tail"))
|
|
127
|
+
except TypeError:
|
|
128
|
+
for kwargs in ({"method": "tail", "prob": (0.025, 0.975)}, {"method": "tail", "prob": 0.05}):
|
|
129
|
+
try:
|
|
130
|
+
return float(az.ess(arr, **kwargs))
|
|
131
|
+
except Exception:
|
|
132
|
+
continue
|
|
133
|
+
return float(az.ess(arr)) # last resort: bulk ESS
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def sanitize_inference_data(idata):
|
|
137
|
+
"""Make all attrs (top-level + each group) NetCDF-serializable, in place.
|
|
138
|
+
|
|
139
|
+
nutpie stores a dict-valued ``sample_stats`` attr that xarray's NetCDF writer
|
|
140
|
+
rejects (it only accepts str/Number/ndarray/list/tuple/bytes). Any other value
|
|
141
|
+
is JSON-encoded. Idempotent for already-clean objects.
|
|
142
|
+
"""
|
|
143
|
+
import json
|
|
144
|
+
|
|
145
|
+
try:
|
|
146
|
+
import numpy as np
|
|
147
|
+
|
|
148
|
+
ok = (str, bytes, int, float, list, tuple, np.ndarray, np.number)
|
|
149
|
+
except Exception:
|
|
150
|
+
ok = (str, bytes, int, float, list, tuple)
|
|
151
|
+
|
|
152
|
+
def _fix(attrs):
|
|
153
|
+
if not isinstance(attrs, dict):
|
|
154
|
+
return
|
|
155
|
+
for key, value in list(attrs.items()):
|
|
156
|
+
if not isinstance(value, ok):
|
|
157
|
+
try:
|
|
158
|
+
attrs[key] = json.dumps(value, default=str)
|
|
159
|
+
except Exception:
|
|
160
|
+
attrs[key] = str(value)
|
|
161
|
+
|
|
162
|
+
def _coerce_object_datavars(group):
|
|
163
|
+
"""Coerce object-dtype numeric data variables to float64 in place.
|
|
164
|
+
|
|
165
|
+
PyMC's SMC sample_stats (beta, accept_rate, log_marginal_likelihood)
|
|
166
|
+
come back as object arrays of mixed Python float/int that NetCDF can't
|
|
167
|
+
write (even native ``idata.to_netcdf()`` raises). The values are regular
|
|
168
|
+
(chain x stage) numbers, so float64 is lossless.
|
|
169
|
+
"""
|
|
170
|
+
data_vars = getattr(group, "data_vars", None)
|
|
171
|
+
if data_vars is None:
|
|
172
|
+
return
|
|
173
|
+
for name in list(data_vars):
|
|
174
|
+
try:
|
|
175
|
+
if group[name].dtype == object:
|
|
176
|
+
group[name] = group[name].astype("float64")
|
|
177
|
+
except (ValueError, TypeError, KeyError):
|
|
178
|
+
pass
|
|
179
|
+
|
|
180
|
+
_fix(group_attrs(idata, None))
|
|
181
|
+
for name in group_names(idata):
|
|
182
|
+
_fix(group_attrs(idata, name))
|
|
183
|
+
_coerce_object_datavars(get_group(idata, name))
|
|
184
|
+
return idata
|