bella-companion 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bella-companion might be problematic. Click here for more details.
- bella_companion/cli.py +13 -4
- bella_companion/simulations/__init__.py +2 -1
- bella_companion/simulations/figures/__init__.py +21 -0
- bella_companion/simulations/figures/epi_multitype_results.py +81 -0
- bella_companion/simulations/figures/epi_skyline_results.py +46 -0
- bella_companion/simulations/figures/explain/__init__.py +6 -0
- bella_companion/simulations/figures/explain/pdp.py +101 -0
- bella_companion/simulations/figures/explain/shap.py +56 -0
- bella_companion/simulations/figures/fbd_2traits_results.py +83 -0
- bella_companion/simulations/figures/fbd_no_traits_results.py +58 -0
- bella_companion/simulations/figures/scenarios.py +38 -32
- bella_companion/simulations/generate_figures.py +24 -0
- bella_companion/simulations/scenarios/fbd_2traits.py +2 -2
- bella_companion/utils/__init__.py +20 -1
- bella_companion/utils/beast.py +1 -1
- bella_companion/utils/explain.py +45 -0
- bella_companion/utils/plots.py +98 -0
- {bella_companion-0.0.3.dist-info → bella_companion-0.0.5.dist-info}/METADATA +4 -3
- {bella_companion-0.0.3.dist-info → bella_companion-0.0.5.dist-info}/RECORD +21 -16
- bella_companion/simulations/figures/epi_explainations.py +0 -109
- bella_companion/simulations/figures/epi_predictions.py +0 -58
- bella_companion/simulations/figures/fbd_explainations.py +0 -99
- bella_companion/simulations/figures/fbd_predictions.py +0 -66
- bella_companion/simulations/figures/utils.py +0 -250
- {bella_companion-0.0.3.dist-info → bella_companion-0.0.5.dist-info}/WHEEL +0 -0
- {bella_companion-0.0.3.dist-info → bella_companion-0.0.5.dist-info}/entry_points.txt +0 -0
bella_companion/utils/beast.py
CHANGED
|
@@ -9,7 +9,7 @@ import numpy as np
|
|
|
9
9
|
import polars as pl
|
|
10
10
|
from joblib import Parallel, delayed
|
|
11
11
|
from lumiere import read_log_file, read_weights
|
|
12
|
-
from lumiere.
|
|
12
|
+
from lumiere.typing import Weights
|
|
13
13
|
from tqdm import tqdm
|
|
14
14
|
|
|
15
15
|
from bella_companion.utils.slurm import get_job_metadata
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from lumiere import get_partial_dependence_values, get_shap_features_importance
|
|
3
|
+
from lumiere.typing import ActivationFunction, Weights
|
|
4
|
+
from numpy.typing import ArrayLike
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_median_partial_dependence_values(
|
|
8
|
+
weights: list[Weights], # shape: (n_weight_samples, ...)
|
|
9
|
+
features_grid: list[list[float]],
|
|
10
|
+
hidden_activation: ActivationFunction,
|
|
11
|
+
output_activation: ActivationFunction,
|
|
12
|
+
) -> list[list[float]]: # shape: (n_features, n_grid_points)
|
|
13
|
+
pdvalues = [
|
|
14
|
+
get_partial_dependence_values(
|
|
15
|
+
weights=w,
|
|
16
|
+
features_grid=features_grid,
|
|
17
|
+
hidden_activation=hidden_activation,
|
|
18
|
+
output_activation=output_activation,
|
|
19
|
+
)
|
|
20
|
+
for w in weights
|
|
21
|
+
]
|
|
22
|
+
return [
|
|
23
|
+
np.median([pd[feature_idx] for pd in pdvalues], axis=0).tolist()
|
|
24
|
+
for feature_idx in range(len(features_grid))
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_median_shap_features_importance(
|
|
29
|
+
weights: list[Weights],
|
|
30
|
+
inputs: ArrayLike,
|
|
31
|
+
hidden_activation: ActivationFunction,
|
|
32
|
+
output_activation: ActivationFunction,
|
|
33
|
+
) -> list[float]: # length: n_features
|
|
34
|
+
features_importance = np.array(
|
|
35
|
+
[
|
|
36
|
+
get_shap_features_importance(
|
|
37
|
+
weights=w,
|
|
38
|
+
inputs=inputs,
|
|
39
|
+
hidden_activation=hidden_activation,
|
|
40
|
+
output_activation=output_activation,
|
|
41
|
+
)
|
|
42
|
+
for w in weights
|
|
43
|
+
]
|
|
44
|
+
)
|
|
45
|
+
return np.median(features_importance, axis=0).tolist()
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
import numpy as np
|
|
5
|
+
import polars as pl
|
|
6
|
+
import seaborn as sns
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def _set_time_bin_xticks(n: int, reverse: bool = False):
|
|
10
|
+
xticks_labels = range(n)
|
|
11
|
+
if reverse:
|
|
12
|
+
xticks_labels = reversed(xticks_labels)
|
|
13
|
+
plt.xticks(ticks=range(n), labels=list(map(str, xticks_labels))) # pyright: ignore
|
|
14
|
+
plt.xlabel("Time bin") # pyright: ignore
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def step(
|
|
18
|
+
x: list[float],
|
|
19
|
+
reverse_xticks: bool = False,
|
|
20
|
+
label: str | None = None,
|
|
21
|
+
color: str | None = None,
|
|
22
|
+
linestyle: str | None = None,
|
|
23
|
+
):
|
|
24
|
+
x = [x[0], *x]
|
|
25
|
+
n = len(x)
|
|
26
|
+
plt.step( # pyright: ignore
|
|
27
|
+
list(range(n)), x, label=label, color=color, linestyle=linestyle
|
|
28
|
+
)
|
|
29
|
+
_set_time_bin_xticks(n, reverse_xticks)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _count_time_bins(true_values: dict[str, list[float]]) -> int:
|
|
33
|
+
assert (
|
|
34
|
+
len({len(true_value) for true_value in true_values.values()}) == 1
|
|
35
|
+
), "All targets must have the same number of change times."
|
|
36
|
+
return len(next(iter((true_values.values()))))
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def plot_maes_per_time_bin(
|
|
40
|
+
logs_summaries: dict[str, pl.DataFrame],
|
|
41
|
+
true_values: dict[str, list[float]],
|
|
42
|
+
output_filepath: Path,
|
|
43
|
+
reverse_xticks: bool = False,
|
|
44
|
+
):
|
|
45
|
+
def _mae(target: str, i: int) -> pl.Expr:
|
|
46
|
+
return (pl.col(f"{target}i{i}_median") - true_values[target][i]).abs()
|
|
47
|
+
|
|
48
|
+
n_time_bins = _count_time_bins(true_values)
|
|
49
|
+
df = pl.concat(
|
|
50
|
+
logs_summaries[model]
|
|
51
|
+
.select(
|
|
52
|
+
pl.mean_horizontal([_mae(target, i) for target in true_values]).alias("MAE")
|
|
53
|
+
)
|
|
54
|
+
.with_columns(pl.lit(i).alias("Time bin"), pl.lit(model).alias("Model"))
|
|
55
|
+
for i in range(n_time_bins)
|
|
56
|
+
for model in logs_summaries
|
|
57
|
+
)
|
|
58
|
+
sns.violinplot(
|
|
59
|
+
x="Time bin",
|
|
60
|
+
y="MAE",
|
|
61
|
+
hue="Model",
|
|
62
|
+
data=df,
|
|
63
|
+
inner=None,
|
|
64
|
+
cut=0,
|
|
65
|
+
density_norm="width",
|
|
66
|
+
legend=False,
|
|
67
|
+
)
|
|
68
|
+
_set_time_bin_xticks(n_time_bins, reverse_xticks)
|
|
69
|
+
plt.savefig(output_filepath) # pyright: ignore
|
|
70
|
+
plt.close()
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def plot_coverage_per_time_bin(
|
|
74
|
+
logs_summaries: dict[str, pl.DataFrame],
|
|
75
|
+
true_values: dict[str, list[float]],
|
|
76
|
+
output_filepath: Path,
|
|
77
|
+
reverse_xticks: bool = False,
|
|
78
|
+
):
|
|
79
|
+
def _coverage(model: str, target: str, i: int) -> float:
|
|
80
|
+
lower_bound = logs_summaries[model][f"{target}i{i}_lower"]
|
|
81
|
+
upper_bound = logs_summaries[model][f"{target}i{i}_upper"]
|
|
82
|
+
true_value = true_values[target][i]
|
|
83
|
+
N = len(logs_summaries[model])
|
|
84
|
+
return ((lower_bound <= true_value) & (true_value <= upper_bound)).sum() / N
|
|
85
|
+
|
|
86
|
+
n_time_bins = _count_time_bins(true_values)
|
|
87
|
+
for model in logs_summaries:
|
|
88
|
+
avg_coverage_by_time_bin = [
|
|
89
|
+
np.mean([_coverage(model, target, i) for target in true_values])
|
|
90
|
+
for i in range(_count_time_bins(true_values))
|
|
91
|
+
]
|
|
92
|
+
plt.plot(avg_coverage_by_time_bin, marker="o") # pyright: ignore
|
|
93
|
+
|
|
94
|
+
_set_time_bin_xticks(n_time_bins, reverse_xticks)
|
|
95
|
+
plt.ylabel("Coverage") # pyright: ignore
|
|
96
|
+
plt.ylim((0, 1.05)) # pyright: ignore
|
|
97
|
+
plt.savefig(output_filepath) # pyright: ignore
|
|
98
|
+
plt.close()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: bella-companion
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.5
|
|
4
4
|
Summary:
|
|
5
5
|
Author: gabriele-marino
|
|
6
6
|
Author-email: gabmarino.8601@gmail.com
|
|
@@ -10,6 +10,7 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
10
10
|
Classifier: Programming Language :: Python :: 3.11
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.12
|
|
12
12
|
Requires-Dist: arviz (>=0.22.0,<0.23.0)
|
|
13
|
-
Requires-Dist: bella-lumiere (>=0.0.
|
|
13
|
+
Requires-Dist: bella-lumiere (>=0.0.13,<0.0.14)
|
|
14
14
|
Requires-Dist: dotenv (>=0.9.9,<0.10.0)
|
|
15
|
-
Requires-Dist: phylogenie (>=2.1.
|
|
15
|
+
Requires-Dist: phylogenie (>=2.1.28,<3.0.0)
|
|
16
|
+
Requires-Dist: seaborn (>=0.13.2,<0.14.0)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
bella_companion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
bella_companion/cli.py,sha256=
|
|
2
|
+
bella_companion/cli.py,sha256=fjCIIsguBUnm_u1dyuo42INGTDa5lh8laiW1yZN65So,1178
|
|
3
3
|
bella_companion/fbd_empirical/data/body_mass.csv,sha256=-UkKNtm9m3g4PjY3BcfdP6z5nL_I6p9cq6cgZ-bWKI8,30360
|
|
4
4
|
bella_companion/fbd_empirical/data/change_times.csv,sha256=zmc9_z91-XMwKyIoP9v9dVlLcf4MeIHkQiHLjoMriOo,120
|
|
5
5
|
bella_companion/fbd_empirical/data/sampling_change_times.csv,sha256=Gwi9RcMFy89RyvfxKVZ_MoKVRHOZLuwB_3LEaq8asMQ,32
|
|
@@ -9,29 +9,34 @@ bella_companion/fbd_empirical/notbooks.ipynb,sha256=O45kmz0lZENRDFbKXEWPsIKATfF5
|
|
|
9
9
|
bella_companion/fbd_empirical/params.json,sha256=hU23LniClZL_GSBAxIEJUJgMa93AM8zdtFOq6mt3vkI,311
|
|
10
10
|
bella_companion/fbd_empirical/run_beast.py,sha256=2sV2UmxOfWmbueiU6D0p3lueMYiZyIkSKYoblTMrYuA,1935
|
|
11
11
|
bella_companion/fbd_empirical/summarize_logs.py,sha256=O6rhE606Wa98a8b1KKlLPjUOro1pfyqVTLdQksQMG0g,1439
|
|
12
|
-
bella_companion/simulations/__init__.py,sha256=
|
|
12
|
+
bella_companion/simulations/__init__.py,sha256=EBZAcI8skNPKjrA7CjrqH9ea7DTntmydAD0RqsxNUMM,352
|
|
13
13
|
bella_companion/simulations/features.py,sha256=DZOBpJGlQ0UinqUZYbEtoemZ2eQGVLV_i-DfpW31qJI,104
|
|
14
|
-
bella_companion/simulations/figures/__init__.py,sha256=
|
|
15
|
-
bella_companion/simulations/figures/
|
|
16
|
-
bella_companion/simulations/figures/
|
|
17
|
-
bella_companion/simulations/figures/
|
|
18
|
-
bella_companion/simulations/figures/
|
|
19
|
-
bella_companion/simulations/figures/
|
|
20
|
-
bella_companion/simulations/figures/
|
|
14
|
+
bella_companion/simulations/figures/__init__.py,sha256=aBYbJntH4egFmkSSWiVMYDEApXPYxJD7eA3TCPNNegM,658
|
|
15
|
+
bella_companion/simulations/figures/epi_multitype_results.py,sha256=j85WgvN5AyAtX-CalMegr2lwlAZBmzyJxkikBPXRjCc,2629
|
|
16
|
+
bella_companion/simulations/figures/epi_skyline_results.py,sha256=Ej1iGHLnLkUohx7oC9gRTLe8T--5WS-JtyViLbxraLA,1647
|
|
17
|
+
bella_companion/simulations/figures/explain/__init__.py,sha256=DnmVIWO65nTT5VsWnbS7NyYgKEY_eo4oMCtCY_ML2Vk,260
|
|
18
|
+
bella_companion/simulations/figures/explain/pdp.py,sha256=bSrH-TZcT2mVZIBXTUeoRvShsX2bIUyhUvgmbzXLIbg,3925
|
|
19
|
+
bella_companion/simulations/figures/explain/shap.py,sha256=tjYX5H3C-IiygG9scXhLgtlCRxEBlsQ2OW5gVWL15kY,1936
|
|
20
|
+
bella_companion/simulations/figures/fbd_2traits_results.py,sha256=YP_ksmh5Sj2WnoxqRuSLt-gd-W7k2lUHs4idAbvbdhw,2708
|
|
21
|
+
bella_companion/simulations/figures/fbd_no_traits_results.py,sha256=O6hx_OZVSHmW0xq9T4q4oz6eskC_xcFZGMEdX8ZilrU,1942
|
|
22
|
+
bella_companion/simulations/figures/scenarios.py,sha256=OKh9_-ZvzNgWsO3-Vd0Aw3ndjVf76i_OuCvsKI-5r2s,2795
|
|
21
23
|
bella_companion/simulations/generate_data.py,sha256=H8OV4ZlTGZB-jXaROTPmOsK3UxRiU-GrX40l-shliw8,728
|
|
24
|
+
bella_companion/simulations/generate_figures.py,sha256=layMgoj3Bfl78Ceb1oE7YirAQ8zhjDyD9IrxDRXf6go,657
|
|
22
25
|
bella_companion/simulations/run_beast.py,sha256=xOuwE0w4IbOqqCSym6kHsAEhfGT2mWdA-jmUZuviMbc,3121
|
|
23
26
|
bella_companion/simulations/scenarios/__init__.py,sha256=3Kl1lKcFpfb3vLX64DmSW4XCF5kXU1ZoHtstFH-ZIzU,876
|
|
24
27
|
bella_companion/simulations/scenarios/common.py,sha256=_ddaSuTvEVdttGkXB4HPc2B7IB1F_GBOCW3cVOPZ-ZM,807
|
|
25
28
|
bella_companion/simulations/scenarios/epi_multitype.py,sha256=GWGIiqvYwX_FrT_3RXkZKYGDht9nZ7ceHRBKUvXDPnA,2432
|
|
26
29
|
bella_companion/simulations/scenarios/epi_skyline.py,sha256=JqnOVATECxBUqEbkR5lBlMI2O8k4hO6ipR8k9cHUsm0,2365
|
|
27
|
-
bella_companion/simulations/scenarios/fbd_2traits.py,sha256=
|
|
30
|
+
bella_companion/simulations/scenarios/fbd_2traits.py,sha256=G24WCAHrWPwvQeElsy4UMl1I9ALFnVQp6wXuc25Ie-g,3552
|
|
28
31
|
bella_companion/simulations/scenarios/fbd_no_traits.py,sha256=R6CH0fVeQg-Iesl39pq2uY8ICVEO4VZbvUVUCGwauJU,2520
|
|
29
32
|
bella_companion/simulations/scenarios/scenario.py,sha256=_FRWAyOFbw94lAzd3zCD-1ek4TrssoiXfXRQPShLiIA,620
|
|
30
33
|
bella_companion/simulations/summarize_logs.py,sha256=5IdzR9IwjeF2LgZmzpuK0rPfYMct2OUgEp0QyUbUS7g,1263
|
|
31
|
-
bella_companion/utils/__init__.py,sha256=
|
|
32
|
-
bella_companion/utils/beast.py,sha256=
|
|
34
|
+
bella_companion/utils/__init__.py,sha256=bb0-pCjQNwGaUgG2v9htgOLlOMiB1-kdZnSQ_QY2QAo,647
|
|
35
|
+
bella_companion/utils/beast.py,sha256=5Vsv98VTE9HrY56WzUSMECjD_rIPxHTMRMD1ZzmA6wY,2181
|
|
36
|
+
bella_companion/utils/explain.py,sha256=uP7HPyn2YiykAI69BQV3RooDpC6qKoCLXfp3Uibp4zk,1475
|
|
37
|
+
bella_companion/utils/plots.py,sha256=dB_GiJ1HGrZ93cqODz6kB-HeDRPwlm2MkMe9rJZGnfs,3117
|
|
33
38
|
bella_companion/utils/slurm.py,sha256=v5DaG7YHVyK8KRFptgGDC6I8jxEhyJuMVK9N08pZSAI,1812
|
|
34
|
-
bella_companion-0.0.
|
|
35
|
-
bella_companion-0.0.
|
|
36
|
-
bella_companion-0.0.
|
|
37
|
-
bella_companion-0.0.
|
|
39
|
+
bella_companion-0.0.5.dist-info/METADATA,sha256=IIRIUzr8WcpcLOApct-WoFLrEBkzAxhjXPIY6HucbOk,576
|
|
40
|
+
bella_companion-0.0.5.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
41
|
+
bella_companion-0.0.5.dist-info/entry_points.txt,sha256=rSeKoAhmjnQqAYFcXBv0gAM2ViJfJe0D8_dD-fWrXeg,50
|
|
42
|
+
bella_companion-0.0.5.dist-info/RECORD,,
|
|
@@ -1,109 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from functools import partial
|
|
3
|
-
|
|
4
|
-
import joblib
|
|
5
|
-
import matplotlib.pyplot as plt
|
|
6
|
-
import numpy as np
|
|
7
|
-
import polars as pl
|
|
8
|
-
import src.config as cfg
|
|
9
|
-
from lumiere.backend import sigmoid
|
|
10
|
-
from src.simulations.figures.utils import (
|
|
11
|
-
plot_partial_dependencies,
|
|
12
|
-
plot_shap_features_importance,
|
|
13
|
-
)
|
|
14
|
-
from src.simulations.scenarios.epi_multitype import (
|
|
15
|
-
MIGRATION_PREDICTOR,
|
|
16
|
-
MIGRATION_RATE_UPPER,
|
|
17
|
-
MIGRATION_RATES,
|
|
18
|
-
SCENARIO,
|
|
19
|
-
)
|
|
20
|
-
from src.utils import set_plt_rcparams
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def set_plt_rcparams():
|
|
24
|
-
plt.rcParams["pdf.fonttype"] = 42
|
|
25
|
-
plt.rcParams["xtick.labelsize"] = 14
|
|
26
|
-
plt.rcParams["ytick.labelsize"] = 14
|
|
27
|
-
plt.rcParams["font.size"] = 14
|
|
28
|
-
plt.rcParams["figure.constrained_layout.use"] = True
|
|
29
|
-
plt.rcParams["lines.linewidth"] = 3
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
def _plot_predictions(log_summary: pl.DataFrame, output_dir: str):
|
|
33
|
-
sort_idx = np.argsort(MIGRATION_PREDICTOR.flatten())
|
|
34
|
-
|
|
35
|
-
estimates = np.array(
|
|
36
|
-
[
|
|
37
|
-
log_summary[f"{target}_median"].median()
|
|
38
|
-
for target in SCENARIO.targets["migrationRate"]
|
|
39
|
-
]
|
|
40
|
-
)
|
|
41
|
-
lower = np.array(
|
|
42
|
-
[
|
|
43
|
-
log_summary[f"{target}_lower"].median()
|
|
44
|
-
for target in SCENARIO.targets["migrationRate"]
|
|
45
|
-
]
|
|
46
|
-
)
|
|
47
|
-
upper = np.array(
|
|
48
|
-
[
|
|
49
|
-
log_summary[f"{target}_upper"].median()
|
|
50
|
-
for target in SCENARIO.targets["migrationRate"]
|
|
51
|
-
]
|
|
52
|
-
)
|
|
53
|
-
plt.errorbar(
|
|
54
|
-
MIGRATION_PREDICTOR.flatten()[sort_idx],
|
|
55
|
-
estimates[sort_idx],
|
|
56
|
-
yerr=[
|
|
57
|
-
estimates[sort_idx] - lower[sort_idx],
|
|
58
|
-
upper[sort_idx] - estimates[sort_idx],
|
|
59
|
-
],
|
|
60
|
-
marker="o",
|
|
61
|
-
color="C2",
|
|
62
|
-
)
|
|
63
|
-
plt.plot(
|
|
64
|
-
MIGRATION_PREDICTOR.flatten()[sort_idx],
|
|
65
|
-
estimates[sort_idx],
|
|
66
|
-
marker="o",
|
|
67
|
-
color="C2",
|
|
68
|
-
)
|
|
69
|
-
plt.plot(
|
|
70
|
-
MIGRATION_PREDICTOR.flatten()[sort_idx],
|
|
71
|
-
MIGRATION_RATES.flatten()[sort_idx],
|
|
72
|
-
linestyle="dashed",
|
|
73
|
-
marker="o",
|
|
74
|
-
color="k",
|
|
75
|
-
)
|
|
76
|
-
plt.xlabel("Migration predictor")
|
|
77
|
-
plt.ylabel("Migration rate")
|
|
78
|
-
plt.savefig(os.path.join(output_dir, "predictions.svg"))
|
|
79
|
-
plt.close()
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
def main():
|
|
83
|
-
output_dir = os.path.join(cfg.FIGURES_DIR, "epi-explainations")
|
|
84
|
-
os.makedirs(output_dir, exist_ok=True)
|
|
85
|
-
|
|
86
|
-
log_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, "epi-multitype")
|
|
87
|
-
model = "MLP-32_16"
|
|
88
|
-
log_summary = pl.read_csv(os.path.join(log_dir, f"{model}.csv"))
|
|
89
|
-
weights = joblib.load(os.path.join(log_dir, f"{model}_weights.pkl"))
|
|
90
|
-
|
|
91
|
-
set_plt_rcparams()
|
|
92
|
-
|
|
93
|
-
_plot_predictions(log_summary, output_dir)
|
|
94
|
-
plot_partial_dependencies(
|
|
95
|
-
weights=weights["migrationRate"],
|
|
96
|
-
features=SCENARIO.features["migrationRate"],
|
|
97
|
-
output_dir=output_dir,
|
|
98
|
-
output_activation=partial(sigmoid, upper=MIGRATION_RATE_UPPER),
|
|
99
|
-
)
|
|
100
|
-
plot_shap_features_importance(
|
|
101
|
-
weights=weights["migrationRate"],
|
|
102
|
-
features=SCENARIO.features["migrationRate"],
|
|
103
|
-
output_file=os.path.join(output_dir, "shap_values.svg"),
|
|
104
|
-
output_activation=partial(sigmoid, upper=MIGRATION_RATE_UPPER),
|
|
105
|
-
)
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
if __name__ == "__main__":
|
|
109
|
-
main()
|
|
@@ -1,58 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
|
-
import polars as pl
|
|
5
|
-
|
|
6
|
-
import src.config as cfg
|
|
7
|
-
from src.simulations.figures.utils import (
|
|
8
|
-
plot_coverage_per_time_bin,
|
|
9
|
-
plot_maes_per_time_bin,
|
|
10
|
-
step,
|
|
11
|
-
)
|
|
12
|
-
from src.simulations.scenarios.epi_skyline import REPRODUCTION_NUMBERS
|
|
13
|
-
from src.utils import set_plt_rcparams
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def main():
|
|
17
|
-
output_dir = os.path.join(cfg.FIGURES_DIR, "epi-predictions")
|
|
18
|
-
os.makedirs(output_dir, exist_ok=True)
|
|
19
|
-
|
|
20
|
-
set_plt_rcparams()
|
|
21
|
-
|
|
22
|
-
for i, reproduction_number in enumerate(REPRODUCTION_NUMBERS, start=1):
|
|
23
|
-
summaries_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, f"epi-skyline_{i}")
|
|
24
|
-
logs_summaries = {
|
|
25
|
-
"Nonparametric": pl.read_csv(
|
|
26
|
-
os.path.join(summaries_dir, "Nonparametric.csv")
|
|
27
|
-
),
|
|
28
|
-
"GLM": pl.read_csv(os.path.join(summaries_dir, "GLM.csv")),
|
|
29
|
-
"MLP": pl.read_csv(os.path.join(summaries_dir, "MLP-16_8.csv")),
|
|
30
|
-
}
|
|
31
|
-
true_values = {"reproductionNumber": reproduction_number}
|
|
32
|
-
|
|
33
|
-
for log_summary in logs_summaries.values():
|
|
34
|
-
step(
|
|
35
|
-
[
|
|
36
|
-
log_summary[f"reproductionNumberi{i}_median"].median()
|
|
37
|
-
for i in range(len(reproduction_number))
|
|
38
|
-
]
|
|
39
|
-
)
|
|
40
|
-
step(reproduction_number, color="k", linestyle="--")
|
|
41
|
-
plt.ylabel("Reproduction number")
|
|
42
|
-
plt.savefig(os.path.join(output_dir, f"epi-skyline_{i}-predictions.svg"))
|
|
43
|
-
plt.close()
|
|
44
|
-
|
|
45
|
-
plot_coverage_per_time_bin(
|
|
46
|
-
logs_summaries,
|
|
47
|
-
true_values,
|
|
48
|
-
os.path.join(output_dir, f"epi-skyline_{i}-coverage.svg"),
|
|
49
|
-
)
|
|
50
|
-
plot_maes_per_time_bin(
|
|
51
|
-
logs_summaries,
|
|
52
|
-
true_values,
|
|
53
|
-
os.path.join(output_dir, f"epi-skyline_{i}-maes.svg"),
|
|
54
|
-
)
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
if __name__ == "__main__":
|
|
58
|
-
main()
|
|
@@ -1,99 +0,0 @@
|
|
|
1
|
-
import ast
|
|
2
|
-
import os
|
|
3
|
-
from functools import partial
|
|
4
|
-
|
|
5
|
-
import joblib
|
|
6
|
-
import matplotlib.pyplot as plt
|
|
7
|
-
import polars as pl
|
|
8
|
-
from joblib import Parallel, delayed
|
|
9
|
-
from lumiere.backend import sigmoid
|
|
10
|
-
|
|
11
|
-
import src.config as cfg
|
|
12
|
-
from src.figures.utils import (
|
|
13
|
-
plot_partial_dependencies,
|
|
14
|
-
plot_shap_features_importance,
|
|
15
|
-
step,
|
|
16
|
-
)
|
|
17
|
-
from src.simulations.scenarios.fbd_2traits import (
|
|
18
|
-
BIRTH_RATE_TRAIT1_SET,
|
|
19
|
-
BIRTH_RATE_TRAIT1_UNSET,
|
|
20
|
-
DEATH_RATE_TRAIT1_SET,
|
|
21
|
-
DEATH_RATE_TRAIT1_UNSET,
|
|
22
|
-
FBD_RATE_UPPER,
|
|
23
|
-
N_TIME_BINS,
|
|
24
|
-
SCENARIO,
|
|
25
|
-
STATES,
|
|
26
|
-
)
|
|
27
|
-
from src.utils import set_plt_rcparams
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def _plot_predictions(log_summary: pl.DataFrame, output_dir: str):
|
|
31
|
-
for rate in ["birth", "death"]:
|
|
32
|
-
label = r"\lambda" if rate == "birth" else r"\mu"
|
|
33
|
-
rate_trait_1_set = (
|
|
34
|
-
BIRTH_RATE_TRAIT1_UNSET if rate == "birth" else DEATH_RATE_TRAIT1_UNSET
|
|
35
|
-
)
|
|
36
|
-
rate_trait_1_unset = (
|
|
37
|
-
BIRTH_RATE_TRAIT1_SET if rate == "birth" else DEATH_RATE_TRAIT1_SET
|
|
38
|
-
)
|
|
39
|
-
for state in STATES:
|
|
40
|
-
estimates = [
|
|
41
|
-
log_summary[f"{rate}Ratei{i}_{state}_median"].median()
|
|
42
|
-
for i in range(N_TIME_BINS)
|
|
43
|
-
]
|
|
44
|
-
step(
|
|
45
|
-
estimates,
|
|
46
|
-
label=rf"${label}_{{{state[0]},{state[1]}}}$",
|
|
47
|
-
reverse_xticks=True,
|
|
48
|
-
)
|
|
49
|
-
step(
|
|
50
|
-
rate_trait_1_unset,
|
|
51
|
-
color="k",
|
|
52
|
-
linestyle="dashed",
|
|
53
|
-
label=rf"${label}_{{0,0}}$ = ${label}_{{0,1}}$",
|
|
54
|
-
reverse_xticks=True,
|
|
55
|
-
)
|
|
56
|
-
step(
|
|
57
|
-
rate_trait_1_set,
|
|
58
|
-
color="gray",
|
|
59
|
-
linestyle="dashed",
|
|
60
|
-
label=rf"${label}_{{1,0}}$ = ${label}_{{1,1}}$",
|
|
61
|
-
reverse_xticks=True,
|
|
62
|
-
)
|
|
63
|
-
plt.legend()
|
|
64
|
-
plt.ylabel(rf"${label}$")
|
|
65
|
-
plt.savefig(os.path.join(output_dir, rate, "predictions.svg"))
|
|
66
|
-
plt.close()
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
def main():
|
|
70
|
-
output_dir = os.path.join(cfg.FIGURES_DIR, "fbd-explainations")
|
|
71
|
-
for rate in ["birth", "death"]:
|
|
72
|
-
os.makedirs(os.path.join(output_dir, rate), exist_ok=True)
|
|
73
|
-
|
|
74
|
-
log_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, "fbd-2traits")
|
|
75
|
-
model = "MLP-32_16"
|
|
76
|
-
log_summary = pl.read_csv(os.path.join(log_dir, f"{model}.csv"))
|
|
77
|
-
weights = joblib.load(os.path.join(log_dir, f"{model}_weights.pkl"))
|
|
78
|
-
|
|
79
|
-
set_plt_rcparams()
|
|
80
|
-
|
|
81
|
-
_plot_predictions(log_summary, output_dir)
|
|
82
|
-
|
|
83
|
-
for rate in ["birth", "death"]:
|
|
84
|
-
plot_partial_dependencies(
|
|
85
|
-
weights=weights[f"{rate}Rate"],
|
|
86
|
-
features=SCENARIO.features[f"{rate}Rate"],
|
|
87
|
-
output_dir=os.path.join(output_dir, rate),
|
|
88
|
-
output_activation=partial(sigmoid, upper=FBD_RATE_UPPER),
|
|
89
|
-
)
|
|
90
|
-
plot_shap_features_importance(
|
|
91
|
-
weights=weights[f"{rate}Rate"],
|
|
92
|
-
features=SCENARIO.features[f"{rate}Rate"],
|
|
93
|
-
output_file=os.path.join(output_dir, rate, "shap_values.svg"),
|
|
94
|
-
output_activation=partial(sigmoid, upper=FBD_RATE_UPPER),
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
if __name__ == "__main__":
|
|
99
|
-
main()
|
|
@@ -1,66 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
|
|
3
|
-
import matplotlib.pyplot as plt
|
|
4
|
-
import polars as pl
|
|
5
|
-
|
|
6
|
-
import src.config as cfg
|
|
7
|
-
from src.simulations.figures.utils import (
|
|
8
|
-
plot_coverage_per_time_bin,
|
|
9
|
-
plot_maes_per_time_bin,
|
|
10
|
-
step,
|
|
11
|
-
)
|
|
12
|
-
from src.simulations.scenarios.fbd_no_traits import BIRTH_RATES, DEATH_RATES
|
|
13
|
-
from src.utils import set_plt_rcparams
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def main():
|
|
17
|
-
output_dir = os.path.join(cfg.FIGURES_DIR, "fbd-predictions")
|
|
18
|
-
os.makedirs(output_dir, exist_ok=True)
|
|
19
|
-
|
|
20
|
-
set_plt_rcparams()
|
|
21
|
-
|
|
22
|
-
for i, (birth_rate, death_rate) in enumerate(
|
|
23
|
-
zip(BIRTH_RATES, DEATH_RATES), start=1
|
|
24
|
-
):
|
|
25
|
-
summaries_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, f"fbd-no-traits_{i}")
|
|
26
|
-
logs_summaries = {
|
|
27
|
-
"Nonparametric": pl.read_csv(
|
|
28
|
-
os.path.join(summaries_dir, "Nonparametric.csv")
|
|
29
|
-
),
|
|
30
|
-
"GLM": pl.read_csv(os.path.join(summaries_dir, "GLM.csv")),
|
|
31
|
-
"MLP": pl.read_csv(os.path.join(summaries_dir, "MLP-16_8.csv")),
|
|
32
|
-
}
|
|
33
|
-
true_values = {"birthRate": birth_rate, "deathRate": death_rate}
|
|
34
|
-
|
|
35
|
-
for id, rate in true_values.items():
|
|
36
|
-
for log_summary in logs_summaries.values():
|
|
37
|
-
step(
|
|
38
|
-
[
|
|
39
|
-
log_summary[f"{id}i{i}_median"].median()
|
|
40
|
-
for i in range(len(rate))
|
|
41
|
-
],
|
|
42
|
-
reverse_xticks=True,
|
|
43
|
-
)
|
|
44
|
-
step(rate, color="k", linestyle="--", reverse_xticks=True)
|
|
45
|
-
plt.ylabel(r"$\lambda$" if id == "birthRate" else r"$\mu$")
|
|
46
|
-
plt.savefig(
|
|
47
|
-
os.path.join(output_dir, f"fbd-no-traits_{i}-predictions-{id}.svg")
|
|
48
|
-
)
|
|
49
|
-
plt.close()
|
|
50
|
-
|
|
51
|
-
plot_coverage_per_time_bin(
|
|
52
|
-
logs_summaries,
|
|
53
|
-
true_values,
|
|
54
|
-
os.path.join(output_dir, f"fbd-no-traits_{i}-coverage.svg"),
|
|
55
|
-
reverse_xticks=True,
|
|
56
|
-
)
|
|
57
|
-
plot_maes_per_time_bin(
|
|
58
|
-
logs_summaries,
|
|
59
|
-
true_values,
|
|
60
|
-
os.path.join(output_dir, f"fbd-no-traits_{i}-maes.svg"),
|
|
61
|
-
reverse_xticks=True,
|
|
62
|
-
)
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
if __name__ == "__main__":
|
|
66
|
-
main()
|