bella-companion 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bella-companion might be problematic. Click here for more details.

Files changed (26) hide show
  1. bella_companion/cli.py +13 -4
  2. bella_companion/simulations/__init__.py +2 -1
  3. bella_companion/simulations/figures/__init__.py +21 -0
  4. bella_companion/simulations/figures/epi_multitype_results.py +81 -0
  5. bella_companion/simulations/figures/epi_skyline_results.py +46 -0
  6. bella_companion/simulations/figures/explain/__init__.py +6 -0
  7. bella_companion/simulations/figures/explain/pdp.py +101 -0
  8. bella_companion/simulations/figures/explain/shap.py +56 -0
  9. bella_companion/simulations/figures/fbd_2traits_results.py +83 -0
  10. bella_companion/simulations/figures/fbd_no_traits_results.py +58 -0
  11. bella_companion/simulations/figures/scenarios.py +38 -32
  12. bella_companion/simulations/generate_figures.py +24 -0
  13. bella_companion/simulations/scenarios/fbd_2traits.py +2 -2
  14. bella_companion/utils/__init__.py +20 -1
  15. bella_companion/utils/beast.py +1 -1
  16. bella_companion/utils/explain.py +45 -0
  17. bella_companion/utils/plots.py +98 -0
  18. {bella_companion-0.0.3.dist-info → bella_companion-0.0.5.dist-info}/METADATA +4 -3
  19. {bella_companion-0.0.3.dist-info → bella_companion-0.0.5.dist-info}/RECORD +21 -16
  20. bella_companion/simulations/figures/epi_explainations.py +0 -109
  21. bella_companion/simulations/figures/epi_predictions.py +0 -58
  22. bella_companion/simulations/figures/fbd_explainations.py +0 -99
  23. bella_companion/simulations/figures/fbd_predictions.py +0 -66
  24. bella_companion/simulations/figures/utils.py +0 -250
  25. {bella_companion-0.0.3.dist-info → bella_companion-0.0.5.dist-info}/WHEEL +0 -0
  26. {bella_companion-0.0.3.dist-info → bella_companion-0.0.5.dist-info}/entry_points.txt +0 -0
@@ -9,7 +9,7 @@ import numpy as np
9
9
  import polars as pl
10
10
  from joblib import Parallel, delayed
11
11
  from lumiere import read_log_file, read_weights
12
- from lumiere.backend.typing import Weights
12
+ from lumiere.typing import Weights
13
13
  from tqdm import tqdm
14
14
 
15
15
  from bella_companion.utils.slurm import get_job_metadata
