liesel-gam 0.0.4__py3-none-any.whl → 0.0.6a4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
liesel_gam/plots.py ADDED
@@ -0,0 +1,946 @@
1
+ from collections.abc import Mapping, Sequence
2
+ from typing import Any, Literal
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import liesel.goose as gs
7
+ import liesel.model as lsl
8
+ import numpy as np
9
+ import pandas as pd
10
+ import plotnine as p9
11
+ from jax import Array
12
+ from jax.typing import ArrayLike
13
+
14
+ from .builder.registry import CategoryMapping
15
+ from .var import LinTerm, MRFTerm, RITerm, Term, TPTerm
16
+
17
+ KeyArray = Any
18
+
19
+
20
+ def summarise_by_samples(
21
+ key: KeyArray, a: Any, name: str, n: int = 100
22
+ ) -> pd.DataFrame:
23
+ """
24
+ - index: index of the flattened array
25
+ - sample: sample number
26
+ - obs: observation number (enumerates response values)
27
+ - chain: chain number
28
+ """
29
+
30
+ _, iterations, _ = a.shape
31
+
32
+ a = np.concatenate(a, axis=0)
33
+ idx = jax.random.choice(key, a.shape[0], shape=(n,), replace=True)
34
+
35
+ a_column = a[idx, :].ravel()
36
+ sample_column = np.repeat(np.arange(n), a.shape[-1])
37
+ index_column = np.repeat(idx, a.shape[-1])
38
+ obs_column = np.tile(np.arange(a.shape[-1]), n)
39
+
40
+ data = {name: a_column, "sample": sample_column}
41
+ data["index"] = index_column
42
+ data["obs"] = obs_column
43
+ df = pd.DataFrame(data)
44
+
45
+ df["chain"] = df["index"] // iterations
46
+
47
+ return df
48
+
49
+
50
+ def summarise_1d_smooth(
51
+ term: Term,
52
+ samples: dict[str, Array],
53
+ newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
54
+ quantiles: Sequence[float] = (0.05, 0.5, 0.95),
55
+ hdi_prob: float = 0.9,
56
+ ngrid: int = 150,
57
+ ):
58
+ if newdata is None:
59
+ # TODO: Currently, this branch of the function assumes that term.basis.x is
60
+ # a strong node.
61
+ # That is not necessarily always the case.
62
+ xgrid = np.linspace(term.basis.x.value.min(), term.basis.x.value.max(), ngrid)
63
+ newdata_x: Mapping[str, ArrayLike] = {term.basis.x.name: xgrid}
64
+ else:
65
+ newdata_x = newdata
66
+ xgrid = np.asarray(newdata[term.basis.x.name])
67
+
68
+ newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
69
+
70
+ term_samples = term.predict(samples, newdata=newdata_x)
71
+ term_summary = (
72
+ gs.SamplesSummary.from_array(
73
+ term_samples, name=term.name, quantiles=quantiles, hdi_prob=hdi_prob
74
+ )
75
+ .to_dataframe()
76
+ .reset_index()
77
+ )
78
+
79
+ term_summary[term.basis.x.name] = xgrid
80
+ return term_summary
81
+
82
+
83
+ def plot_1d_smooth(
84
+ term: Term,
85
+ samples: dict[str, Array],
86
+ newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
87
+ ci_quantiles: tuple[float, float] | None = (0.05, 0.95),
88
+ hdi_prob: float | None = None,
89
+ show_n_samples: int | None = 50,
90
+ seed: int | KeyArray = 1,
91
+ ngrid: int = 150,
92
+ ):
93
+ if newdata is None:
94
+ # TODO: Currently, this branch of the function assumes that term.basis.x is
95
+ # a strong node.
96
+ # That is not necessarily always the case.
97
+ xgrid = np.linspace(term.basis.x.value.min(), term.basis.x.value.max(), 150)
98
+ newdata_x: Mapping[str, ArrayLike] = {term.basis.x.name: xgrid}
99
+ else:
100
+ newdata_x = newdata
101
+ xgrid = np.asarray(newdata[term.basis.x.name])
102
+
103
+ newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
104
+
105
+ term_samples = term.predict(samples, newdata=newdata_x)
106
+
107
+ term_summary = summarise_1d_smooth(
108
+ term=term,
109
+ samples=samples,
110
+ newdata=newdata,
111
+ quantiles=(0.05, 0.95) if ci_quantiles is None else ci_quantiles,
112
+ hdi_prob=0.9 if hdi_prob is None else hdi_prob,
113
+ ngrid=ngrid,
114
+ )
115
+
116
+ p = p9.ggplot(term_summary) + p9.labs(
117
+ title=f"Posterior summary of {term.name}",
118
+ x=term.basis.x.name,
119
+ y=term.name,
120
+ )
121
+
122
+ if ci_quantiles is not None:
123
+ p = p + p9.geom_ribbon(
124
+ p9.aes(
125
+ term.basis.x.name,
126
+ ymin=f"q_{str(ci_quantiles[0])}",
127
+ ymax=f"q_{str(ci_quantiles[1])}",
128
+ ),
129
+ fill="#56B4E9",
130
+ alpha=0.5,
131
+ data=term_summary,
132
+ )
133
+
134
+ if hdi_prob is not None:
135
+ p = p + p9.geom_line(
136
+ p9.aes(term.basis.x.name, "hdi_low"),
137
+ linetype="dashed",
138
+ data=term_summary,
139
+ )
140
+
141
+ p = p + p9.geom_line(
142
+ p9.aes(term.basis.x.name, "hdi_high"),
143
+ linetype="dashed",
144
+ data=term_summary,
145
+ )
146
+
147
+ if show_n_samples is not None and show_n_samples > 0:
148
+ key = jax.random.key(seed) if isinstance(seed, int) else seed
149
+
150
+ summary_samples_df = summarise_by_samples(
151
+ key=key, a=term_samples, name=term.name, n=show_n_samples
152
+ )
153
+
154
+ summary_samples_df[term.basis.x.name] = np.tile(
155
+ np.squeeze(xgrid), show_n_samples
156
+ )
157
+
158
+ p = p + p9.geom_line(
159
+ p9.aes(term.basis.x.name, term.name, group="sample"),
160
+ color="grey",
161
+ data=summary_samples_df,
162
+ alpha=0.3,
163
+ )
164
+
165
+ p = p + p9.geom_line(
166
+ p9.aes(term.basis.x.name, "mean"), data=term_summary, size=1.3, color="blue"
167
+ )
168
+
169
+ return p
170
+
171
+
172
+ def grid_nd(inputs: dict[str, jax.typing.ArrayLike], ngrid: int) -> dict[str, Any]:
173
+ mins = {k: jnp.min(v) for k, v in inputs.items()}
174
+ maxs = {k: jnp.max(v) for k, v in inputs.items()}
175
+ grids = {k: np.linspace(mins[k], maxs[k], ngrid) for k in inputs}
176
+ full_grid_arrays = [v.flatten() for v in np.meshgrid(*grids.values())]
177
+ full_grids = dict(zip(inputs.keys(), full_grid_arrays))
178
+ return full_grids
179
+
180
+
181
+ def input_grid_nd_smooth(
182
+ term: TPTerm | Term | LinTerm, ngrid: int
183
+ ) -> dict[str, jax.typing.ArrayLike]:
184
+ if isinstance(term, TPTerm):
185
+ inputs = {k: v.value for k, v in term.input_obs.items()}
186
+ return grid_nd(inputs, ngrid)
187
+
188
+ if not isinstance(term.basis.x, lsl.TransientCalc | lsl.Calc):
189
+ raise NotImplementedError(
190
+ "Function not implemented for bases with inputs of "
191
+ f"type {type(term.basis.x)}."
192
+ )
193
+ inputs = {n.var.name: n.var.value for n in term.basis.x.all_input_nodes()} # type: ignore
194
+ return grid_nd(inputs, ngrid)
195
+
196
+
197
+ # using q_0.05 and q_0.95 explicitly here
198
+ # even though users could choose to return other quantiles like 0.1 and 0.9
199
+ # then they can supply q_0.1 and q_0.9, etc.
200
+ PlotVars = Literal[
201
+ "mean", "sd", "var", "hdi_low", "hdi_high", "q_0.05", "q_0.5", "q_0.95"
202
+ ]
203
+
204
+
205
+ def summarise_nd_smooth(
206
+ term: Term | TPTerm,
207
+ samples: Mapping[str, jax.Array],
208
+ newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
209
+ ngrid: int = 20,
210
+ which: PlotVars | Sequence[PlotVars] = "mean",
211
+ quantiles: Sequence[float] = (0.05, 0.5, 0.95),
212
+ hdi_prob: float = 0.9,
213
+ newdata_meshgrid: bool = False,
214
+ ):
215
+ if isinstance(which, str):
216
+ which = [which]
217
+
218
+ if newdata is None:
219
+ newdata_x: Mapping[str, ArrayLike] = input_grid_nd_smooth(term, ngrid=ngrid)
220
+ elif newdata_meshgrid:
221
+ full_grid_arrays = [v.flatten() for v in np.meshgrid(*newdata.values())]
222
+ newdata_x = dict(zip(newdata.keys(), full_grid_arrays))
223
+ else:
224
+ newdata_x = newdata
225
+
226
+ newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
227
+
228
+ term_samples = term.predict(samples, newdata=newdata_x)
229
+
230
+ ci_quantiles_ = (0.05, 0.95) if quantiles is None else quantiles
231
+ hdi_prob_ = 0.9 if hdi_prob is None else hdi_prob
232
+ term_summary = (
233
+ gs.SamplesSummary.from_array(
234
+ term_samples, name=term.name, quantiles=ci_quantiles_, hdi_prob=hdi_prob_
235
+ )
236
+ .to_dataframe()
237
+ .reset_index()
238
+ )
239
+
240
+ for k, v in newdata_x.items():
241
+ term_summary[k] = np.asarray(v)
242
+
243
+ term_summary.reset_index(inplace=True)
244
+ term_summary = term_summary.melt(
245
+ id_vars=["index"] + list(newdata_x.keys()),
246
+ value_vars=which,
247
+ var_name="variable",
248
+ value_name="value",
249
+ )
250
+
251
+ term_summary["variable"] = pd.Categorical(
252
+ term_summary["variable"], categories=which
253
+ )
254
+ return term_summary
255
+
256
+
257
+ def plot_2d_smooth(
258
+ term: TPTerm | Term,
259
+ samples: Mapping[str, jax.Array],
260
+ newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
261
+ ngrid: int = 20,
262
+ which: PlotVars | Sequence[PlotVars] = "mean",
263
+ quantiles: Sequence[float] = (0.05, 0.5, 0.95),
264
+ hdi_prob: float = 0.9,
265
+ newdata_meshgrid: bool = False,
266
+ ):
267
+ if isinstance(term, TPTerm):
268
+ names = list(term.input_obs)
269
+ if len(names) != 2:
270
+ raise ValueError(
271
+ f"'plot_2d_smooth' can only handle smooths with two inputs, "
272
+ f"got {len(names)} for smooth {term}: {names}"
273
+ )
274
+
275
+ for v in term.input_obs.values():
276
+ if jnp.issubdtype(v.value, jnp.integer):
277
+ raise TypeError(
278
+ "'plot_2d_smooth' expects continuous marginals, got "
279
+ f"type {v.value.dtype} for {v}"
280
+ )
281
+ else:
282
+ names = [n.var.name for n in term.basis.x.all_input_nodes()] # type: ignore
283
+
284
+ term_summary = summarise_nd_smooth(
285
+ term=term,
286
+ samples=samples,
287
+ newdata=newdata,
288
+ ngrid=ngrid,
289
+ which=which,
290
+ quantiles=quantiles,
291
+ hdi_prob=hdi_prob,
292
+ newdata_meshgrid=newdata_meshgrid,
293
+ )
294
+
295
+ p = (
296
+ p9.ggplot(term_summary)
297
+ + p9.labs(title=f"Posterior summary of {term.name}")
298
+ + p9.aes(*names, fill="value")
299
+ + p9.facet_wrap("~variable", labeller="label_both")
300
+ )
301
+
302
+ p = p + p9.geom_tile()
303
+
304
+ return p
305
+
306
+
307
+ def polys_to_df(polys: Mapping[str, ArrayLike]):
308
+ poly_labels = list(polys)
309
+ poly_coords = list(polys.values())
310
+ poly_coord_dim = np.shape(poly_coords[0])[-1]
311
+ poly_df = pd.concat(
312
+ [
313
+ pd.DataFrame(
314
+ poly_coords[i], columns=[f"V{i}" for i in range(poly_coord_dim)]
315
+ ).assign(vertex=lambda df: df.index + 1, id=i, label=poly_labels[i])
316
+ for i in range(len(polys))
317
+ ],
318
+ ignore_index=True,
319
+ )
320
+ return poly_df
321
+
322
+
323
+ def plot_polys(
324
+ region: str,
325
+ which: str | Sequence[str],
326
+ df: pd.DataFrame,
327
+ polys: Mapping[str, ArrayLike],
328
+ show_unobserved: bool = True,
329
+ observed_color: str = "none",
330
+ unobserved_color: str = "red",
331
+ ) -> p9.ggplot:
332
+ if isinstance(which, str):
333
+ which = [which]
334
+
335
+ poly_df = polys_to_df(polys)
336
+
337
+ df["label"] = df[region].astype(str)
338
+ # plot_df = df.merge(poly_df, on="label")
339
+
340
+ if "observed" not in df.columns:
341
+ df["observed"] = True
342
+
343
+ if df["observed"].all():
344
+ show_unobserved = False
345
+
346
+ plot_df = poly_df.merge(df, on="label")
347
+
348
+ plot_df = plot_df.melt(
349
+ id_vars=["label", "V0", "V1", "observed"],
350
+ value_vars=which,
351
+ var_name="variable",
352
+ value_name="value",
353
+ )
354
+
355
+ plot_df["variable"] = pd.Categorical(plot_df["variable"], categories=which)
356
+
357
+ p = (
358
+ p9.ggplot(plot_df)
359
+ + p9.aes("V0", "V1", group="label", fill="value")
360
+ + p9.aes(color="observed")
361
+ + p9.facet_wrap("~variable", labeller="label_both")
362
+ + p9.scale_color_manual({True: observed_color, False: unobserved_color})
363
+ + p9.guides(color=p9.guide_legend(override_aes={"fill": None}))
364
+ )
365
+ if show_unobserved:
366
+ p = p + p9.geom_polygon()
367
+ else:
368
+ p = p + p9.geom_polygon(data=plot_df.query("observed == True"))
369
+ p = p + p9.geom_polygon(data=plot_df.query("observed == False"), fill="none")
370
+
371
+ return p
372
+
373
+
374
+ def _convert_to_integers(
375
+ grid: np.typing.NDArray,
376
+ labels: Sequence[str] | CategoryMapping | None,
377
+ term: RITerm | MRFTerm | lsl.Var,
378
+ ) -> np.typing.NDArray[np.int_]:
379
+ if isinstance(labels, CategoryMapping):
380
+ grid = labels.to_integers(grid)
381
+ else:
382
+ try:
383
+ grid = term.mapping.to_integers(grid) # type: ignore
384
+ except (ValueError, AttributeError):
385
+ if not np.issubdtype(grid.dtype, np.integer):
386
+ raise TypeError(
387
+ f"There's no mapping available on the term {term}. "
388
+ "In this case, its values in 'newdata' must be specified "
389
+ f"as integer codes. Got data type {grid.dtype}"
390
+ )
391
+
392
+ return grid
393
+
394
+
395
+ def summarise_cluster(
396
+ term: RITerm | MRFTerm | Term,
397
+ samples: Mapping[str, jax.Array],
398
+ newdata: gs.Position
399
+ | None
400
+ | Mapping[str, ArrayLike | Sequence[int] | Sequence[str]] = None,
401
+ labels: CategoryMapping | Sequence[str] | None = None,
402
+ quantiles: Sequence[float] = (0.05, 0.5, 0.95),
403
+ hdi_prob: float = 0.9,
404
+ ) -> pd.DataFrame:
405
+ if labels is None:
406
+ try:
407
+ labels = term.mapping # type: ignore
408
+ except (AttributeError, ValueError):
409
+ labels = None
410
+
411
+ if newdata is None and isinstance(labels, CategoryMapping):
412
+ grid = np.asarray(list(labels.integers_to_labels_map))
413
+ unique_x = np.unique(term.basis.x.value)
414
+ newdata_x: Mapping[str, ArrayLike] = {term.basis.x.name: grid}
415
+ observed = [x in unique_x for x in grid]
416
+ elif newdata is None:
417
+ grid = np.unique(term.basis.x.value)
418
+ newdata_x = {term.basis.x.name: grid}
419
+ observed = [True for _ in grid]
420
+ else:
421
+ unique_x = np.unique(term.basis.x.value)
422
+ grid = np.asarray(newdata[term.basis.x.name])
423
+ grid = _convert_to_integers(grid, labels, term)
424
+
425
+ observed = [x in unique_x for x in grid]
426
+ newdata_x = {term.basis.x.name: grid}
427
+
428
+ newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
429
+ predictions = term.predict(samples=samples, newdata=newdata_x)
430
+ predictions_summary = (
431
+ gs.SamplesSummary.from_array(
432
+ predictions,
433
+ quantiles=quantiles,
434
+ hdi_prob=0.9 if hdi_prob is None else hdi_prob,
435
+ )
436
+ .to_dataframe()
437
+ .reset_index()
438
+ )
439
+
440
+ if isinstance(labels, CategoryMapping):
441
+ codes = newdata_x[term.basis.x.name]
442
+ labels_str = list(labels.integers_to_labels(codes))
443
+ categories = list(labels.labels_to_integers_map)
444
+ predictions_summary[term.basis.x.name] = pd.Categorical(
445
+ labels_str, categories=categories
446
+ )
447
+ elif labels is not None:
448
+ labels_str = list(labels)
449
+ categories = sorted(set(labels_str))
450
+ predictions_summary[term.basis.x.name] = pd.Categorical(
451
+ labels_str, categories=categories
452
+ )
453
+ else:
454
+ predictions_summary[term.basis.x.name] = pd.Categorical(
455
+ np.asarray(term.basis.x.value)
456
+ )
457
+
458
+ predictions_summary["observed"] = observed
459
+
460
+ return predictions_summary
461
+
462
+
463
+ def summarise_regions(
464
+ term: RITerm | MRFTerm | Term,
465
+ samples: Mapping[str, jax.Array],
466
+ newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
467
+ which: PlotVars | Sequence[PlotVars] = "mean",
468
+ polys: Mapping[str, ArrayLike] | None = None,
469
+ labels: CategoryMapping | Sequence[str] | None = None,
470
+ quantiles: Sequence[float] = (0.05, 0.5, 0.95),
471
+ hdi_prob: float = 0.9,
472
+ ) -> pd.DataFrame:
473
+ polygons = None
474
+ if polys is not None:
475
+ polygons = polys
476
+ else:
477
+ try:
478
+ # using type ignore here, since the case of term not having the attribute
479
+ # polygons is handle by the try except
480
+ polygons = term.polygons # type: ignore
481
+ except AttributeError:
482
+ pass
483
+
484
+ if not polygons:
485
+ raise ValueError(
486
+ "When passing a term without polygons, polygons must be supplied manually "
487
+ "through the argument 'polys'"
488
+ )
489
+
490
+ df = summarise_cluster(
491
+ term=term,
492
+ samples=samples,
493
+ newdata=newdata,
494
+ labels=labels,
495
+ quantiles=quantiles,
496
+ hdi_prob=hdi_prob,
497
+ )
498
+ region = term.basis.x.name
499
+ if isinstance(which, str):
500
+ which = [which]
501
+
502
+ unique_labels_in_df = df[term.basis.x.name].unique().tolist()
503
+ assert polygons is not None
504
+ for region_label in polygons:
505
+ if region_label not in unique_labels_in_df:
506
+ raise ValueError(
507
+ f"Label '{region_label}' found in polys, but not in cluster summary. "
508
+ f"Known labels: {unique_labels_in_df}"
509
+ )
510
+
511
+ poly_df = polys_to_df(polygons)
512
+
513
+ df["label"] = df[region].astype(str)
514
+
515
+ if "observed" not in df.columns:
516
+ df["observed"] = True
517
+
518
+ plot_df = poly_df.merge(df, on="label")
519
+
520
+ plot_df = plot_df.melt(
521
+ id_vars=["label", "V0", "V1", "observed"],
522
+ value_vars=which,
523
+ var_name="variable",
524
+ value_name="value",
525
+ )
526
+
527
+ plot_df["variable"] = pd.Categorical(plot_df["variable"], categories=which)
528
+
529
+ return plot_df
530
+
531
+
532
+ def plot_regions(
533
+ term: RITerm | MRFTerm | Term,
534
+ samples: Mapping[str, jax.Array],
535
+ newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
536
+ which: PlotVars | Sequence[PlotVars] = "mean",
537
+ polys: Mapping[str, ArrayLike] | None = None,
538
+ labels: CategoryMapping | None = None,
539
+ quantiles: Sequence[float] = (0.05, 0.5, 0.95),
540
+ hdi_prob: float = 0.9,
541
+ show_unobserved: bool = True,
542
+ observed_color: str = "none",
543
+ unobserved_color: str = "red",
544
+ ) -> p9.ggplot:
545
+ plot_df = summarise_regions(
546
+ term=term,
547
+ samples=samples,
548
+ newdata=newdata,
549
+ which=which,
550
+ polys=polys,
551
+ labels=labels,
552
+ quantiles=quantiles,
553
+ hdi_prob=hdi_prob,
554
+ )
555
+ p = (
556
+ p9.ggplot(plot_df)
557
+ + p9.aes("V0", "V1", group="label", fill="value")
558
+ + p9.aes(color="observed")
559
+ + p9.facet_wrap("~variable", labeller="label_both")
560
+ + p9.scale_color_manual({True: observed_color, False: unobserved_color})
561
+ + p9.guides(color=p9.guide_legend(override_aes={"fill": None}))
562
+ )
563
+ if show_unobserved:
564
+ p = p + p9.geom_polygon()
565
+ else:
566
+ p = p + p9.geom_polygon(data=plot_df.query("observed == True"))
567
+ p = p + p9.geom_polygon(data=plot_df.query("observed == False"), fill="none")
568
+
569
+ p += p9.labs(title=f"Plot of {term.name}")
570
+ return p
571
+
572
+
573
+ def plot_forest(
574
+ term: RITerm | MRFTerm | LinTerm,
575
+ samples: Mapping[str, jax.Array],
576
+ newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
577
+ labels: CategoryMapping | None = None,
578
+ ymin: str = "hdi_low",
579
+ ymax: str = "hdi_high",
580
+ ci_quantiles: tuple[float, float] = (0.05, 0.95),
581
+ hdi_prob: float = 0.9,
582
+ show_unobserved: bool = True,
583
+ highlight_unobserved: bool = True,
584
+ unobserved_color: str = "red",
585
+ indices: Sequence[int] | None = None,
586
+ ) -> p9.ggplot:
587
+ if isinstance(term, RITerm | MRFTerm):
588
+ return plot_forest_clustered(
589
+ term=term,
590
+ samples=samples,
591
+ newdata=newdata,
592
+ labels=labels,
593
+ ymin=ymin,
594
+ ymax=ymax,
595
+ ci_quantiles=ci_quantiles,
596
+ hdi_prob=hdi_prob,
597
+ show_unobserved=show_unobserved,
598
+ highlight_unobserved=highlight_unobserved,
599
+ unobserved_color=unobserved_color,
600
+ indices=indices,
601
+ )
602
+ elif isinstance(term, LinTerm):
603
+ return plot_forest_lin(
604
+ term=term,
605
+ samples=samples,
606
+ ymin=ymin,
607
+ ymax=ymax,
608
+ ci_quantiles=ci_quantiles,
609
+ hdi_prob=hdi_prob,
610
+ indices=indices,
611
+ )
612
+ else:
613
+ raise TypeError(f"term has unsupported type {type(term)}.")
614
+
615
+
616
+ def summarise_lin(
617
+ term: LinTerm,
618
+ samples: Mapping[str, jax.Array],
619
+ quantiles: Sequence[float] = (0.05, 0.5, 0.95),
620
+ hdi_prob: float = 0.9,
621
+ indices: Sequence[int] | None = None,
622
+ ) -> pd.DataFrame:
623
+ if indices is not None:
624
+ coef_samples = samples[term.coef.name][..., indices]
625
+ colnames = [term.column_names[i] for i in indices]
626
+ else:
627
+ coef_samples = samples[term.coef.name]
628
+ colnames = term.column_names
629
+
630
+ df = (
631
+ gs.SamplesSummary.from_array(
632
+ coef_samples, quantiles=quantiles, hdi_prob=hdi_prob
633
+ )
634
+ .to_dataframe()
635
+ .reset_index()
636
+ )
637
+
638
+ df["x"] = colnames
639
+ df.drop(["variable", "var_fqn", "var_index"], axis=1, inplace=True)
640
+ df.insert(0, "x", df.pop("x"))
641
+ return df
642
+
643
+
644
+ def plot_forest_lin(
645
+ term: LinTerm,
646
+ samples: Mapping[str, jax.Array],
647
+ ymin: str = "hdi_low",
648
+ ymax: str = "hdi_high",
649
+ ci_quantiles: tuple[float, float] = (0.05, 0.95),
650
+ hdi_prob: float = 0.9,
651
+ indices: Sequence[int] | None = None,
652
+ ) -> p9.ggplot:
653
+ df = summarise_lin(
654
+ term=term,
655
+ samples=samples,
656
+ quantiles=ci_quantiles,
657
+ hdi_prob=hdi_prob,
658
+ indices=indices,
659
+ )
660
+
661
+ df[ymin] = df[ymin].astype(df["mean"].dtype)
662
+ df[ymax] = df[ymax].astype(df["mean"].dtype)
663
+
664
+ p = (
665
+ p9.ggplot(df)
666
+ + p9.aes("x", "mean")
667
+ + p9.geom_hline(yintercept=0, color="grey")
668
+ + p9.geom_linerange(p9.aes(ymin=ymin, ymax=ymax), color="grey")
669
+ + p9.geom_point()
670
+ + p9.coord_flip()
671
+ + p9.labs(x="x")
672
+ )
673
+
674
+ p += p9.labs(title=f"Posterior summary of {term.name}")
675
+
676
+ return p
677
+
678
+
679
+ def plot_forest_clustered(
680
+ term: RITerm | MRFTerm | Term,
681
+ samples: Mapping[str, jax.Array],
682
+ newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
683
+ labels: CategoryMapping | None = None,
684
+ ymin: str = "hdi_low",
685
+ ymax: str = "hdi_high",
686
+ ci_quantiles: tuple[float, float] = (0.05, 0.95),
687
+ hdi_prob: float = 0.9,
688
+ show_unobserved: bool = True,
689
+ highlight_unobserved: bool = True,
690
+ unobserved_color: str = "red",
691
+ indices: Sequence[int] | None = None,
692
+ ) -> p9.ggplot:
693
+ if labels is None:
694
+ try:
695
+ labels = term.mapping # type: ignore
696
+ except AttributeError:
697
+ labels = None
698
+
699
+ df = summarise_cluster(
700
+ term=term,
701
+ samples=samples,
702
+ newdata=newdata,
703
+ labels=labels,
704
+ quantiles=ci_quantiles,
705
+ hdi_prob=hdi_prob,
706
+ )
707
+ cluster = term.basis.x.name
708
+
709
+ if labels is None:
710
+ xlab = cluster + " (indices)"
711
+ else:
712
+ xlab = cluster + " (labels)"
713
+
714
+ df[ymin] = df[ymin].astype(df["mean"].dtype)
715
+ df[ymax] = df[ymax].astype(df["mean"].dtype)
716
+
717
+ if indices is not None:
718
+ df = df.iloc[indices, :]
719
+
720
+ if not show_unobserved:
721
+ df = df.query("observed == True")
722
+
723
+ p = (
724
+ p9.ggplot(df)
725
+ + p9.aes(cluster, "mean")
726
+ + p9.geom_hline(yintercept=0, color="grey")
727
+ + p9.geom_linerange(p9.aes(ymin=ymin, ymax=ymax), color="grey")
728
+ + p9.geom_point()
729
+ + p9.coord_flip()
730
+ + p9.labs(x=xlab)
731
+ )
732
+
733
+ if highlight_unobserved:
734
+ df_uo = df.query("observed == False")
735
+ p = p + p9.geom_point(
736
+ p9.aes(cluster, "mean"),
737
+ color=unobserved_color,
738
+ shape="x",
739
+ data=df_uo,
740
+ )
741
+
742
+ p += p9.labs(title=f"Posterior summary of {term.name}")
743
+
744
+ return p
745
+
746
+
747
+ def summarise_1d_smooth_clustered(
748
+ clustered_term: lsl.Var,
749
+ samples: Mapping[str, jax.Array],
750
+ ngrid: int = 20,
751
+ newdata: gs.Position
752
+ | None
753
+ | Mapping[str, ArrayLike | Sequence[int] | Sequence[str]] = None,
754
+ which: PlotVars | Sequence[PlotVars] = "mean",
755
+ ci_quantiles: Sequence[float] = (0.05, 0.5, 0.95),
756
+ hdi_prob: float = 0.9,
757
+ labels: CategoryMapping | None | Sequence[str] = None,
758
+ newdata_meshgrid: bool = False,
759
+ ):
760
+ if isinstance(which, str):
761
+ which = [which]
762
+
763
+ term = clustered_term.value_node["x"]
764
+ cluster = clustered_term.value_node["cluster"]
765
+
766
+ assert isinstance(term, Term | lsl.Var)
767
+ assert isinstance(cluster, RITerm | MRFTerm)
768
+
769
+ if labels is None:
770
+ try:
771
+ labels = cluster.mapping # type: ignore
772
+ except (AttributeError, ValueError):
773
+ labels = None
774
+
775
+ if isinstance(term, Term):
776
+ x = term.basis.x
777
+ else:
778
+ x = term
779
+
780
+ if newdata is None:
781
+ if not jnp.issubdtype(x.value.dtype, jnp.floating):
782
+ raise TypeError(
783
+ "Automatic grid creation is valid only for continuous x, got "
784
+ f"dtype {jnp.dtype(x.value)} for {x}."
785
+ )
786
+
787
+ if newdata is None and isinstance(labels, CategoryMapping):
788
+ cgrid = np.asarray(list(labels.integers_to_labels_map)) # integer codes
789
+ unique_clusters = np.unique(cluster.basis.x.value) # unique codes
790
+
791
+ if isinstance(x, lsl.Node) or x.strong:
792
+ xgrid: Mapping[str, ArrayLike] = {
793
+ x.name: jnp.linspace(x.value.min(), x.value.max(), ngrid)
794
+ }
795
+ else:
796
+ assert isinstance(term, Term | LinTerm), (
797
+ f"Wrong type for term: {type(term)}"
798
+ )
799
+ ncols = jnp.shape(term.basis.value)[-1]
800
+ xgrid = input_grid_nd_smooth(term, ngrid=int(np.pow(ngrid, 1 / ncols)))
801
+
802
+ grid: Mapping[str, ArrayLike | Sequence[int] | Sequence[str]] = dict(xgrid) | {
803
+ cluster.basis.x.name: cgrid
804
+ }
805
+
806
+ # code : bool
807
+ observed = {x: x in unique_clusters for x in cgrid}
808
+ elif newdata is None:
809
+ cgrid = np.unique(cluster.basis.x.value)
810
+ if isinstance(x, lsl.Node) or x.strong:
811
+ xgrid = {x.name: jnp.linspace(x.value.min(), x.value.max(), ngrid)}
812
+ else:
813
+ assert isinstance(term, Term | LinTerm), (
814
+ f"Wrong type for term: {type(term)}"
815
+ )
816
+ ncols = jnp.shape(term.basis.value)[-1]
817
+ xgrid = input_grid_nd_smooth(term, ngrid=int(np.pow(ngrid, 1 / ncols)))
818
+
819
+ grid = xgrid | {cluster.basis.x.name: cgrid}
820
+
821
+ # code : bool
822
+ observed = {x: True for x in cgrid}
823
+ else:
824
+ pass
825
+
826
+ if newdata is not None and newdata_meshgrid:
827
+ cgrid = np.asarray(newdata[cluster.basis.x.name])
828
+ cgrid = _convert_to_integers(cgrid, labels, cluster)
829
+
830
+ grid = {x.name: newdata[x.name], cluster.basis.x.name: cgrid}
831
+ full_grid_arrays = [v.flatten() for v in np.meshgrid(*grid.values())]
832
+ newdata_x: dict[str, ArrayLike | Sequence[int] | Sequence[str]] = dict(
833
+ zip(grid.keys(), full_grid_arrays)
834
+ )
835
+
836
+ if isinstance(labels, CategoryMapping):
837
+ observed = {x: x in cluster.basis.x.value for x in cgrid}
838
+ else:
839
+ observed = {x: True for x in cgrid}
840
+ elif newdata is not None:
841
+ cgrid = np.asarray(newdata[cluster.basis.x.name])
842
+ cgrid = _convert_to_integers(cgrid, labels, cluster)
843
+ newdata_x = {x.name: newdata[x.name], cluster.basis.x.name: cgrid}
844
+ # code : bool
845
+ if isinstance(labels, CategoryMapping):
846
+ observed = {x: x in cluster.basis.x.value for x in cgrid}
847
+ else:
848
+ observed = {x: True for x in cgrid}
849
+ else: # then we use the grid created from observed data
850
+ full_grid_arrays = [v.flatten() for v in np.meshgrid(*grid.values())]
851
+ newdata_x = dict(zip(grid.keys(), full_grid_arrays))
852
+
853
+ newdata_x = {k: jnp.asarray(v) for k, v in newdata_x.items()}
854
+
855
+ term_samples = clustered_term.predict(samples, newdata=newdata_x)
856
+ term_summary = (
857
+ gs.SamplesSummary.from_array(
858
+ term_samples,
859
+ name=clustered_term.name,
860
+ quantiles=ci_quantiles,
861
+ hdi_prob=hdi_prob,
862
+ )
863
+ .to_dataframe()
864
+ .reset_index()
865
+ )
866
+
867
+ for k, v in newdata_x.items():
868
+ term_summary[k] = np.asarray(v)
869
+
870
+ if labels is not None:
871
+ if isinstance(labels, CategoryMapping):
872
+ labels_long = labels.to_labels(newdata_x[cluster.basis.x.name])
873
+ categories = list(labels.labels_to_integers_map)
874
+ term_summary[cluster.basis.x.name] = pd.Categorical(
875
+ labels_long, categories=categories
876
+ )
877
+ else:
878
+ term_summary[cluster.basis.x.name] = labels
879
+
880
+ term_summary["observed"] = [
881
+ observed[x] for x in np.asarray(newdata_x[cluster.basis.x.name])
882
+ ]
883
+
884
+ term_summary.reset_index(inplace=True)
885
+
886
+ return term_summary
887
+
888
+
889
+ def plot_1d_smooth_clustered(
890
+ clustered_term: lsl.Var,
891
+ samples: Mapping[str, jax.Array],
892
+ ngrid: int = 20,
893
+ newdata: gs.Position | None | Mapping[str, ArrayLike] = None,
894
+ labels: CategoryMapping | None = None,
895
+ color_scale: str = "viridis",
896
+ newdata_meshgrid: bool = False,
897
+ ):
898
+ ci_quantiles = (0.05, 0.5, 0.95)
899
+ hdi_prob = 0.9
900
+
901
+ term = clustered_term.value_node["x"]
902
+ cluster = clustered_term.value_node["cluster"]
903
+
904
+ assert isinstance(term, Term | lsl.Var)
905
+ assert isinstance(cluster, RITerm | MRFTerm)
906
+
907
+ if labels is None:
908
+ try:
909
+ labels = cluster.mapping # type: ignore
910
+ except AttributeError:
911
+ labels = None
912
+
913
+ term_summary = summarise_1d_smooth_clustered(
914
+ clustered_term=clustered_term,
915
+ samples=samples,
916
+ ngrid=ngrid,
917
+ ci_quantiles=ci_quantiles,
918
+ hdi_prob=hdi_prob,
919
+ labels=labels,
920
+ newdata=newdata,
921
+ newdata_meshgrid=newdata_meshgrid,
922
+ )
923
+
924
+ if labels is None:
925
+ clab = cluster.basis.x.name + " (indices)"
926
+ else:
927
+ clab = cluster.basis.x.name + " (labels)"
928
+
929
+ if isinstance(term, Term):
930
+ x = term.basis.x
931
+ else:
932
+ x = term
933
+
934
+ p = (
935
+ p9.ggplot(term_summary)
936
+ + p9.aes(x.name, "mean", group=cluster.basis.x.name)
937
+ + p9.aes(color=cluster.basis.x.name)
938
+ + p9.labs(
939
+ title=f"Posterior summary of {clustered_term.name}", x=x.name, color=clab
940
+ )
941
+ + p9.facet_wrap("~variable", labeller="label_both")
942
+ + p9.scale_color_cmap_d(color_scale)
943
+ + p9.geom_line()
944
+ )
945
+
946
+ return p