liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- liesel_gam/__about__.py +1 -1
- liesel_gam/__init__.py +38 -1
- liesel_gam/builder/__init__.py +8 -0
- liesel_gam/builder/builder.py +2003 -0
- liesel_gam/builder/category_mapping.py +158 -0
- liesel_gam/builder/consolidate_bases.py +105 -0
- liesel_gam/builder/registry.py +561 -0
- liesel_gam/constraint.py +107 -0
- liesel_gam/dist.py +541 -1
- liesel_gam/kernel.py +18 -7
- liesel_gam/plots.py +946 -0
- liesel_gam/predictor.py +59 -20
- liesel_gam/var.py +1508 -126
- liesel_gam-0.0.6a4.dist-info/METADATA +559 -0
- liesel_gam-0.0.6a4.dist-info/RECORD +18 -0
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/WHEEL +1 -1
- liesel_gam-0.0.4.dist-info/METADATA +0 -160
- liesel_gam-0.0.4.dist-info/RECORD +0 -11
- {liesel_gam-0.0.4.dist-info → liesel_gam-0.0.6a4.dist-info}/licenses/LICENSE +0 -0
liesel_gam/plots.py
ADDED
|
@@ -0,0 +1,946 @@
|
|
|
1
|
+
from collections.abc import Mapping, Sequence
|
|
2
|
+
from typing import Any, Literal
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import liesel.goose as gs
|
|
7
|
+
import liesel.model as lsl
|
|
8
|
+
import numpy as np
|
|
9
|
+
import pandas as pd
|
|
10
|
+
import plotnine as p9
|
|
11
|
+
from jax import Array
|
|
12
|
+
from jax.typing import ArrayLike
|
|
13
|
+
|
|
14
|
+
from .builder.registry import CategoryMapping
|
|
15
|
+
from .var import LinTerm, MRFTerm, RITerm, Term, TPTerm
|
|
16
|
+
|
|
17
|
+
KeyArray = Any
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def summarise_by_samples(
|
|
21
|
+
key: KeyArray, a: Any, name: str, n: int = 100
|
|
22
|
+
) -> pd.DataFrame:
|
|
23
|
+
"""
|
|
24
|
+
- index: index of the flattened array
|
|
25
|
+
- sample: sample number
|
|
26
|
+
- obs: observation number (enumerates response values)
|
|
27
|
+
- chain: chain number
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
_, iterations, _ = a.shape
|
|
31
|
+
|
|
32
|
+
a = np.concatenate(a, axis=0)
|
|
33
|
+
idx = jax.random.choice(key, a.shape[0], shape=(n,), replace=True)
|
|
34
|
+
|
|
35
|
+
a_column = a[idx, :].ravel()
|
|
36
|
+
sample_column = np.repeat(np.arange(n), a.shape[-1])
|
|
37
|
+
index_column = np.repeat(idx, a.shape[-1])
|
|
38
|
+
obs_column = np.tile(np.arange(a.shape[-1]), n)
|
|
39
|
+
|
|
40
|
+
data = {name: a_column, "sample": sample_column}
|
|
41
|
+
data["index"] = index_column
|
|
42
|
+
data["obs"] = obs_column
|
|
43
|
+
df = pd.DataFrame(data)
|
|
44
|
+
|
|
45
|
+
df["chain"] = df["index"] // iterations
|
|
46
|
+
|
|
47
|
+
return df
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def summarise_1d_smooth(
|
|
51
|
+
term: Term,
|
|
52
|
+
samples: dict[str, Array],
|
|
53
|
+
newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
|
|
54
|
+
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
|
|
55
|
+
hdi_prob: float = 0.9,
|
|
56
|
+
ngrid: int = 150,
|
|
57
|
+
):
|
|
58
|
+
if newdata is None:
|
|
59
|
+
# TODO: Currently, this branch of the function assumes that term.basis.x is
|
|
60
|
+
# a strong node.
|
|
61
|
+
# That is not necessarily always the case.
|
|
62
|
+
xgrid = np.linspace(term.basis.x.value.min(), term.basis.x.value.max(), ngrid)
|
|
63
|
+
newdata_x: Mapping[str, ArrayLike] = {term.basis.x.name: xgrid}
|
|
64
|
+
else:
|
|
65
|
+
newdata_x = newdata
|
|
66
|
+
xgrid = np.asarray(newdata[term.basis.x.name])
|
|
67
|
+
|
|
68
|
+
newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
|
|
69
|
+
|
|
70
|
+
term_samples = term.predict(samples, newdata=newdata_x)
|
|
71
|
+
term_summary = (
|
|
72
|
+
gs.SamplesSummary.from_array(
|
|
73
|
+
term_samples, name=term.name, quantiles=quantiles, hdi_prob=hdi_prob
|
|
74
|
+
)
|
|
75
|
+
.to_dataframe()
|
|
76
|
+
.reset_index()
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
term_summary[term.basis.x.name] = xgrid
|
|
80
|
+
return term_summary
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def plot_1d_smooth(
|
|
84
|
+
term: Term,
|
|
85
|
+
samples: dict[str, Array],
|
|
86
|
+
newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
|
|
87
|
+
ci_quantiles: tuple[float, float] | None = (0.05, 0.95),
|
|
88
|
+
hdi_prob: float | None = None,
|
|
89
|
+
show_n_samples: int | None = 50,
|
|
90
|
+
seed: int | KeyArray = 1,
|
|
91
|
+
ngrid: int = 150,
|
|
92
|
+
):
|
|
93
|
+
if newdata is None:
|
|
94
|
+
# TODO: Currently, this branch of the function assumes that term.basis.x is
|
|
95
|
+
# a strong node.
|
|
96
|
+
# That is not necessarily always the case.
|
|
97
|
+
xgrid = np.linspace(term.basis.x.value.min(), term.basis.x.value.max(), 150)
|
|
98
|
+
newdata_x: Mapping[str, ArrayLike] = {term.basis.x.name: xgrid}
|
|
99
|
+
else:
|
|
100
|
+
newdata_x = newdata
|
|
101
|
+
xgrid = np.asarray(newdata[term.basis.x.name])
|
|
102
|
+
|
|
103
|
+
newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
|
|
104
|
+
|
|
105
|
+
term_samples = term.predict(samples, newdata=newdata_x)
|
|
106
|
+
|
|
107
|
+
term_summary = summarise_1d_smooth(
|
|
108
|
+
term=term,
|
|
109
|
+
samples=samples,
|
|
110
|
+
newdata=newdata,
|
|
111
|
+
quantiles=(0.05, 0.95) if ci_quantiles is None else ci_quantiles,
|
|
112
|
+
hdi_prob=0.9 if hdi_prob is None else hdi_prob,
|
|
113
|
+
ngrid=ngrid,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
p = p9.ggplot(term_summary) + p9.labs(
|
|
117
|
+
title=f"Posterior summary of {term.name}",
|
|
118
|
+
x=term.basis.x.name,
|
|
119
|
+
y=term.name,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if ci_quantiles is not None:
|
|
123
|
+
p = p + p9.geom_ribbon(
|
|
124
|
+
p9.aes(
|
|
125
|
+
term.basis.x.name,
|
|
126
|
+
ymin=f"q_{str(ci_quantiles[0])}",
|
|
127
|
+
ymax=f"q_{str(ci_quantiles[1])}",
|
|
128
|
+
),
|
|
129
|
+
fill="#56B4E9",
|
|
130
|
+
alpha=0.5,
|
|
131
|
+
data=term_summary,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
if hdi_prob is not None:
|
|
135
|
+
p = p + p9.geom_line(
|
|
136
|
+
p9.aes(term.basis.x.name, "hdi_low"),
|
|
137
|
+
linetype="dashed",
|
|
138
|
+
data=term_summary,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
p = p + p9.geom_line(
|
|
142
|
+
p9.aes(term.basis.x.name, "hdi_high"),
|
|
143
|
+
linetype="dashed",
|
|
144
|
+
data=term_summary,
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
if show_n_samples is not None and show_n_samples > 0:
|
|
148
|
+
key = jax.random.key(seed) if isinstance(seed, int) else seed
|
|
149
|
+
|
|
150
|
+
summary_samples_df = summarise_by_samples(
|
|
151
|
+
key=key, a=term_samples, name=term.name, n=show_n_samples
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
summary_samples_df[term.basis.x.name] = np.tile(
|
|
155
|
+
np.squeeze(xgrid), show_n_samples
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
p = p + p9.geom_line(
|
|
159
|
+
p9.aes(term.basis.x.name, term.name, group="sample"),
|
|
160
|
+
color="grey",
|
|
161
|
+
data=summary_samples_df,
|
|
162
|
+
alpha=0.3,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
p = p + p9.geom_line(
|
|
166
|
+
p9.aes(term.basis.x.name, "mean"), data=term_summary, size=1.3, color="blue"
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
return p
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def grid_nd(inputs: dict[str, jax.typing.ArrayLike], ngrid: int) -> dict[str, Any]:
|
|
173
|
+
mins = {k: jnp.min(v) for k, v in inputs.items()}
|
|
174
|
+
maxs = {k: jnp.max(v) for k, v in inputs.items()}
|
|
175
|
+
grids = {k: np.linspace(mins[k], maxs[k], ngrid) for k in inputs}
|
|
176
|
+
full_grid_arrays = [v.flatten() for v in np.meshgrid(*grids.values())]
|
|
177
|
+
full_grids = dict(zip(inputs.keys(), full_grid_arrays))
|
|
178
|
+
return full_grids
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def input_grid_nd_smooth(
|
|
182
|
+
term: TPTerm | Term | LinTerm, ngrid: int
|
|
183
|
+
) -> dict[str, jax.typing.ArrayLike]:
|
|
184
|
+
if isinstance(term, TPTerm):
|
|
185
|
+
inputs = {k: v.value for k, v in term.input_obs.items()}
|
|
186
|
+
return grid_nd(inputs, ngrid)
|
|
187
|
+
|
|
188
|
+
if not isinstance(term.basis.x, lsl.TransientCalc | lsl.Calc):
|
|
189
|
+
raise NotImplementedError(
|
|
190
|
+
"Function not implemented for bases with inputs of "
|
|
191
|
+
f"type {type(term.basis.x)}."
|
|
192
|
+
)
|
|
193
|
+
inputs = {n.var.name: n.var.value for n in term.basis.x.all_input_nodes()} # type: ignore
|
|
194
|
+
return grid_nd(inputs, ngrid)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
# using q_0.05 and q_0.95 explicitly here
|
|
198
|
+
# even though users could choose to return other quantiles like 0.1 and 0.9
|
|
199
|
+
# then they can supply q_0.1 and q_0.9, etc.
|
|
200
|
+
PlotVars = Literal[
|
|
201
|
+
"mean", "sd", "var", "hdi_low", "hdi_high", "q_0.05", "q_0.5", "q_0.95"
|
|
202
|
+
]
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def summarise_nd_smooth(
|
|
206
|
+
term: Term | TPTerm,
|
|
207
|
+
samples: Mapping[str, jax.Array],
|
|
208
|
+
newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
|
|
209
|
+
ngrid: int = 20,
|
|
210
|
+
which: PlotVars | Sequence[PlotVars] = "mean",
|
|
211
|
+
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
|
|
212
|
+
hdi_prob: float = 0.9,
|
|
213
|
+
newdata_meshgrid: bool = False,
|
|
214
|
+
):
|
|
215
|
+
if isinstance(which, str):
|
|
216
|
+
which = [which]
|
|
217
|
+
|
|
218
|
+
if newdata is None:
|
|
219
|
+
newdata_x: Mapping[str, ArrayLike] = input_grid_nd_smooth(term, ngrid=ngrid)
|
|
220
|
+
elif newdata_meshgrid:
|
|
221
|
+
full_grid_arrays = [v.flatten() for v in np.meshgrid(*newdata.values())]
|
|
222
|
+
newdata_x = dict(zip(newdata.keys(), full_grid_arrays))
|
|
223
|
+
else:
|
|
224
|
+
newdata_x = newdata
|
|
225
|
+
|
|
226
|
+
newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
|
|
227
|
+
|
|
228
|
+
term_samples = term.predict(samples, newdata=newdata_x)
|
|
229
|
+
|
|
230
|
+
ci_quantiles_ = (0.05, 0.95) if quantiles is None else quantiles
|
|
231
|
+
hdi_prob_ = 0.9 if hdi_prob is None else hdi_prob
|
|
232
|
+
term_summary = (
|
|
233
|
+
gs.SamplesSummary.from_array(
|
|
234
|
+
term_samples, name=term.name, quantiles=ci_quantiles_, hdi_prob=hdi_prob_
|
|
235
|
+
)
|
|
236
|
+
.to_dataframe()
|
|
237
|
+
.reset_index()
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
for k, v in newdata_x.items():
|
|
241
|
+
term_summary[k] = np.asarray(v)
|
|
242
|
+
|
|
243
|
+
term_summary.reset_index(inplace=True)
|
|
244
|
+
term_summary = term_summary.melt(
|
|
245
|
+
id_vars=["index"] + list(newdata_x.keys()),
|
|
246
|
+
value_vars=which,
|
|
247
|
+
var_name="variable",
|
|
248
|
+
value_name="value",
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
term_summary["variable"] = pd.Categorical(
|
|
252
|
+
term_summary["variable"], categories=which
|
|
253
|
+
)
|
|
254
|
+
return term_summary
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def plot_2d_smooth(
|
|
258
|
+
term: TPTerm | Term,
|
|
259
|
+
samples: Mapping[str, jax.Array],
|
|
260
|
+
newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
|
|
261
|
+
ngrid: int = 20,
|
|
262
|
+
which: PlotVars | Sequence[PlotVars] = "mean",
|
|
263
|
+
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
|
|
264
|
+
hdi_prob: float = 0.9,
|
|
265
|
+
newdata_meshgrid: bool = False,
|
|
266
|
+
):
|
|
267
|
+
if isinstance(term, TPTerm):
|
|
268
|
+
names = list(term.input_obs)
|
|
269
|
+
if len(names) != 2:
|
|
270
|
+
raise ValueError(
|
|
271
|
+
f"'plot_2d_smooth' can only handle smooths with two inputs, "
|
|
272
|
+
f"got {len(names)} for smooth {term}: {names}"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
for v in term.input_obs.values():
|
|
276
|
+
if jnp.issubdtype(v.value, jnp.integer):
|
|
277
|
+
raise TypeError(
|
|
278
|
+
"'plot_2d_smooth' expects continuous marginals, got "
|
|
279
|
+
f"type {v.value.dtype} for {v}"
|
|
280
|
+
)
|
|
281
|
+
else:
|
|
282
|
+
names = [n.var.name for n in term.basis.x.all_input_nodes()] # type: ignore
|
|
283
|
+
|
|
284
|
+
term_summary = summarise_nd_smooth(
|
|
285
|
+
term=term,
|
|
286
|
+
samples=samples,
|
|
287
|
+
newdata=newdata,
|
|
288
|
+
ngrid=ngrid,
|
|
289
|
+
which=which,
|
|
290
|
+
quantiles=quantiles,
|
|
291
|
+
hdi_prob=hdi_prob,
|
|
292
|
+
newdata_meshgrid=newdata_meshgrid,
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
p = (
|
|
296
|
+
p9.ggplot(term_summary)
|
|
297
|
+
+ p9.labs(title=f"Posterior summary of {term.name}")
|
|
298
|
+
+ p9.aes(*names, fill="value")
|
|
299
|
+
+ p9.facet_wrap("~variable", labeller="label_both")
|
|
300
|
+
)
|
|
301
|
+
|
|
302
|
+
p = p + p9.geom_tile()
|
|
303
|
+
|
|
304
|
+
return p
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def polys_to_df(polys: Mapping[str, ArrayLike]):
|
|
308
|
+
poly_labels = list(polys)
|
|
309
|
+
poly_coords = list(polys.values())
|
|
310
|
+
poly_coord_dim = np.shape(poly_coords[0])[-1]
|
|
311
|
+
poly_df = pd.concat(
|
|
312
|
+
[
|
|
313
|
+
pd.DataFrame(
|
|
314
|
+
poly_coords[i], columns=[f"V{i}" for i in range(poly_coord_dim)]
|
|
315
|
+
).assign(vertex=lambda df: df.index + 1, id=i, label=poly_labels[i])
|
|
316
|
+
for i in range(len(polys))
|
|
317
|
+
],
|
|
318
|
+
ignore_index=True,
|
|
319
|
+
)
|
|
320
|
+
return poly_df
|
|
321
|
+
|
|
322
|
+
|
|
323
|
+
def plot_polys(
|
|
324
|
+
region: str,
|
|
325
|
+
which: str | Sequence[str],
|
|
326
|
+
df: pd.DataFrame,
|
|
327
|
+
polys: Mapping[str, ArrayLike],
|
|
328
|
+
show_unobserved: bool = True,
|
|
329
|
+
observed_color: str = "none",
|
|
330
|
+
unobserved_color: str = "red",
|
|
331
|
+
) -> p9.ggplot:
|
|
332
|
+
if isinstance(which, str):
|
|
333
|
+
which = [which]
|
|
334
|
+
|
|
335
|
+
poly_df = polys_to_df(polys)
|
|
336
|
+
|
|
337
|
+
df["label"] = df[region].astype(str)
|
|
338
|
+
# plot_df = df.merge(poly_df, on="label")
|
|
339
|
+
|
|
340
|
+
if "observed" not in df.columns:
|
|
341
|
+
df["observed"] = True
|
|
342
|
+
|
|
343
|
+
if df["observed"].all():
|
|
344
|
+
show_unobserved = False
|
|
345
|
+
|
|
346
|
+
plot_df = poly_df.merge(df, on="label")
|
|
347
|
+
|
|
348
|
+
plot_df = plot_df.melt(
|
|
349
|
+
id_vars=["label", "V0", "V1", "observed"],
|
|
350
|
+
value_vars=which,
|
|
351
|
+
var_name="variable",
|
|
352
|
+
value_name="value",
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
plot_df["variable"] = pd.Categorical(plot_df["variable"], categories=which)
|
|
356
|
+
|
|
357
|
+
p = (
|
|
358
|
+
p9.ggplot(plot_df)
|
|
359
|
+
+ p9.aes("V0", "V1", group="label", fill="value")
|
|
360
|
+
+ p9.aes(color="observed")
|
|
361
|
+
+ p9.facet_wrap("~variable", labeller="label_both")
|
|
362
|
+
+ p9.scale_color_manual({True: observed_color, False: unobserved_color})
|
|
363
|
+
+ p9.guides(color=p9.guide_legend(override_aes={"fill": None}))
|
|
364
|
+
)
|
|
365
|
+
if show_unobserved:
|
|
366
|
+
p = p + p9.geom_polygon()
|
|
367
|
+
else:
|
|
368
|
+
p = p + p9.geom_polygon(data=plot_df.query("observed == True"))
|
|
369
|
+
p = p + p9.geom_polygon(data=plot_df.query("observed == False"), fill="none")
|
|
370
|
+
|
|
371
|
+
return p
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def _convert_to_integers(
|
|
375
|
+
grid: np.typing.NDArray,
|
|
376
|
+
labels: Sequence[str] | CategoryMapping | None,
|
|
377
|
+
term: RITerm | MRFTerm | lsl.Var,
|
|
378
|
+
) -> np.typing.NDArray[np.int_]:
|
|
379
|
+
if isinstance(labels, CategoryMapping):
|
|
380
|
+
grid = labels.to_integers(grid)
|
|
381
|
+
else:
|
|
382
|
+
try:
|
|
383
|
+
grid = term.mapping.to_integers(grid) # type: ignore
|
|
384
|
+
except (ValueError, AttributeError):
|
|
385
|
+
if not np.issubdtype(grid.dtype, np.integer):
|
|
386
|
+
raise TypeError(
|
|
387
|
+
f"There's no mapping available on the term {term}. "
|
|
388
|
+
"In this case, its values in 'newdata' must be specified "
|
|
389
|
+
f"as integer codes. Got data type {grid.dtype}"
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
return grid
|
|
393
|
+
|
|
394
|
+
|
|
395
|
+
def summarise_cluster(
|
|
396
|
+
term: RITerm | MRFTerm | Term,
|
|
397
|
+
samples: Mapping[str, jax.Array],
|
|
398
|
+
newdata: gs.Position
|
|
399
|
+
| None
|
|
400
|
+
| Mapping[str, ArrayLike | Sequence[int] | Sequence[str]] = None,
|
|
401
|
+
labels: CategoryMapping | Sequence[str] | None = None,
|
|
402
|
+
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
|
|
403
|
+
hdi_prob: float = 0.9,
|
|
404
|
+
) -> pd.DataFrame:
|
|
405
|
+
if labels is None:
|
|
406
|
+
try:
|
|
407
|
+
labels = term.mapping # type: ignore
|
|
408
|
+
except (AttributeError, ValueError):
|
|
409
|
+
labels = None
|
|
410
|
+
|
|
411
|
+
if newdata is None and isinstance(labels, CategoryMapping):
|
|
412
|
+
grid = np.asarray(list(labels.integers_to_labels_map))
|
|
413
|
+
unique_x = np.unique(term.basis.x.value)
|
|
414
|
+
newdata_x: Mapping[str, ArrayLike] = {term.basis.x.name: grid}
|
|
415
|
+
observed = [x in unique_x for x in grid]
|
|
416
|
+
elif newdata is None:
|
|
417
|
+
grid = np.unique(term.basis.x.value)
|
|
418
|
+
newdata_x = {term.basis.x.name: grid}
|
|
419
|
+
observed = [True for _ in grid]
|
|
420
|
+
else:
|
|
421
|
+
unique_x = np.unique(term.basis.x.value)
|
|
422
|
+
grid = np.asarray(newdata[term.basis.x.name])
|
|
423
|
+
grid = _convert_to_integers(grid, labels, term)
|
|
424
|
+
|
|
425
|
+
observed = [x in unique_x for x in grid]
|
|
426
|
+
newdata_x = {term.basis.x.name: grid}
|
|
427
|
+
|
|
428
|
+
newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
|
|
429
|
+
predictions = term.predict(samples=samples, newdata=newdata_x)
|
|
430
|
+
predictions_summary = (
|
|
431
|
+
gs.SamplesSummary.from_array(
|
|
432
|
+
predictions,
|
|
433
|
+
quantiles=quantiles,
|
|
434
|
+
hdi_prob=0.9 if hdi_prob is None else hdi_prob,
|
|
435
|
+
)
|
|
436
|
+
.to_dataframe()
|
|
437
|
+
.reset_index()
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
if isinstance(labels, CategoryMapping):
|
|
441
|
+
codes = newdata_x[term.basis.x.name]
|
|
442
|
+
labels_str = list(labels.integers_to_labels(codes))
|
|
443
|
+
categories = list(labels.labels_to_integers_map)
|
|
444
|
+
predictions_summary[term.basis.x.name] = pd.Categorical(
|
|
445
|
+
labels_str, categories=categories
|
|
446
|
+
)
|
|
447
|
+
elif labels is not None:
|
|
448
|
+
labels_str = list(labels)
|
|
449
|
+
categories = sorted(set(labels_str))
|
|
450
|
+
predictions_summary[term.basis.x.name] = pd.Categorical(
|
|
451
|
+
labels_str, categories=categories
|
|
452
|
+
)
|
|
453
|
+
else:
|
|
454
|
+
predictions_summary[term.basis.x.name] = pd.Categorical(
|
|
455
|
+
np.asarray(term.basis.x.value)
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
predictions_summary["observed"] = observed
|
|
459
|
+
|
|
460
|
+
return predictions_summary
|
|
461
|
+
|
|
462
|
+
|
|
463
|
+
def summarise_regions(
|
|
464
|
+
term: RITerm | MRFTerm | Term,
|
|
465
|
+
samples: Mapping[str, jax.Array],
|
|
466
|
+
newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
|
|
467
|
+
which: PlotVars | Sequence[PlotVars] = "mean",
|
|
468
|
+
polys: Mapping[str, ArrayLike] | None = None,
|
|
469
|
+
labels: CategoryMapping | Sequence[str] | None = None,
|
|
470
|
+
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
|
|
471
|
+
hdi_prob: float = 0.9,
|
|
472
|
+
) -> pd.DataFrame:
|
|
473
|
+
polygons = None
|
|
474
|
+
if polys is not None:
|
|
475
|
+
polygons = polys
|
|
476
|
+
else:
|
|
477
|
+
try:
|
|
478
|
+
# using type ignore here, since the case of term not having the attribute
|
|
479
|
+
# polygons is handle by the try except
|
|
480
|
+
polygons = term.polygons # type: ignore
|
|
481
|
+
except AttributeError:
|
|
482
|
+
pass
|
|
483
|
+
|
|
484
|
+
if not polygons:
|
|
485
|
+
raise ValueError(
|
|
486
|
+
"When passing a term without polygons, polygons must be supplied manually "
|
|
487
|
+
"through the argument 'polys'"
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
df = summarise_cluster(
|
|
491
|
+
term=term,
|
|
492
|
+
samples=samples,
|
|
493
|
+
newdata=newdata,
|
|
494
|
+
labels=labels,
|
|
495
|
+
quantiles=quantiles,
|
|
496
|
+
hdi_prob=hdi_prob,
|
|
497
|
+
)
|
|
498
|
+
region = term.basis.x.name
|
|
499
|
+
if isinstance(which, str):
|
|
500
|
+
which = [which]
|
|
501
|
+
|
|
502
|
+
unique_labels_in_df = df[term.basis.x.name].unique().tolist()
|
|
503
|
+
assert polygons is not None
|
|
504
|
+
for region_label in polygons:
|
|
505
|
+
if region_label not in unique_labels_in_df:
|
|
506
|
+
raise ValueError(
|
|
507
|
+
f"Label '{region_label}' found in polys, but not in cluster summary. "
|
|
508
|
+
f"Known labels: {unique_labels_in_df}"
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
poly_df = polys_to_df(polygons)
|
|
512
|
+
|
|
513
|
+
df["label"] = df[region].astype(str)
|
|
514
|
+
|
|
515
|
+
if "observed" not in df.columns:
|
|
516
|
+
df["observed"] = True
|
|
517
|
+
|
|
518
|
+
plot_df = poly_df.merge(df, on="label")
|
|
519
|
+
|
|
520
|
+
plot_df = plot_df.melt(
|
|
521
|
+
id_vars=["label", "V0", "V1", "observed"],
|
|
522
|
+
value_vars=which,
|
|
523
|
+
var_name="variable",
|
|
524
|
+
value_name="value",
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
plot_df["variable"] = pd.Categorical(plot_df["variable"], categories=which)
|
|
528
|
+
|
|
529
|
+
return plot_df
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def plot_regions(
|
|
533
|
+
term: RITerm | MRFTerm | Term,
|
|
534
|
+
samples: Mapping[str, jax.Array],
|
|
535
|
+
newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
|
|
536
|
+
which: PlotVars | Sequence[PlotVars] = "mean",
|
|
537
|
+
polys: Mapping[str, ArrayLike] | None = None,
|
|
538
|
+
labels: CategoryMapping | None = None,
|
|
539
|
+
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
|
|
540
|
+
hdi_prob: float = 0.9,
|
|
541
|
+
show_unobserved: bool = True,
|
|
542
|
+
observed_color: str = "none",
|
|
543
|
+
unobserved_color: str = "red",
|
|
544
|
+
) -> p9.ggplot:
|
|
545
|
+
plot_df = summarise_regions(
|
|
546
|
+
term=term,
|
|
547
|
+
samples=samples,
|
|
548
|
+
newdata=newdata,
|
|
549
|
+
which=which,
|
|
550
|
+
polys=polys,
|
|
551
|
+
labels=labels,
|
|
552
|
+
quantiles=quantiles,
|
|
553
|
+
hdi_prob=hdi_prob,
|
|
554
|
+
)
|
|
555
|
+
p = (
|
|
556
|
+
p9.ggplot(plot_df)
|
|
557
|
+
+ p9.aes("V0", "V1", group="label", fill="value")
|
|
558
|
+
+ p9.aes(color="observed")
|
|
559
|
+
+ p9.facet_wrap("~variable", labeller="label_both")
|
|
560
|
+
+ p9.scale_color_manual({True: observed_color, False: unobserved_color})
|
|
561
|
+
+ p9.guides(color=p9.guide_legend(override_aes={"fill": None}))
|
|
562
|
+
)
|
|
563
|
+
if show_unobserved:
|
|
564
|
+
p = p + p9.geom_polygon()
|
|
565
|
+
else:
|
|
566
|
+
p = p + p9.geom_polygon(data=plot_df.query("observed == True"))
|
|
567
|
+
p = p + p9.geom_polygon(data=plot_df.query("observed == False"), fill="none")
|
|
568
|
+
|
|
569
|
+
p += p9.labs(title=f"Plot of {term.name}")
|
|
570
|
+
return p
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def plot_forest(
|
|
574
|
+
term: RITerm | MRFTerm | LinTerm,
|
|
575
|
+
samples: Mapping[str, jax.Array],
|
|
576
|
+
newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
|
|
577
|
+
labels: CategoryMapping | None = None,
|
|
578
|
+
ymin: str = "hdi_low",
|
|
579
|
+
ymax: str = "hdi_high",
|
|
580
|
+
ci_quantiles: tuple[float, float] = (0.05, 0.95),
|
|
581
|
+
hdi_prob: float = 0.9,
|
|
582
|
+
show_unobserved: bool = True,
|
|
583
|
+
highlight_unobserved: bool = True,
|
|
584
|
+
unobserved_color: str = "red",
|
|
585
|
+
indices: Sequence[int] | None = None,
|
|
586
|
+
) -> p9.ggplot:
|
|
587
|
+
if isinstance(term, RITerm | MRFTerm):
|
|
588
|
+
return plot_forest_clustered(
|
|
589
|
+
term=term,
|
|
590
|
+
samples=samples,
|
|
591
|
+
newdata=newdata,
|
|
592
|
+
labels=labels,
|
|
593
|
+
ymin=ymin,
|
|
594
|
+
ymax=ymax,
|
|
595
|
+
ci_quantiles=ci_quantiles,
|
|
596
|
+
hdi_prob=hdi_prob,
|
|
597
|
+
show_unobserved=show_unobserved,
|
|
598
|
+
highlight_unobserved=highlight_unobserved,
|
|
599
|
+
unobserved_color=unobserved_color,
|
|
600
|
+
indices=indices,
|
|
601
|
+
)
|
|
602
|
+
elif isinstance(term, LinTerm):
|
|
603
|
+
return plot_forest_lin(
|
|
604
|
+
term=term,
|
|
605
|
+
samples=samples,
|
|
606
|
+
ymin=ymin,
|
|
607
|
+
ymax=ymax,
|
|
608
|
+
ci_quantiles=ci_quantiles,
|
|
609
|
+
hdi_prob=hdi_prob,
|
|
610
|
+
indices=indices,
|
|
611
|
+
)
|
|
612
|
+
else:
|
|
613
|
+
raise TypeError(f"term has unsupported type {type(term)}.")
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def summarise_lin(
|
|
617
|
+
term: LinTerm,
|
|
618
|
+
samples: Mapping[str, jax.Array],
|
|
619
|
+
quantiles: Sequence[float] = (0.05, 0.5, 0.95),
|
|
620
|
+
hdi_prob: float = 0.9,
|
|
621
|
+
indices: Sequence[int] | None = None,
|
|
622
|
+
) -> pd.DataFrame:
|
|
623
|
+
if indices is not None:
|
|
624
|
+
coef_samples = samples[term.coef.name][..., indices]
|
|
625
|
+
colnames = [term.column_names[i] for i in indices]
|
|
626
|
+
else:
|
|
627
|
+
coef_samples = samples[term.coef.name]
|
|
628
|
+
colnames = term.column_names
|
|
629
|
+
|
|
630
|
+
df = (
|
|
631
|
+
gs.SamplesSummary.from_array(
|
|
632
|
+
coef_samples, quantiles=quantiles, hdi_prob=hdi_prob
|
|
633
|
+
)
|
|
634
|
+
.to_dataframe()
|
|
635
|
+
.reset_index()
|
|
636
|
+
)
|
|
637
|
+
|
|
638
|
+
df["x"] = colnames
|
|
639
|
+
df.drop(["variable", "var_fqn", "var_index"], axis=1, inplace=True)
|
|
640
|
+
df.insert(0, "x", df.pop("x"))
|
|
641
|
+
return df
|
|
642
|
+
|
|
643
|
+
|
|
644
|
+
def plot_forest_lin(
|
|
645
|
+
term: LinTerm,
|
|
646
|
+
samples: Mapping[str, jax.Array],
|
|
647
|
+
ymin: str = "hdi_low",
|
|
648
|
+
ymax: str = "hdi_high",
|
|
649
|
+
ci_quantiles: tuple[float, float] = (0.05, 0.95),
|
|
650
|
+
hdi_prob: float = 0.9,
|
|
651
|
+
indices: Sequence[int] | None = None,
|
|
652
|
+
) -> p9.ggplot:
|
|
653
|
+
df = summarise_lin(
|
|
654
|
+
term=term,
|
|
655
|
+
samples=samples,
|
|
656
|
+
quantiles=ci_quantiles,
|
|
657
|
+
hdi_prob=hdi_prob,
|
|
658
|
+
indices=indices,
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
df[ymin] = df[ymin].astype(df["mean"].dtype)
|
|
662
|
+
df[ymax] = df[ymax].astype(df["mean"].dtype)
|
|
663
|
+
|
|
664
|
+
p = (
|
|
665
|
+
p9.ggplot(df)
|
|
666
|
+
+ p9.aes("x", "mean")
|
|
667
|
+
+ p9.geom_hline(yintercept=0, color="grey")
|
|
668
|
+
+ p9.geom_linerange(p9.aes(ymin=ymin, ymax=ymax), color="grey")
|
|
669
|
+
+ p9.geom_point()
|
|
670
|
+
+ p9.coord_flip()
|
|
671
|
+
+ p9.labs(x="x")
|
|
672
|
+
)
|
|
673
|
+
|
|
674
|
+
p += p9.labs(title=f"Posterior summary of {term.name}")
|
|
675
|
+
|
|
676
|
+
return p
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
def plot_forest_clustered(
|
|
680
|
+
term: RITerm | MRFTerm | Term,
|
|
681
|
+
samples: Mapping[str, jax.Array],
|
|
682
|
+
newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
|
|
683
|
+
labels: CategoryMapping | None = None,
|
|
684
|
+
ymin: str = "hdi_low",
|
|
685
|
+
ymax: str = "hdi_high",
|
|
686
|
+
ci_quantiles: tuple[float, float] = (0.05, 0.95),
|
|
687
|
+
hdi_prob: float = 0.9,
|
|
688
|
+
show_unobserved: bool = True,
|
|
689
|
+
highlight_unobserved: bool = True,
|
|
690
|
+
unobserved_color: str = "red",
|
|
691
|
+
indices: Sequence[int] | None = None,
|
|
692
|
+
) -> p9.ggplot:
|
|
693
|
+
if labels is None:
|
|
694
|
+
try:
|
|
695
|
+
labels = term.mapping # type: ignore
|
|
696
|
+
except AttributeError:
|
|
697
|
+
labels = None
|
|
698
|
+
|
|
699
|
+
df = summarise_cluster(
|
|
700
|
+
term=term,
|
|
701
|
+
samples=samples,
|
|
702
|
+
newdata=newdata,
|
|
703
|
+
labels=labels,
|
|
704
|
+
quantiles=ci_quantiles,
|
|
705
|
+
hdi_prob=hdi_prob,
|
|
706
|
+
)
|
|
707
|
+
cluster = term.basis.x.name
|
|
708
|
+
|
|
709
|
+
if labels is None:
|
|
710
|
+
xlab = cluster + " (indices)"
|
|
711
|
+
else:
|
|
712
|
+
xlab = cluster + " (labels)"
|
|
713
|
+
|
|
714
|
+
df[ymin] = df[ymin].astype(df["mean"].dtype)
|
|
715
|
+
df[ymax] = df[ymax].astype(df["mean"].dtype)
|
|
716
|
+
|
|
717
|
+
if indices is not None:
|
|
718
|
+
df = df.iloc[indices, :]
|
|
719
|
+
|
|
720
|
+
if not show_unobserved:
|
|
721
|
+
df = df.query("observed == True")
|
|
722
|
+
|
|
723
|
+
p = (
|
|
724
|
+
p9.ggplot(df)
|
|
725
|
+
+ p9.aes(cluster, "mean")
|
|
726
|
+
+ p9.geom_hline(yintercept=0, color="grey")
|
|
727
|
+
+ p9.geom_linerange(p9.aes(ymin=ymin, ymax=ymax), color="grey")
|
|
728
|
+
+ p9.geom_point()
|
|
729
|
+
+ p9.coord_flip()
|
|
730
|
+
+ p9.labs(x=xlab)
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
if highlight_unobserved:
|
|
734
|
+
df_uo = df.query("observed == False")
|
|
735
|
+
p = p + p9.geom_point(
|
|
736
|
+
p9.aes(cluster, "mean"),
|
|
737
|
+
color=unobserved_color,
|
|
738
|
+
shape="x",
|
|
739
|
+
data=df_uo,
|
|
740
|
+
)
|
|
741
|
+
|
|
742
|
+
p += p9.labs(title=f"Posterior summary of {term.name}")
|
|
743
|
+
|
|
744
|
+
return p
|
|
745
|
+
|
|
746
|
+
|
|
747
|
+
def summarise_1d_smooth_clustered(
|
|
748
|
+
clustered_term: lsl.Var,
|
|
749
|
+
samples: Mapping[str, jax.Array],
|
|
750
|
+
ngrid: int = 20,
|
|
751
|
+
newdata: gs.Position
|
|
752
|
+
| None
|
|
753
|
+
| Mapping[str, ArrayLike | Sequence[int] | Sequence[str]] = None,
|
|
754
|
+
which: PlotVars | Sequence[PlotVars] = "mean",
|
|
755
|
+
ci_quantiles: Sequence[float] = (0.05, 0.5, 0.95),
|
|
756
|
+
hdi_prob: float = 0.9,
|
|
757
|
+
labels: CategoryMapping | None | Sequence[str] = None,
|
|
758
|
+
newdata_meshgrid: bool = False,
|
|
759
|
+
):
|
|
760
|
+
if isinstance(which, str):
|
|
761
|
+
which = [which]
|
|
762
|
+
|
|
763
|
+
term = clustered_term.value_node["x"]
|
|
764
|
+
cluster = clustered_term.value_node["cluster"]
|
|
765
|
+
|
|
766
|
+
assert isinstance(term, Term | lsl.Var)
|
|
767
|
+
assert isinstance(cluster, RITerm | MRFTerm)
|
|
768
|
+
|
|
769
|
+
if labels is None:
|
|
770
|
+
try:
|
|
771
|
+
labels = cluster.mapping # type: ignore
|
|
772
|
+
except (AttributeError, ValueError):
|
|
773
|
+
labels = None
|
|
774
|
+
|
|
775
|
+
if isinstance(term, Term):
|
|
776
|
+
x = term.basis.x
|
|
777
|
+
else:
|
|
778
|
+
x = term
|
|
779
|
+
|
|
780
|
+
if newdata is None:
|
|
781
|
+
if not jnp.issubdtype(x.value.dtype, jnp.floating):
|
|
782
|
+
raise TypeError(
|
|
783
|
+
"Automatic grid creation is valid only for continuous x, got "
|
|
784
|
+
f"dtype {jnp.dtype(x.value)} for {x}."
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
if newdata is None and isinstance(labels, CategoryMapping):
|
|
788
|
+
cgrid = np.asarray(list(labels.integers_to_labels_map)) # integer codes
|
|
789
|
+
unique_clusters = np.unique(cluster.basis.x.value) # unique codes
|
|
790
|
+
|
|
791
|
+
if isinstance(x, lsl.Node) or x.strong:
|
|
792
|
+
xgrid: Mapping[str, ArrayLike] = {
|
|
793
|
+
x.name: jnp.linspace(x.value.min(), x.value.max(), ngrid)
|
|
794
|
+
}
|
|
795
|
+
else:
|
|
796
|
+
assert isinstance(term, Term | LinTerm), (
|
|
797
|
+
f"Wrong type for term: {type(term)}"
|
|
798
|
+
)
|
|
799
|
+
ncols = jnp.shape(term.basis.value)[-1]
|
|
800
|
+
xgrid = input_grid_nd_smooth(term, ngrid=int(np.pow(ngrid, 1 / ncols)))
|
|
801
|
+
|
|
802
|
+
grid: Mapping[str, ArrayLike | Sequence[int] | Sequence[str]] = dict(xgrid) | {
|
|
803
|
+
cluster.basis.x.name: cgrid
|
|
804
|
+
}
|
|
805
|
+
|
|
806
|
+
# code : bool
|
|
807
|
+
observed = {x: x in unique_clusters for x in cgrid}
|
|
808
|
+
elif newdata is None:
|
|
809
|
+
cgrid = np.unique(cluster.basis.x.value)
|
|
810
|
+
if isinstance(x, lsl.Node) or x.strong:
|
|
811
|
+
xgrid = {x.name: jnp.linspace(x.value.min(), x.value.max(), ngrid)}
|
|
812
|
+
else:
|
|
813
|
+
assert isinstance(term, Term | LinTerm), (
|
|
814
|
+
f"Wrong type for term: {type(term)}"
|
|
815
|
+
)
|
|
816
|
+
ncols = jnp.shape(term.basis.value)[-1]
|
|
817
|
+
xgrid = input_grid_nd_smooth(term, ngrid=int(np.pow(ngrid, 1 / ncols)))
|
|
818
|
+
|
|
819
|
+
grid = xgrid | {cluster.basis.x.name: cgrid}
|
|
820
|
+
|
|
821
|
+
# code : bool
|
|
822
|
+
observed = {x: True for x in cgrid}
|
|
823
|
+
else:
|
|
824
|
+
pass
|
|
825
|
+
|
|
826
|
+
if newdata is not None and newdata_meshgrid:
|
|
827
|
+
cgrid = np.asarray(newdata[cluster.basis.x.name])
|
|
828
|
+
cgrid = _convert_to_integers(cgrid, labels, cluster)
|
|
829
|
+
|
|
830
|
+
grid = {x.name: newdata[x.name], cluster.basis.x.name: cgrid}
|
|
831
|
+
full_grid_arrays = [v.flatten() for v in np.meshgrid(*grid.values())]
|
|
832
|
+
newdata_x: dict[str, ArrayLike | Sequence[int] | Sequence[str]] = dict(
|
|
833
|
+
zip(grid.keys(), full_grid_arrays)
|
|
834
|
+
)
|
|
835
|
+
|
|
836
|
+
if isinstance(labels, CategoryMapping):
|
|
837
|
+
observed = {x: x in cluster.basis.x.value for x in cgrid}
|
|
838
|
+
else:
|
|
839
|
+
observed = {x: True for x in cgrid}
|
|
840
|
+
elif newdata is not None:
|
|
841
|
+
cgrid = np.asarray(newdata[cluster.basis.x.name])
|
|
842
|
+
cgrid = _convert_to_integers(cgrid, labels, cluster)
|
|
843
|
+
newdata_x = {x.name: newdata[x.name], cluster.basis.x.name: cgrid}
|
|
844
|
+
# code : bool
|
|
845
|
+
if isinstance(labels, CategoryMapping):
|
|
846
|
+
observed = {x: x in cluster.basis.x.value for x in cgrid}
|
|
847
|
+
else:
|
|
848
|
+
observed = {x: True for x in cgrid}
|
|
849
|
+
else: # then we use the grid created from observed data
|
|
850
|
+
full_grid_arrays = [v.flatten() for v in np.meshgrid(*grid.values())]
|
|
851
|
+
newdata_x = dict(zip(grid.keys(), full_grid_arrays))
|
|
852
|
+
|
|
853
|
+
newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
|
|
854
|
+
|
|
855
|
+
term_samples = clustered_term.predict(samples, newdata=newdata_x)
|
|
856
|
+
term_summary = (
|
|
857
|
+
gs.SamplesSummary.from_array(
|
|
858
|
+
term_samples,
|
|
859
|
+
name=clustered_term.name,
|
|
860
|
+
quantiles=ci_quantiles,
|
|
861
|
+
hdi_prob=hdi_prob,
|
|
862
|
+
)
|
|
863
|
+
.to_dataframe()
|
|
864
|
+
.reset_index()
|
|
865
|
+
)
|
|
866
|
+
|
|
867
|
+
for k, v in newdata_x.items():
|
|
868
|
+
term_summary[k] = np.asarray(v)
|
|
869
|
+
|
|
870
|
+
if labels is not None:
|
|
871
|
+
if isinstance(labels, CategoryMapping):
|
|
872
|
+
labels_long = labels.to_labels(newdata_x[cluster.basis.x.name])
|
|
873
|
+
categories = list(labels.labels_to_integers_map)
|
|
874
|
+
term_summary[cluster.basis.x.name] = pd.Categorical(
|
|
875
|
+
labels_long, categories=categories
|
|
876
|
+
)
|
|
877
|
+
else:
|
|
878
|
+
term_summary[cluster.basis.x.name] = labels
|
|
879
|
+
|
|
880
|
+
term_summary["observed"] = [
|
|
881
|
+
observed[x] for x in np.asarray(newdata_x[cluster.basis.x.name])
|
|
882
|
+
]
|
|
883
|
+
|
|
884
|
+
term_summary.reset_index(inplace=True)
|
|
885
|
+
|
|
886
|
+
return term_summary
|
|
887
|
+
|
|
888
|
+
|
|
889
|
+
def plot_1d_smooth_clustered(
|
|
890
|
+
clustered_term: lsl.Var,
|
|
891
|
+
samples: Mapping[str, jax.Array],
|
|
892
|
+
ngrid: int = 20,
|
|
893
|
+
newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
|
|
894
|
+
labels: CategoryMapping | None = None,
|
|
895
|
+
color_scale: str = "viridis",
|
|
896
|
+
newdata_meshgrid: bool = False,
|
|
897
|
+
):
|
|
898
|
+
ci_quantiles = (0.05, 0.5, 0.95)
|
|
899
|
+
hdi_prob = 0.9
|
|
900
|
+
|
|
901
|
+
term = clustered_term.value_node["x"]
|
|
902
|
+
cluster = clustered_term.value_node["cluster"]
|
|
903
|
+
|
|
904
|
+
assert isinstance(term, Term | lsl.Var)
|
|
905
|
+
assert isinstance(cluster, RITerm | MRFTerm)
|
|
906
|
+
|
|
907
|
+
if labels is None:
|
|
908
|
+
try:
|
|
909
|
+
labels = cluster.mapping # type: ignore
|
|
910
|
+
except AttributeError:
|
|
911
|
+
labels = None
|
|
912
|
+
|
|
913
|
+
term_summary = summarise_1d_smooth_clustered(
|
|
914
|
+
clustered_term=clustered_term,
|
|
915
|
+
samples=samples,
|
|
916
|
+
ngrid=ngrid,
|
|
917
|
+
ci_quantiles=ci_quantiles,
|
|
918
|
+
hdi_prob=hdi_prob,
|
|
919
|
+
labels=labels,
|
|
920
|
+
newdata=newdata,
|
|
921
|
+
newdata_meshgrid=newdata_meshgrid,
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
if labels is None:
|
|
925
|
+
clab = cluster.basis.x.name + " (indices)"
|
|
926
|
+
else:
|
|
927
|
+
clab = cluster.basis.x.name + " (labels)"
|
|
928
|
+
|
|
929
|
+
if isinstance(term, Term):
|
|
930
|
+
x = term.basis.x
|
|
931
|
+
else:
|
|
932
|
+
x = term
|
|
933
|
+
|
|
934
|
+
p = (
|
|
935
|
+
p9.ggplot(term_summary)
|
|
936
|
+
+ p9.aes(x.name, "mean", group=cluster.basis.x.name)
|
|
937
|
+
+ p9.aes(color=cluster.basis.x.name)
|
|
938
|
+
+ p9.labs(
|
|
939
|
+
title=f"Posterior summary of {clustered_term.name}", x=x.name, color=clab
|
|
940
|
+
)
|
|
941
|
+
+ p9.facet_wrap("~variable", labeller="label_both")
|
|
942
|
+
+ p9.scale_color_cmap_d(color_scale)
|
|
943
|
+
+ p9.geom_line()
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
return p
|