@@ -0,0 +1,45 @@
1
+ import numpy as np
2
+ from lumiere import get_partial_dependence_values, get_shap_features_importance
3
+ from lumiere.typing import ActivationFunction, Weights
4
+ from numpy.typing import ArrayLike
5
+
6
+
7
+ def get_median_partial_dependence_values(
8
+ weights: list[Weights], # shape: (n_weight_samples, ...)
9
+ features_grid: list[list[float]],
10
+ hidden_activation: ActivationFunction,
11
+ output_activation: ActivationFunction,
12
+ ) -> list[list[float]]: # shape: (n_features, n_grid_points)
13
+ pdvalues = [
14
+ get_partial_dependence_values(
15
+ weights=w,
16
+ features_grid=features_grid,
17
+ hidden_activation=hidden_activation,
18
+ output_activation=output_activation,
19
+ )
20
+ for w in weights
21
+ ]
22
+ return [
23
+ np.median([pd[feature_idx] for pd in pdvalues], axis=0).tolist()
24
+ for feature_idx in range(len(features_grid))
25
+ ]
26
+
27
+
28
+ def get_median_shap_features_importance(
29
+ weights: list[Weights],
30
+ inputs: ArrayLike,
31
+ hidden_activation: ActivationFunction,
32
+ output_activation: ActivationFunction,
33
+ ) -> list[float]: # length: n_features
34
+ features_importance = np.array(
35
+ [
36
+ get_shap_features_importance(
37
+ weights=w,
38
+ inputs=inputs,
39
+ hidden_activation=hidden_activation,
40
+ output_activation=output_activation,
41
+ )
42
+ for w in weights
43
+ ]
44
+ )
45
+ return np.median(features_importance, axis=0).tolist()
@@ -0,0 +1,98 @@
1
+ from pathlib import Path
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import polars as pl
6
+ import seaborn as sns
7
+
8
+
9
+ def _set_time_bin_xticks(n: int, reverse: bool = False):
10
+ xticks_labels = range(n)
11
+ if reverse:
12
+ xticks_labels = reversed(xticks_labels)
13
+ plt.xticks(ticks=range(n), labels=list(map(str, xticks_labels))) # pyright: ignore
14
+ plt.xlabel("Time bin") # pyright: ignore
15
+
16
+
17
+ def step(
18
+ x: list[float],
19
+ reverse_xticks: bool = False,
20
+ label: str | None = None,
21
+ color: str | None = None,
22
+ linestyle: str | None = None,
23
+ ):
24
+ x = [x[0], *x]
25
+ n = len(x)
26
+ plt.step( # pyright: ignore
27
+ list(range(n)), x, label=label, color=color, linestyle=linestyle
28
+ )
29
+ _set_time_bin_xticks(n, reverse_xticks)
30
+
31
+
32
+ def _count_time_bins(true_values: dict[str, list[float]]) -> int:
33
+ assert (
34
+ len({len(true_value) for true_value in true_values.values()}) == 1
35
+ ), "All targets must have the same number of change times."
36
+ return len(next(iter((true_values.values()))))
37
+
38
+
39
+ def plot_maes_per_time_bin(
40
+ logs_summaries: dict[str, pl.DataFrame],
41
+ true_values: dict[str, list[float]],
42
+ output_filepath: Path,
43
+ reverse_xticks: bool = False,
44
+ ):
45
+ def _mae(target: str, i: int) -> pl.Expr:
46
+ return (pl.col(f"{target}i{i}_median") - true_values[target][i]).abs()
47
+
48
+ n_time_bins = _count_time_bins(true_values)
49
+ df = pl.concat(
50
+ logs_summaries[model]
51
+ .select(
52
+ pl.mean_horizontal([_mae(target, i) for target in true_values]).alias("MAE")
53
+ )
54
+ .with_columns(pl.lit(i).alias("Time bin"), pl.lit(model).alias("Model"))
55
+ for i in range(n_time_bins)
56
+ for model in logs_summaries
57
+ )
58
+ sns.violinplot(
59
+ x="Time bin",
60
+ y="MAE",
61
+ hue="Model",
62
+ data=df,
63
+ inner=None,
64
+ cut=0,
65
+ density_norm="width",
66
+ legend=False,
67
+ )
68
+ _set_time_bin_xticks(n_time_bins, reverse_xticks)
69
+ plt.savefig(output_filepath) # pyright: ignore
70
+ plt.close()
71
+
72
+
73
+ def plot_coverage_per_time_bin(
74
+ logs_summaries: dict[str, pl.DataFrame],
75
+ true_values: dict[str, list[float]],
76
+ output_filepath: Path,
77
+ reverse_xticks: bool = False,
78
+ ):
79
+ def _coverage(model: str, target: str, i: int) -> float:
80
+ lower_bound = logs_summaries[model][f"{target}i{i}_lower"]
81
+ upper_bound = logs_summaries[model][f"{target}i{i}_upper"]
82
+ true_value = true_values[target][i]
83
+ N = len(logs_summaries[model])
84
+ return ((lower_bound <= true_value) & (true_value <= upper_bound)).sum() / N
85
+
86
+ n_time_bins = _count_time_bins(true_values)
87
+ for model in logs_summaries:
88
+ avg_coverage_by_time_bin = [
89
+ np.mean([_coverage(model, target, i) for target in true_values])
90
+ for i in range(_count_time_bins(true_values))
91
+ ]
92
+ plt.plot(avg_coverage_by_time_bin, marker="o") # pyright: ignore
93
+
94
+ _set_time_bin_xticks(n_time_bins, reverse_xticks)
95
+ plt.ylabel("Coverage") # pyright: ignore
96
+ plt.ylim((0, 1.05)) # pyright: ignore
97
+ plt.savefig(output_filepath) # pyright: ignore
98
+ plt.close()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bella-companion
3
- Version: 0.0.3
3
+ Version: 0.0.5
4
4
  Summary:
