cloudposterior 0.6.0a1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ from cloudposterior.api import cleanup_volumes, cloud, map, sample
2
+
3
+ __all__ = ["cleanup_volumes", "cloud", "map", "sample"]
@@ -0,0 +1,184 @@
1
+ """arviz 0.x / 1.x compatibility shims.
2
+
3
+ PyMC 6 hard-requires ``arviz>=1.1``, a DataTree-based rewrite that:
4
+
5
+ - makes ``InferenceData.groups()`` a ``DataTree.groups`` *property* returning
6
+ slash-prefixed paths (``/posterior``) plus a ``/`` root,
7
+ - removes ``arviz.convert_to_inference_data``,
8
+ - changes ``arviz.ess(..., method="tail")`` to require a ``prob`` argument,
9
+ - returns ``xarray.DataTree`` from samplers / ``from_netcdf`` instead of
10
+ ``arviz.InferenceData``.
11
+
12
+ PyMC 5 ships arviz 0.x with the old API. These helpers work on both majors so
13
+ cloudposterior (and the remote worker) stay version-agnostic.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+
19
+ def group_names(idata) -> list[str]:
20
+ """Clean group names (no leading slash, no DataTree root) for both majors."""
21
+ groups = idata.groups
22
+ raw = list(groups() if callable(groups) else groups)
23
+ out = []
24
+ for name in raw:
25
+ name = name.strip("/")
26
+ if name: # drop the DataTree root group ("/")
27
+ out.append(name)
28
+ return out
29
+
30
+
31
+ def get_group(idata, name):
32
+ """Return a group dataset/node by name for both majors (attr or item access)."""
33
+ try:
34
+ group = getattr(idata, name)
35
+ if group is not None:
36
+ return group
37
+ except Exception:
38
+ pass
39
+ try:
40
+ return idata[name]
41
+ except Exception:
42
+ return None
43
+
44
+
45
+ def group_attrs(idata, name=None):
46
+ """attrs dict of a group, or the top-level attrs when ``name`` is None."""
47
+ if name is None:
48
+ return getattr(idata, "attrs", None)
49
+ return getattr(get_group(idata, name), "attrs", None)
50
+
51
+
52
+ def add_group(idata, name: str, group) -> None:
53
+ """Add ``group`` to ``idata`` under ``name`` in place, across arviz majors.
54
+
55
+ Used to merge a remotely-computed group (e.g. ``log_likelihood``) into the
56
+ caller's local idata so cloudposterior matches PyMC's
57
+ ``extend_inferencedata=True`` in-place semantics. ``group`` may be an
58
+ xarray Dataset or a DataTree node; it is normalized to a Dataset.
59
+ """
60
+ import xarray as xr
61
+
62
+ ds = group
63
+ if not isinstance(group, xr.Dataset):
64
+ to_ds = getattr(group, "to_dataset", None)
65
+ if callable(to_ds):
66
+ try:
67
+ ds = to_ds()
68
+ except Exception:
69
+ ds = group
70
+
71
+ # arviz 0.x InferenceData exposes add_groups({name: ds}).
72
+ adder = getattr(idata, "add_groups", None)
73
+ if callable(adder):
74
+ try:
75
+ adder({name: ds})
76
+ return
77
+ except Exception:
78
+ pass
79
+
80
+ # arviz 1.x DataTree (and general fallback): item assignment creates a child.
81
+ try:
82
+ idata[name] = ds
83
+ return
84
+ except Exception:
85
+ pass
86
+
87
+ setattr(idata, name, ds)
88
+
89
+
90
+ def load_all(idata) -> None:
91
+ """Best-effort eager load of every group so a temp NetCDF file can be deleted."""
92
+ for name in group_names(idata):
93
+ loader = getattr(get_group(idata, name), "load", None)
94
+ if callable(loader):
95
+ try:
96
+ loader()
97
+ except Exception:
98
+ pass
99
+
100
+
101
+ def to_inference_data(trace):
102
+ """Normalize a sampler result to an arviz object across majors.
103
+
104
+ arviz 0.x: convert non-InferenceData via ``convert_to_inference_data``.
105
+ arviz 1.x: nutpie / pm.sample already return a DataTree -- use it as-is.
106
+ """
107
+ import arviz as az
108
+
109
+ inference_data = getattr(az, "InferenceData", None)
110
+ if inference_data is not None and isinstance(trace, inference_data):
111
+ return trace
112
+ conv = getattr(az, "convert_to_inference_data", None)
113
+ if conv is not None:
114
+ try:
115
+ return conv(trace)
116
+ except Exception:
117
+ pass
118
+ return trace
119
+
120
+
121
+ def ess_tail(arr) -> float:
122
+ """Tail-ESS across majors (arviz 1.x changed ``ess(method="tail")``)."""
123
+ import arviz as az
124
+
125
+ try:
126
+ return float(az.ess(arr, method="tail"))
127
+ except TypeError:
128
+ for kwargs in ({"method": "tail", "prob": (0.025, 0.975)}, {"method": "tail", "prob": 0.05}):
129
+ try:
130
+ return float(az.ess(arr, **kwargs))
131
+ except Exception:
132
+ continue
133
+ return float(az.ess(arr)) # last resort: bulk ESS
134
+
135
+
136
+ def sanitize_inference_data(idata):
137
+ """Make all attrs (top-level + each group) NetCDF-serializable, in place.
138
+
139
+ nutpie stores a dict-valued ``sample_stats`` attr that xarray's NetCDF writer
140
+ rejects (it only accepts str/Number/ndarray/list/tuple/bytes). Any other value
141
+ is JSON-encoded. Idempotent for already-clean objects.
142
+ """
143
+ import json
144
+
145
+ try:
146
+ import numpy as np
147
+
148
+ ok = (str, bytes, int, float, list, tuple, np.ndarray, np.number)
149
+ except Exception:
150
+ ok = (str, bytes, int, float, list, tuple)
151
+
152
+ def _fix(attrs):
153
+ if not isinstance(attrs, dict):
154
+ return
155
+ for key, value in list(attrs.items()):
156
+ if not isinstance(value, ok):
157
+ try:
158
+ attrs[key] = json.dumps(value, default=str)
159
+ except Exception:
160
+ attrs[key] = str(value)
161
+
162
+ def _coerce_object_datavars(group):
163
+ """Coerce object-dtype numeric data variables to float64 in place.
164
+
165
+ PyMC's SMC sample_stats (beta, accept_rate, log_marginal_likelihood)
166
+ come back as object arrays of mixed Python float/int that NetCDF can't
167
+ write (even native ``idata.to_netcdf()`` raises). The values are regular
168
+ (chain x stage) numbers, so float64 is lossless.
169
+ """
170
+ data_vars = getattr(group, "data_vars", None)
171
+ if data_vars is None:
172
+ return
173
+ for name in list(data_vars):
174
+ try:
175
+ if group[name].dtype == object:
176
+ group[name] = group[name].astype("float64")
177
+ except (ValueError, TypeError, KeyError):
178
+ pass
179
+
180
+ _fix(group_attrs(idata, None))
181
+ for name in group_names(idata):
182
+ _fix(group_attrs(idata, name))
183
+ _coerce_object_datavars(get_group(idata, name))
184
+ return idata