5
5
  Author: gabriele-marino
6
6
  Author-email: gabmarino.8601@gmail.com
@@ -10,6 +10,7 @@ Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
12
  Requires-Dist: arviz (>=0.22.0,<0.23.0)
13
- Requires-Dist: bella-lumiere (>=0.0.10,<0.0.11)
13
+ Requires-Dist: bella-lumiere (>=0.0.13,<0.0.14)
14
14
  Requires-Dist: dotenv (>=0.9.9,<0.10.0)
15
- Requires-Dist: phylogenie (>=2.1.27,<3.0.0)
15
+ Requires-Dist: phylogenie (>=2.1.28,<3.0.0)
16
+ Requires-Dist: seaborn (>=0.13.2,<0.14.0)
@@ -1,5 +1,5 @@
1
1
  bella_companion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- bella_companion/cli.py,sha256=0sPnzGyUGo2OBZ0rj17ZGzMdwNH0o-BXKsYtCJjzGvQ,968
2
+ bella_companion/cli.py,sha256=fjCIIsguBUnm_u1dyuo42INGTDa5lh8laiW1yZN65So,1178
3
3
  bella_companion/fbd_empirical/data/body_mass.csv,sha256=-UkKNtm9m3g4PjY3BcfdP6z5nL_I6p9cq6cgZ-bWKI8,30360
4
4
  bella_companion/fbd_empirical/data/change_times.csv,sha256=zmc9_z91-XMwKyIoP9v9dVlLcf4MeIHkQiHLjoMriOo,120
5
5
  bella_companion/fbd_empirical/data/sampling_change_times.csv,sha256=Gwi9RcMFy89RyvfxKVZ_MoKVRHOZLuwB_3LEaq8asMQ,32
@@ -9,29 +9,34 @@ bella_companion/fbd_empirical/notbooks.ipynb,sha256=O45kmz0lZENRDFbKXEWPsIKATfF5
9
9
  bella_companion/fbd_empirical/params.json,sha256=hU23LniClZL_GSBAxIEJUJgMa93AM8zdtFOq6mt3vkI,311
10
10
  bella_companion/fbd_empirical/run_beast.py,sha256=2sV2UmxOfWmbueiU6D0p3lueMYiZyIkSKYoblTMrYuA,1935
11
11
  bella_companion/fbd_empirical/summarize_logs.py,sha256=O6rhE606Wa98a8b1KKlLPjUOro1pfyqVTLdQksQMG0g,1439
12
- bella_companion/simulations/__init__.py,sha256=i6Fe7l5sUJY9hPxdg6L_FVhwbSPhNxQNMb-m33JlfxI,258
12
+ bella_companion/simulations/__init__.py,sha256=EBZAcI8skNPKjrA7CjrqH9ea7DTntmydAD0RqsxNUMM,352
13
13
  bella_companion/simulations/features.py,sha256=DZOBpJGlQ0UinqUZYbEtoemZ2eQGVLV_i-DfpW31qJI,104
14
- bella_companion/simulations/figures/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- bella_companion/simulations/figures/epi_explainations.py,sha256=omiJgyIY-I6zcJAcyOF7GJ2pba6pMZySLkWy7OFrjFY,3093
16
- bella_companion/simulations/figures/epi_predictions.py,sha256=4yXwOBKxUv4kgZdI9zAMEhZ0QCNKZdkAafRQ1RTeaWg,1835
17
- bella_companion/simulations/figures/fbd_explainations.py,sha256=9Uj7yttpn_TH5HqycW8R-Nlky9A9aFXDXRpXQuT1L4s,3037
18
- bella_companion/simulations/figures/fbd_predictions.py,sha256=jdXYCLledZEWoPCIuTLhHEPMdeG6YXvf5xZnEOslv-U,2119
19
- bella_companion/simulations/figures/scenarios.py,sha256=vyybn3Qhfq96N8tvW0wSzpFoHHP8EIc8dkOz63o_Atw,2492
20
- bella_companion/simulations/figures/utils.py,sha256=sY8wFBg02fv5ugpJ80EqQishD_HEdLwhqsw2LfM7wEo,8539
14
+ bella_companion/simulations/figures/__init__.py,sha256=aBYbJntH4egFmkSSWiVMYDEApXPYxJD7eA3TCPNNegM,658
15
+ bella_companion/simulations/figures/epi_multitype_results.py,sha256=j85WgvN5AyAtX-CalMegr2lwlAZBmzyJxkikBPXRjCc,2629
16
+ bella_companion/simulations/figures/epi_skyline_results.py,sha256=Ej1iGHLnLkUohx7oC9gRTLe8T--5WS-JtyViLbxraLA,1647
17
+ bella_companion/simulations/figures/explain/__init__.py,sha256=DnmVIWO65nTT5VsWnbS7NyYgKEY_eo4oMCtCY_ML2Vk,260
18
+ bella_companion/simulations/figures/explain/pdp.py,sha256=bSrH-TZcT2mVZIBXTUeoRvShsX2bIUyhUvgmbzXLIbg,3925
19
+ bella_companion/simulations/figures/explain/shap.py,sha256=tjYX5H3C-IiygG9scXhLgtlCRxEBlsQ2OW5gVWL15kY,1936
20
+ bella_companion/simulations/figures/fbd_2traits_results.py,sha256=YP_ksmh5Sj2WnoxqRuSLt-gd-W7k2lUHs4idAbvbdhw,2708
21
+ bella_companion/simulations/figures/fbd_no_traits_results.py,sha256=O6hx_OZVSHmW0xq9T4q4oz6eskC_xcFZGMEdX8ZilrU,1942
22
+ bella_companion/simulations/figures/scenarios.py,sha256=OKh9_-ZvzNgWsO3-Vd0Aw3ndjVf76i_OuCvsKI-5r2s,2795
21
23
  bella_companion/simulations/generate_data.py,sha256=H8OV4ZlTGZB-jXaROTPmOsK3UxRiU-GrX40l-shliw8,728
24
+ bella_companion/simulations/generate_figures.py,sha256=layMgoj3Bfl78Ceb1oE7YirAQ8zhjDyD9IrxDRXf6go,657
22
25
  bella_companion/simulations/run_beast.py,sha256=xOuwE0w4IbOqqCSym6kHsAEhfGT2mWdA-jmUZuviMbc,3121
23
26
  bella_companion/simulations/scenarios/__init__.py,sha256=3Kl1lKcFpfb3vLX64DmSW4XCF5kXU1ZoHtstFH-ZIzU,876
24
27
  bella_companion/simulations/scenarios/common.py,sha256=_ddaSuTvEVdttGkXB4HPc2B7IB1F_GBOCW3cVOPZ-ZM,807
25
28
  bella_companion/simulations/scenarios/epi_multitype.py,sha256=GWGIiqvYwX_FrT_3RXkZKYGDht9nZ7ceHRBKUvXDPnA,2432
26
29
  bella_companion/simulations/scenarios/epi_skyline.py,sha256=JqnOVATECxBUqEbkR5lBlMI2O8k4hO6ipR8k9cHUsm0,2365
27
- bella_companion/simulations/scenarios/fbd_2traits.py,sha256=sCtdWyV6GQQOIhnL9Dd8NIbAR-StTwUTD9-b_BalmFQ,3552
30
+ bella_companion/simulations/scenarios/fbd_2traits.py,sha256=G24WCAHrWPwvQeElsy4UMl1I9ALFnVQp6wXuc25Ie-g,3552
28
31
  bella_companion/simulations/scenarios/fbd_no_traits.py,sha256=R6CH0fVeQg-Iesl39pq2uY8ICVEO4VZbvUVUCGwauJU,2520
29
32
  bella_companion/simulations/scenarios/scenario.py,sha256=_FRWAyOFbw94lAzd3zCD-1ek4TrssoiXfXRQPShLiIA,620
30
33
  bella_companion/simulations/summarize_logs.py,sha256=5IdzR9IwjeF2LgZmzpuK0rPfYMct2OUgEp0QyUbUS7g,1263
31
- bella_companion/utils/__init__.py,sha256=_5tLPH_3GHtimNcH0Yd9Z6yIM3WkWkNApNGLzFnF6nY,222
32
- bella_companion/utils/beast.py,sha256=RG-iSEFuL92K6yxUV2nxdmcVqfrEiPhaYTmReW4ZoWk,2189
34
+ bella_companion/utils/__init__.py,sha256=bb0-pCjQNwGaUgG2v9htgOLlOMiB1-kdZnSQ_QY2QAo,647
35
+ bella_companion/utils/beast.py,sha256=5Vsv98VTE9HrY56WzUSMECjD_rIPxHTMRMD1ZzmA6wY,2181
36
+ bella_companion/utils/explain.py,sha256=uP7HPyn2YiykAI69BQV3RooDpC6qKoCLXfp3Uibp4zk,1475
37
+ bella_companion/utils/plots.py,sha256=dB_GiJ1HGrZ93cqODz6kB-HeDRPwlm2MkMe9rJZGnfs,3117
33
38
  bella_companion/utils/slurm.py,sha256=v5DaG7YHVyK8KRFptgGDC6I8jxEhyJuMVK9N08pZSAI,1812
34
- bella_companion-0.0.3.dist-info/METADATA,sha256=ARm9evUES-6JEG46mregbCHmiNOMW59VGT_msd2V6Bg,534
35
- bella_companion-0.0.3.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
36
- bella_companion-0.0.3.dist-info/entry_points.txt,sha256=rSeKoAhmjnQqAYFcXBv0gAM2ViJfJe0D8_dD-fWrXeg,50
37
- bella_companion-0.0.3.dist-info/RECORD,,
39
+ bella_companion-0.0.5.dist-info/METADATA,sha256=IIRIUzr8WcpcLOApct-WoFLrEBkzAxhjXPIY6HucbOk,576
40
+ bella_companion-0.0.5.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
41
+ bella_companion-0.0.5.dist-info/entry_points.txt,sha256=rSeKoAhmjnQqAYFcXBv0gAM2ViJfJe0D8_dD-fWrXeg,50
42
+ bella_companion-0.0.5.dist-info/RECORD,,
@@ -1,109 +0,0 @@
1
- import os
2
- from functools import partial
3
-
4
- import joblib
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- import polars as pl
8
- import src.config as cfg
9
- from lumiere.backend import sigmoid
10
- from src.simulations.figures.utils import (
11
- plot_partial_dependencies,
12
- plot_shap_features_importance,
13
- )
14
- from src.simulations.scenarios.epi_multitype import (
15
- MIGRATION_PREDICTOR,
16
- MIGRATION_RATE_UPPER,
17
- MIGRATION_RATES,
18
- SCENARIO,
19
- )
20
- from src.utils import set_plt_rcparams
21
-
22
-
23
- def set_plt_rcparams():
24
- plt.rcParams["pdf.fonttype"] = 42
25
- plt.rcParams["xtick.labelsize"] = 14
26
- plt.rcParams["ytick.labelsize"] = 14
27
- plt.rcParams["font.size"] = 14
28
- plt.rcParams["figure.constrained_layout.use"] = True
29
- plt.rcParams["lines.linewidth"] = 3
30
-
31
-
32
- def _plot_predictions(log_summary: pl.DataFrame, output_dir: str):
33
- sort_idx = np.argsort(MIGRATION_PREDICTOR.flatten())
34
-
35
- estimates = np.array(
36
- [
37
- log_summary[f"{target}_median"].median()
38
- for target in SCENARIO.targets["migrationRate"]
39
- ]
40
- )
41
- lower = np.array(
42
- [
43
- log_summary[f"{target}_lower"].median()
44
- for target in SCENARIO.targets["migrationRate"]
45
- ]
46
- )
47
- upper = np.array(
48
- [
49
- log_summary[f"{target}_upper"].median()
50
- for target in SCENARIO.targets["migrationRate"]
51
- ]
52
- )
53
- plt.errorbar(
54
- MIGRATION_PREDICTOR.flatten()[sort_idx],
55
- estimates[sort_idx],
56
- yerr=[
57
- estimates[sort_idx] - lower[sort_idx],
58
- upper[sort_idx] - estimates[sort_idx],
59
- ],
60
- marker="o",
61
- color="C2",
62
- )
63
- plt.plot(
64
- MIGRATION_PREDICTOR.flatten()[sort_idx],
65
- estimates[sort_idx],
66
- marker="o",
67
- color="C2",
68
- )
69
- plt.plot(
70
- MIGRATION_PREDICTOR.flatten()[sort_idx],
71
- MIGRATION_RATES.flatten()[sort_idx],
72
- linestyle="dashed",
73
- marker="o",
74
- color="k",
75
- )
76
- plt.xlabel("Migration predictor")
77
- plt.ylabel("Migration rate")
78
- plt.savefig(os.path.join(output_dir, "predictions.svg"))
79
- plt.close()
80
-
81
-
82
- def main():
83
- output_dir = os.path.join(cfg.FIGURES_DIR, "epi-explainations")
84
- os.makedirs(output_dir, exist_ok=True)
85
-
86
- log_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, "epi-multitype")
87
- model = "MLP-32_16"
88
- log_summary = pl.read_csv(os.path.join(log_dir, f"{model}.csv"))
89
- weights = joblib.load(os.path.join(log_dir, f"{model}_weights.pkl"))
90
-
91
- set_plt_rcparams()
92
-
93
- _plot_predictions(log_summary, output_dir)
94
- plot_partial_dependencies(
95
- weights=weights["migrationRate"],
96
- features=SCENARIO.features["migrationRate"],
97
- output_dir=output_dir,
98
- output_activation=partial(sigmoid, upper=MIGRATION_RATE_UPPER),
99
- )
100
- plot_shap_features_importance(
101
- weights=weights["migrationRate"],
102
- features=SCENARIO.features["migrationRate"],
103
- output_file=os.path.join(output_dir, "shap_values.svg"),
104
- output_activation=partial(sigmoid, upper=MIGRATION_RATE_UPPER),
105
- )
106
-
107
-
108
- if __name__ == "__main__":
109
- main()
@@ -1,58 +0,0 @@
1
- import os
2
-
3
- import matplotlib.pyplot as plt
4
- import polars as pl
5
-
6
- import src.config as cfg
7
- from src.simulations.figures.utils import (
8
- plot_coverage_per_time_bin,
9
- plot_maes_per_time_bin,
10
- step,
11
- )
12
- from src.simulations.scenarios.epi_skyline import REPRODUCTION_NUMBERS
13
- from src.utils import set_plt_rcparams
14
-
15
-
16
- def main():
17
- output_dir = os.path.join(cfg.FIGURES_DIR, "epi-predictions")
18
- os.makedirs(output_dir, exist_ok=True)
19
-
20
- set_plt_rcparams()
21
-
22
- for i, reproduction_number in enumerate(REPRODUCTION_NUMBERS, start=1):
23
- summaries_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, f"epi-skyline_{i}")
24
- logs_summaries = {
25
- "Nonparametric": pl.read_csv(
26
- os.path.join(summaries_dir, "Nonparametric.csv")
27
- ),
28
- "GLM": pl.read_csv(os.path.join(summaries_dir, "GLM.csv")),
29
- "MLP": pl.read_csv(os.path.join(summaries_dir, "MLP-16_8.csv")),
30
- }
31
- true_values = {"reproductionNumber": reproduction_number}
32
-
33
- for log_summary in logs_summaries.values():
34
- step(
35
- [
36
- log_summary[f"reproductionNumberi{i}_median"].median()
37
- for i in range(len(reproduction_number))
38
- ]
39
- )
40
- step(reproduction_number, color="k", linestyle="--")
41
- plt.ylabel("Reproduction number")
42
- plt.savefig(os.path.join(output_dir, f"epi-skyline_{i}-predictions.svg"))
43
- plt.close()
44
-
45
- plot_coverage_per_time_bin(
46
- logs_summaries,
47
- true_values,
48
- os.path.join(output_dir, f"epi-skyline_{i}-coverage.svg"),
49
- )
50
- plot_maes_per_time_bin(
51
- logs_summaries,
52
- true_values,
53
- os.path.join(output_dir, f"epi-skyline_{i}-maes.svg"),
54
- )
55
-
56
-
57
- if __name__ == "__main__":
58
- main()
@@ -1,99 +0,0 @@
1
- import ast
2
- import os
3
- from functools import partial
4
-
5
- import joblib
6
- import matplotlib.pyplot as plt
7
- import polars as pl
8
- from joblib import Parallel, delayed
9
- from lumiere.backend import sigmoid
10
-
11
- import src.config as cfg
12
- from src.figures.utils import (
13
- plot_partial_dependencies,
14
- plot_shap_features_importance,
15
- step,
16
- )
17
- from src.simulations.scenarios.fbd_2traits import (
18
- BIRTH_RATE_TRAIT1_SET,
19
- BIRTH_RATE_TRAIT1_UNSET,
20
- DEATH_RATE_TRAIT1_SET,
21
- DEATH_RATE_TRAIT1_UNSET,
22
- FBD_RATE_UPPER,
23
- N_TIME_BINS,
24
- SCENARIO,
25
- STATES,
26
- )
27
- from src.utils import set_plt_rcparams
28
-
29
-
30
- def _plot_predictions(log_summary: pl.DataFrame, output_dir: str):
31
- for rate in ["birth", "death"]:
32
- label = r"\lambda" if rate == "birth" else r"\mu"
33
- rate_trait_1_set = (
34
- BIRTH_RATE_TRAIT1_UNSET if rate == "birth" else DEATH_RATE_TRAIT1_UNSET
35
- )
36
- rate_trait_1_unset = (
37
- BIRTH_RATE_TRAIT1_SET if rate == "birth" else DEATH_RATE_TRAIT1_SET
38
- )
39
- for state in STATES:
40
- estimates = [
41
- log_summary[f"{rate}Ratei{i}_{state}_median"].median()
42
- for i in range(N_TIME_BINS)
43
- ]
44
- step(
45
- estimates,
46
- label=rf"${label}_{{{state[0]},{state[1]}}}$",
47
- reverse_xticks=True,
48
- )
49
- step(
50
- rate_trait_1_unset,
51
- color="k",
52
- linestyle="dashed",
53
- label=rf"${label}_{{0,0}}$ = ${label}_{{0,1}}$",
54
- reverse_xticks=True,
55
- )
56
- step(
57
- rate_trait_1_set,
58
- color="gray",
59
- linestyle="dashed",
60
- label=rf"${label}_{{1,0}}$ = ${label}_{{1,1}}$",
61
- reverse_xticks=True,
62
- )
63
- plt.legend()
64
- plt.ylabel(rf"${label}$")
65
- plt.savefig(os.path.join(output_dir, rate, "predictions.svg"))
66
- plt.close()
67
-
68
-
69
- def main():
70
- output_dir = os.path.join(cfg.FIGURES_DIR, "fbd-explainations")
71
- for rate in ["birth", "death"]:
72
- os.makedirs(os.path.join(output_dir, rate), exist_ok=True)
73
-
74
- log_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, "fbd-2traits")
75
- model = "MLP-32_16"
76
- log_summary = pl.read_csv(os.path.join(log_dir, f"{model}.csv"))
77
- weights = joblib.load(os.path.join(log_dir, f"{model}_weights.pkl"))
78
-
79
- set_plt_rcparams()
80
-
81
- _plot_predictions(log_summary, output_dir)
82
-
83
- for rate in ["birth", "death"]:
84
- plot_partial_dependencies(
85
- weights=weights[f"{rate}Rate"],
86
- features=SCENARIO.features[f"{rate}Rate"],
87
- output_dir=os.path.join(output_dir, rate),
88
- output_activation=partial(sigmoid, upper=FBD_RATE_UPPER),
89
- )
90
- plot_shap_features_importance(
91
- weights=weights[f"{rate}Rate"],
92
- features=SCENARIO.features[f"{rate}Rate"],
93
- output_file=os.path.join(output_dir, rate, "shap_values.svg"),
94
- output_activation=partial(sigmoid, upper=FBD_RATE_UPPER),
95
- )
96
-
97
-
98
- if __name__ == "__main__":
99
- main()
@@ -1,66 +0,0 @@
1
- import os
2
-
3
- import matplotlib.pyplot as plt
4
- import polars as pl
5
-
6
- import src.config as cfg
7
- from src.simulations.figures.utils import (
8
- plot_coverage_per_time_bin,
9
- plot_maes_per_time_bin,
10
- step,
11
- )
12
- from src.simulations.scenarios.fbd_no_traits import BIRTH_RATES, DEATH_RATES
13
- from src.utils import set_plt_rcparams
14
-
15
-
16
- def main():
17
- output_dir = os.path.join(cfg.FIGURES_DIR, "fbd-predictions")
18
- os.makedirs(output_dir, exist_ok=True)
19
-
20
- set_plt_rcparams()
21
-
22
- for i, (birth_rate, death_rate) in enumerate(
23
- zip(BIRTH_RATES, DEATH_RATES), start=1
24
- ):
25
- summaries_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, f"fbd-no-traits_{i}")
26
- logs_summaries = {
27
- "Nonparametric": pl.read_csv(
28
- os.path.join(summaries_dir, "Nonparametric.csv")
29
- ),
30
- "GLM": pl.read_csv(os.path.join(summaries_dir, "GLM.csv")),
31
- "MLP": pl.read_csv(os.path.join(summaries_dir, "MLP-16_8.csv")),
32
- }
33
- true_values = {"birthRate": birth_rate, "deathRate": death_rate}
34
-
35
- for id, rate in true_values.items():
36
- for log_summary in logs_summaries.values():
37
- step(
38
- [
39
- log_summary[f"{id}i{i}_median"].median()
40
- for i in range(len(rate))
41
- ],
42
- reverse_xticks=True,
43
- )
44
- step(rate, color="k", linestyle="--", reverse_xticks=True)
45
- plt.ylabel(r"$\lambda$" if id == "birthRate" else r"$\mu$")
46
- plt.savefig(
47
- os.path.join(output_dir, f"fbd-no-traits_{i}-predictions-{id}.svg")
48
- )
49
- plt.close()
50
-
51
- plot_coverage_per_time_bin(
52
- logs_summaries,
53
- true_values,
54
- os.path.join(output_dir, f"fbd-no-traits_{i}-coverage.svg"),
55
- reverse_xticks=True,
56
- )
57
- plot_maes_per_time_bin(
58
- logs_summaries,
59
- true_values,
60
- os.path.join(output_dir, f"fbd-no-traits_{i}-maes.svg"),
61
- reverse_xticks=True,
62
- )
63
-
64
-
65
- if __name__ == "__main__":
66
- main()