bella-companion 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bella-companion might be problematic. Click here for more details.
- bella_companion/cli.py +13 -4
- bella_companion/simulations/__init__.py +2 -1
- bella_companion/simulations/figures/__init__.py +21 -0
- bella_companion/simulations/figures/epi_multitype_results.py +81 -0
- bella_companion/simulations/figures/epi_skyline_results.py +46 -0
- bella_companion/simulations/figures/explain/__init__.py +6 -0
- bella_companion/simulations/figures/explain/pdp.py +101 -0
- bella_companion/simulations/figures/explain/shap.py +56 -0
- bella_companion/simulations/figures/fbd_2traits_results.py +83 -0
- bella_companion/simulations/figures/fbd_no_traits_results.py +58 -0
- bella_companion/simulations/figures/scenarios.py +38 -32
- bella_companion/simulations/generate_figures.py +24 -0
- bella_companion/simulations/scenarios/fbd_2traits.py +2 -2
- bella_companion/utils/__init__.py +20 -1
- bella_companion/utils/beast.py +1 -1
- bella_companion/utils/explain.py +45 -0
- bella_companion/utils/plots.py +98 -0
- {bella_companion-0.0.4.dist-info → bella_companion-0.0.5.dist-info}/METADATA +4 -3
- {bella_companion-0.0.4.dist-info → bella_companion-0.0.5.dist-info}/RECORD +21 -16
- bella_companion/simulations/figures/epi_explainations.py +0 -109
- bella_companion/simulations/figures/epi_predictions.py +0 -58
- bella_companion/simulations/figures/fbd_explainations.py +0 -99
- bella_companion/simulations/figures/fbd_predictions.py +0 -66
- bella_companion/simulations/figures/utils.py +0 -250
- {bella_companion-0.0.4.dist-info → bella_companion-0.0.5.dist-info}/WHEEL +0 -0
- {bella_companion-0.0.4.dist-info → bella_companion-0.0.5.dist-info}/entry_points.txt +0 -0
bella_companion/cli.py
CHANGED
|
@@ -4,7 +4,12 @@ from pathlib import Path
|
|
|
4
4
|
|
|
5
5
|
from dotenv import load_dotenv
|
|
6
6
|
|
|
7
|
-
from bella_companion.simulations import
|
|
7
|
+
from bella_companion.simulations import (
|
|
8
|
+
generate_data,
|
|
9
|
+
generate_figures,
|
|
10
|
+
run_beast,
|
|
11
|
+
summarize_logs,
|
|
12
|
+
)
|
|
8
13
|
|
|
9
14
|
|
|
10
15
|
def main():
|
|
@@ -18,16 +23,20 @@ def main():
|
|
|
18
23
|
subparsers = parser.add_subparsers(dest="command", required=True)
|
|
19
24
|
|
|
20
25
|
subparsers.add_parser(
|
|
21
|
-
"generate-
|
|
26
|
+
"generate-simulation-data", help="Generate simulated data."
|
|
22
27
|
).set_defaults(func=generate_data)
|
|
23
28
|
|
|
24
29
|
subparsers.add_parser(
|
|
25
|
-
"run-
|
|
30
|
+
"beast-run-simulations", help="Run BEAST2 on simulated data."
|
|
26
31
|
).set_defaults(func=run_beast)
|
|
27
32
|
|
|
28
33
|
subparsers.add_parser(
|
|
29
|
-
"summarize-simulation-logs", help="Summarize
|
|
34
|
+
"summarize-simulation-logs", help="Summarize BEAST2 logs for simulations data."
|
|
30
35
|
).set_defaults(func=summarize_logs)
|
|
31
36
|
|
|
37
|
+
subparsers.add_parser(
|
|
38
|
+
"generate-simulation-figures", help="Generate figures for simulations data."
|
|
39
|
+
).set_defaults(func=generate_figures)
|
|
40
|
+
|
|
32
41
|
args = parser.parse_args()
|
|
33
42
|
args.func()
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
from bella_companion.simulations.generate_data import generate_data
|
|
2
|
+
from bella_companion.simulations.generate_figures import generate_figures
|
|
2
3
|
from bella_companion.simulations.run_beast import run_beast
|
|
3
4
|
from bella_companion.simulations.summarize_logs import summarize_logs
|
|
4
5
|
|
|
5
|
-
__all__ = ["generate_data", "run_beast", "summarize_logs"]
|
|
6
|
+
__all__ = ["generate_data", "generate_figures", "run_beast", "summarize_logs"]
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from bella_companion.simulations.figures.epi_multitype_results import (
|
|
2
|
+
plot_epi_multitype_results,
|
|
3
|
+
)
|
|
4
|
+
from bella_companion.simulations.figures.epi_skyline_results import (
|
|
5
|
+
plot_epi_skyline_results,
|
|
6
|
+
)
|
|
7
|
+
from bella_companion.simulations.figures.fbd_2traits_results import (
|
|
8
|
+
plot_fbd_2traits_results,
|
|
9
|
+
)
|
|
10
|
+
from bella_companion.simulations.figures.fbd_no_traits_results import (
|
|
11
|
+
plot_fbd_no_traits_results,
|
|
12
|
+
)
|
|
13
|
+
from bella_companion.simulations.figures.scenarios import plot_scenarios
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"plot_epi_multitype_results",
|
|
17
|
+
"plot_epi_skyline_results",
|
|
18
|
+
"plot_fbd_2traits_results",
|
|
19
|
+
"plot_fbd_no_traits_results",
|
|
20
|
+
"plot_scenarios",
|
|
21
|
+
]
|
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from functools import partial
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import joblib
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import numpy as np
|
|
8
|
+
import polars as pl
|
|
9
|
+
from lumiere.backend import relu, sigmoid
|
|
10
|
+
|
|
11
|
+
from bella_companion.simulations.figures.explain import (
|
|
12
|
+
plot_partial_dependencies,
|
|
13
|
+
plot_shap_features_importance,
|
|
14
|
+
)
|
|
15
|
+
from bella_companion.simulations.scenarios.epi_multitype import (
|
|
16
|
+
MIGRATION_PREDICTOR,
|
|
17
|
+
MIGRATION_RATE_UPPER,
|
|
18
|
+
MIGRATION_RATES,
|
|
19
|
+
SCENARIO,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _plot_predictions(log_summary: pl.DataFrame, output_dir: Path):
|
|
24
|
+
targets = SCENARIO.targets["migrationRate"]
|
|
25
|
+
estimates = np.array(
|
|
26
|
+
[log_summary[f"{target}_median"].median() for target in targets]
|
|
27
|
+
)
|
|
28
|
+
lower = np.array([log_summary[f"{target}_lower"].median() for target in targets])
|
|
29
|
+
upper = np.array([log_summary[f"{target}_upper"].median() for target in targets])
|
|
30
|
+
|
|
31
|
+
sort_idx = np.argsort(MIGRATION_PREDICTOR.flatten())
|
|
32
|
+
predictors = MIGRATION_PREDICTOR.flatten()[sort_idx]
|
|
33
|
+
rates = MIGRATION_RATES.flatten()[sort_idx]
|
|
34
|
+
estimates = estimates[sort_idx]
|
|
35
|
+
lower = lower[sort_idx]
|
|
36
|
+
upper = upper[sort_idx]
|
|
37
|
+
|
|
38
|
+
plt.errorbar( # pyright: ignore
|
|
39
|
+
predictors,
|
|
40
|
+
estimates,
|
|
41
|
+
yerr=[estimates - lower, upper - estimates],
|
|
42
|
+
fmt="o",
|
|
43
|
+
color="C2",
|
|
44
|
+
elinewidth=2,
|
|
45
|
+
capsize=5,
|
|
46
|
+
)
|
|
47
|
+
plt.plot( # pyright: ignore
|
|
48
|
+
predictors, rates, linestyle="--", marker="o", color="k"
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
plt.xlabel("Migration predictor") # pyright: ignore
|
|
52
|
+
plt.ylabel("Migration rate") # pyright: ignore
|
|
53
|
+
plt.savefig(output_dir / "predictions.svg") # pyright: ignore
|
|
54
|
+
plt.close()
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def plot_epi_multitype_results():
|
|
58
|
+
output_dir = Path(os.environ["BELLA_FIGURES_DIR"]) / "epi-multitype"
|
|
59
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
60
|
+
|
|
61
|
+
log_dir = Path(os.environ["BELLA_LOG_SUMMARIES_DIR"]) / "epi-multitype"
|
|
62
|
+
model = "MLP-32_16"
|
|
63
|
+
log_summary = pl.read_csv(log_dir / f"{model}.csv")
|
|
64
|
+
weights = joblib.load(log_dir / f"{model}.weights.pkl")
|
|
65
|
+
weights = [w["migrationRate"] for w in weights]
|
|
66
|
+
|
|
67
|
+
_plot_predictions(log_summary, output_dir)
|
|
68
|
+
plot_partial_dependencies(
|
|
69
|
+
weights=weights,
|
|
70
|
+
features=SCENARIO.features["migrationRate"],
|
|
71
|
+
output_dir=output_dir,
|
|
72
|
+
hidden_activation=relu,
|
|
73
|
+
output_activation=partial(sigmoid, upper=MIGRATION_RATE_UPPER),
|
|
74
|
+
)
|
|
75
|
+
plot_shap_features_importance(
|
|
76
|
+
weights=weights,
|
|
77
|
+
features=SCENARIO.features["migrationRate"],
|
|
78
|
+
output_file=output_dir / "shap_values.svg",
|
|
79
|
+
hidden_activation=relu,
|
|
80
|
+
output_activation=partial(sigmoid, upper=MIGRATION_RATE_UPPER),
|
|
81
|
+
)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import numpy as np
|
|
6
|
+
import polars as pl
|
|
7
|
+
|
|
8
|
+
from bella_companion.simulations.scenarios.epi_skyline import REPRODUCTION_NUMBERS
|
|
9
|
+
from bella_companion.utils import (
|
|
10
|
+
plot_coverage_per_time_bin,
|
|
11
|
+
plot_maes_per_time_bin,
|
|
12
|
+
step,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def plot_epi_skyline_results():
|
|
17
|
+
output_dir = Path(os.environ["BELLA_FIGURES_DIR"]) / "epi-skyline-results"
|
|
18
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
19
|
+
|
|
20
|
+
for i, reproduction_number in enumerate(REPRODUCTION_NUMBERS, start=1):
|
|
21
|
+
summaries_dir = Path(os.environ["BELLA_LOG_SUMMARIES_DIR"]) / f"epi-skyline_{i}"
|
|
22
|
+
logs_summaries = {
|
|
23
|
+
"Nonparametric": pl.read_csv(summaries_dir / "Nonparametric.csv"),
|
|
24
|
+
"GLM": pl.read_csv(summaries_dir / "GLM.csv"),
|
|
25
|
+
"MLP": pl.read_csv(summaries_dir / "MLP-16_8.csv"),
|
|
26
|
+
}
|
|
27
|
+
true_values = {"reproductionNumberSP": reproduction_number}
|
|
28
|
+
|
|
29
|
+
for log_summary in logs_summaries.values():
|
|
30
|
+
step(
|
|
31
|
+
[
|
|
32
|
+
float(np.median(log_summary[f"reproductionNumberSPi{i}_median"]))
|
|
33
|
+
for i in range(len(reproduction_number))
|
|
34
|
+
]
|
|
35
|
+
)
|
|
36
|
+
step(reproduction_number, color="k", linestyle="--")
|
|
37
|
+
plt.ylabel("Reproduction number") # pyright: ignore
|
|
38
|
+
plt.savefig(output_dir / f"predictions-{i}.svg") # pyright: ignore
|
|
39
|
+
plt.close()
|
|
40
|
+
|
|
41
|
+
plot_coverage_per_time_bin(
|
|
42
|
+
logs_summaries, true_values, output_dir / f"coverage-{i}.svg"
|
|
43
|
+
)
|
|
44
|
+
plot_maes_per_time_bin(
|
|
45
|
+
logs_summaries, true_values, output_dir / f"maes-{i}.svg"
|
|
46
|
+
)
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from functools import partial
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import seaborn as sns
|
|
8
|
+
from joblib import Parallel, delayed
|
|
9
|
+
from lumiere.typing import ActivationFunction, Weights
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
from bella_companion.simulations.features import Feature
|
|
13
|
+
from bella_companion.utils import get_median_partial_dependence_values
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def plot_partial_dependencies(
|
|
17
|
+
weights: list[list[Weights]], # shape: (n_runs, n_weights_samples, ...)
|
|
18
|
+
features: dict[str, Feature],
|
|
19
|
+
output_dir: Path,
|
|
20
|
+
hidden_activation: ActivationFunction,
|
|
21
|
+
output_activation: ActivationFunction,
|
|
22
|
+
):
|
|
23
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
24
|
+
|
|
25
|
+
continuous_grid: list[float] = np.linspace(0, 1, 10).tolist()
|
|
26
|
+
features_grid: list[list[float]] = [
|
|
27
|
+
[0, 1] if feature.is_binary else continuous_grid
|
|
28
|
+
for feature in features.values()
|
|
29
|
+
]
|
|
30
|
+
jobs = Parallel(n_jobs=-1, return_as="generator_unordered")(
|
|
31
|
+
delayed(
|
|
32
|
+
partial(
|
|
33
|
+
get_median_partial_dependence_values,
|
|
34
|
+
features_grid=features_grid,
|
|
35
|
+
hidden_activation=hidden_activation,
|
|
36
|
+
output_activation=output_activation,
|
|
37
|
+
)
|
|
38
|
+
)(w)
|
|
39
|
+
for w in weights
|
|
40
|
+
)
|
|
41
|
+
pdvalues = [
|
|
42
|
+
job for job in tqdm(jobs, total=len(weights), desc="Evaluating PDPs")
|
|
43
|
+
] # shape: (n_runs, n_features, n_grid_points)
|
|
44
|
+
pdvalues = [
|
|
45
|
+
np.array(mcmc_pds).T for mcmc_pds in zip(*pdvalues)
|
|
46
|
+
] # shape: (n_features, n_grid_points, n_runs)
|
|
47
|
+
|
|
48
|
+
if any(not feature.is_binary for feature in features.values()):
|
|
49
|
+
for (feature_name, feature), feature_pdvalues in zip(
|
|
50
|
+
features.items(), pdvalues
|
|
51
|
+
):
|
|
52
|
+
if not feature.is_binary:
|
|
53
|
+
color = "red" if feature.is_relevant else "gray"
|
|
54
|
+
median = np.median(feature_pdvalues, axis=1)
|
|
55
|
+
lower = np.percentile(feature_pdvalues, 2.5, axis=1)
|
|
56
|
+
high = np.percentile(feature_pdvalues, 100 - 2.5, axis=1)
|
|
57
|
+
plt.fill_between( # pyright: ignore
|
|
58
|
+
continuous_grid, lower, high, alpha=0.25, color=color
|
|
59
|
+
)
|
|
60
|
+
for mcmc_pds in feature_pdvalues.T:
|
|
61
|
+
plt.plot( # pyright: ignore
|
|
62
|
+
continuous_grid, mcmc_pds, color=color, alpha=0.2, linewidth=1
|
|
63
|
+
)
|
|
64
|
+
plt.plot( # pyright: ignore
|
|
65
|
+
continuous_grid, median, color=color, label=feature_name
|
|
66
|
+
)
|
|
67
|
+
plt.xlabel("Feature value") # pyright: ignore
|
|
68
|
+
plt.ylabel("MLP Output") # pyright: ignore
|
|
69
|
+
plt.legend() # pyright: ignore
|
|
70
|
+
plt.savefig(output_dir / "PDPs-continuous.svg") # pyright: ignore
|
|
71
|
+
plt.close()
|
|
72
|
+
|
|
73
|
+
if any(feature.is_binary for feature in features.values()):
|
|
74
|
+
data: list[float] = []
|
|
75
|
+
grid_labels: list[int] = []
|
|
76
|
+
feature_labels: list[str] = []
|
|
77
|
+
for (feature_name, feature), feature_pdvalues in zip(
|
|
78
|
+
features.items(), pdvalues
|
|
79
|
+
):
|
|
80
|
+
if feature.is_binary:
|
|
81
|
+
for i in [0, 1]:
|
|
82
|
+
data.extend(feature_pdvalues[i])
|
|
83
|
+
grid_labels.extend([i] * len(feature_pdvalues[i]))
|
|
84
|
+
feature_labels.extend([feature_name] * len(feature_pdvalues[i]))
|
|
85
|
+
sns.violinplot(
|
|
86
|
+
x=grid_labels,
|
|
87
|
+
y=data,
|
|
88
|
+
hue=feature_labels,
|
|
89
|
+
split=False,
|
|
90
|
+
cut=0,
|
|
91
|
+
palette={
|
|
92
|
+
feature_name: "red" if feature.is_relevant else "gray"
|
|
93
|
+
for feature_name, feature in features.items()
|
|
94
|
+
if feature.is_binary
|
|
95
|
+
},
|
|
96
|
+
)
|
|
97
|
+
plt.xlabel("Feature value") # pyright: ignore
|
|
98
|
+
plt.ylabel("MLP Output") # pyright: ignore
|
|
99
|
+
plt.legend() # pyright: ignore
|
|
100
|
+
plt.savefig(output_dir / "PDPs-categorical.svg") # pyright: ignore
|
|
101
|
+
plt.close()
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
from itertools import product
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import seaborn as sns
|
|
8
|
+
from joblib import Parallel, delayed
|
|
9
|
+
from lumiere.typing import ActivationFunction, Weights
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
from bella_companion.simulations.features import Feature
|
|
13
|
+
from bella_companion.utils import get_median_shap_features_importance
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def plot_shap_features_importance(
|
|
17
|
+
weights: list[list[Weights]], # shape: (n_runs, n_weights_samples, ...)
|
|
18
|
+
features: dict[str, Feature],
|
|
19
|
+
output_file: Path,
|
|
20
|
+
hidden_activation: ActivationFunction,
|
|
21
|
+
output_activation: ActivationFunction,
|
|
22
|
+
):
|
|
23
|
+
continuous_grid: list[float] = np.linspace(0, 1, 10).tolist()
|
|
24
|
+
features_grid: list[list[float]] = [
|
|
25
|
+
[0, 1] if feature.is_binary else continuous_grid
|
|
26
|
+
for feature in features.values()
|
|
27
|
+
]
|
|
28
|
+
inputs = list(product(*features_grid))
|
|
29
|
+
|
|
30
|
+
jobs = Parallel(n_jobs=-1, return_as="generator_unordered")(
|
|
31
|
+
delayed(
|
|
32
|
+
partial(
|
|
33
|
+
get_median_shap_features_importance,
|
|
34
|
+
inputs=inputs,
|
|
35
|
+
hidden_activation=hidden_activation,
|
|
36
|
+
output_activation=output_activation,
|
|
37
|
+
)
|
|
38
|
+
)(w)
|
|
39
|
+
for w in weights
|
|
40
|
+
)
|
|
41
|
+
features_importances = np.array(
|
|
42
|
+
[job for job in tqdm(jobs, total=len(weights), desc="Evaluating SHAPs")]
|
|
43
|
+
) # shape: (n_runs, n_features)
|
|
44
|
+
features_importances /= features_importances.sum(axis=1, keepdims=True)
|
|
45
|
+
|
|
46
|
+
for i, (feature_name, feature) in enumerate(features.items()):
|
|
47
|
+
sns.violinplot(
|
|
48
|
+
y=features_importances[:, i],
|
|
49
|
+
x=[feature_name] * len(features_importances),
|
|
50
|
+
cut=0,
|
|
51
|
+
color="red" if feature.is_relevant else "gray",
|
|
52
|
+
)
|
|
53
|
+
plt.xlabel("Feature") # pyright: ignore
|
|
54
|
+
plt.ylabel("Importance") # pyright: ignore
|
|
55
|
+
plt.savefig(output_file) # pyright: ignore
|
|
56
|
+
plt.close()
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from functools import partial
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import joblib
|
|
6
|
+
import matplotlib.pyplot as plt
|
|
7
|
+
import numpy as np
|
|
8
|
+
import polars as pl
|
|
9
|
+
from lumiere.backend import relu, sigmoid
|
|
10
|
+
|
|
11
|
+
from bella_companion.simulations.figures.explain import (
|
|
12
|
+
plot_partial_dependencies,
|
|
13
|
+
plot_shap_features_importance,
|
|
14
|
+
)
|
|
15
|
+
from bella_companion.simulations.scenarios.fbd_2traits import (
|
|
16
|
+
FBD_RATE_UPPER,
|
|
17
|
+
N_TIME_BINS,
|
|
18
|
+
RATES,
|
|
19
|
+
SCENARIO,
|
|
20
|
+
STATES,
|
|
21
|
+
)
|
|
22
|
+
from bella_companion.utils import step
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _plot_predictions(log_summary: pl.DataFrame, output_dir: Path):
|
|
26
|
+
for rate, state_rates in RATES.items():
|
|
27
|
+
label = r"\lambda" if rate == "birth" else r"\mu"
|
|
28
|
+
for state in STATES:
|
|
29
|
+
estimates = [
|
|
30
|
+
float(np.median(log_summary[f"{rate}RateSPi{i}_{state}_median"]))
|
|
31
|
+
for i in range(N_TIME_BINS)
|
|
32
|
+
]
|
|
33
|
+
step(
|
|
34
|
+
estimates,
|
|
35
|
+
label=rf"${label}_{{{state[0]},{state[1]}}}$",
|
|
36
|
+
reverse_xticks=True,
|
|
37
|
+
)
|
|
38
|
+
step(
|
|
39
|
+
state_rates["00"],
|
|
40
|
+
color="k",
|
|
41
|
+
linestyle="dashed",
|
|
42
|
+
label=rf"${label}_{{0,0}}$ = ${label}_{{0,1}}$",
|
|
43
|
+
reverse_xticks=True,
|
|
44
|
+
)
|
|
45
|
+
step(
|
|
46
|
+
state_rates["10"],
|
|
47
|
+
color="gray",
|
|
48
|
+
linestyle="dashed",
|
|
49
|
+
label=rf"${label}_{{1,0}}$ = ${label}_{{1,1}}$",
|
|
50
|
+
reverse_xticks=True,
|
|
51
|
+
)
|
|
52
|
+
plt.legend() # pyright: ignore
|
|
53
|
+
plt.ylabel(rf"${label}$") # pyright: ignore
|
|
54
|
+
plt.savefig(output_dir / rate / "predictions.svg") # pyright: ignore
|
|
55
|
+
plt.close()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def plot_fbd_2traits_results():
|
|
59
|
+
output_dir = Path(os.environ["BELLA_FIGURES_DIR"]) / "fbd-2traits"
|
|
60
|
+
|
|
61
|
+
log_dir = Path(os.environ["BELLA_LOG_SUMMARIES_DIR"]) / "fbd-2traits"
|
|
62
|
+
model = "MLP-32_16"
|
|
63
|
+
log_summary = pl.read_csv(log_dir / f"{model}.csv")
|
|
64
|
+
weights = joblib.load(log_dir / f"{model}.weights.pkl")
|
|
65
|
+
|
|
66
|
+
for rate in RATES:
|
|
67
|
+
os.makedirs(output_dir / rate, exist_ok=True)
|
|
68
|
+
plot_partial_dependencies(
|
|
69
|
+
weights=[w[f"{rate}Rate"] for w in weights],
|
|
70
|
+
features=SCENARIO.features[f"{rate}Rate"],
|
|
71
|
+
output_dir=output_dir / rate,
|
|
72
|
+
hidden_activation=relu,
|
|
73
|
+
output_activation=partial(sigmoid, upper=FBD_RATE_UPPER),
|
|
74
|
+
)
|
|
75
|
+
plot_shap_features_importance(
|
|
76
|
+
weights=[w[f"{rate}Rate"] for w in weights],
|
|
77
|
+
features=SCENARIO.features[f"{rate}Rate"],
|
|
78
|
+
output_file=output_dir / rate / "shap_values.svg",
|
|
79
|
+
hidden_activation=relu,
|
|
80
|
+
output_activation=partial(sigmoid, upper=FBD_RATE_UPPER),
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
_plot_predictions(log_summary, output_dir)
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import matplotlib.pyplot as plt
|
|
5
|
+
import numpy as np
|
|
6
|
+
import polars as pl
|
|
7
|
+
|
|
8
|
+
from bella_companion.simulations.scenarios.fbd_no_traits import RATES
|
|
9
|
+
from bella_companion.utils import (
|
|
10
|
+
plot_coverage_per_time_bin,
|
|
11
|
+
plot_maes_per_time_bin,
|
|
12
|
+
step,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def plot_fbd_no_traits_results():
|
|
17
|
+
output_dir = Path(os.environ["BELLA_FIGURES_DIR"]) / "fbd-no-traits-predictions"
|
|
18
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
19
|
+
|
|
20
|
+
for i, rates in enumerate(RATES, start=1):
|
|
21
|
+
summaries_dir = (
|
|
22
|
+
Path(os.environ["BELLA_LOG_SUMMARIES_DIR"]) / f"fbd-no-traits_{i}"
|
|
23
|
+
)
|
|
24
|
+
logs_summaries = {
|
|
25
|
+
"Nonparametric": pl.read_csv(summaries_dir / "Nonparametric.csv"),
|
|
26
|
+
"GLM": pl.read_csv(summaries_dir / "GLM.csv"),
|
|
27
|
+
"MLP": pl.read_csv(summaries_dir / "MLP-16_8.csv"),
|
|
28
|
+
}
|
|
29
|
+
true_values = {"birthRateSP": rates["birth"], "deathRateSP": rates["death"]}
|
|
30
|
+
|
|
31
|
+
for id, rate in true_values.items():
|
|
32
|
+
for log_summary in logs_summaries.values():
|
|
33
|
+
step(
|
|
34
|
+
[
|
|
35
|
+
float(np.median(log_summary[f"{id}i{i}_median"]))
|
|
36
|
+
for i in range(len(rate))
|
|
37
|
+
],
|
|
38
|
+
reverse_xticks=True,
|
|
39
|
+
)
|
|
40
|
+
step(rate, color="k", linestyle="--", reverse_xticks=True)
|
|
41
|
+
plt.ylabel( # pyright: ignore
|
|
42
|
+
r"$\lambda$" if id == "birthRateSP" else r"$\mu$"
|
|
43
|
+
)
|
|
44
|
+
plt.savefig(output_dir / f"{id}-predictions-{i}-.svg") # pyright: ignore
|
|
45
|
+
plt.close()
|
|
46
|
+
|
|
47
|
+
plot_coverage_per_time_bin(
|
|
48
|
+
logs_summaries,
|
|
49
|
+
true_values,
|
|
50
|
+
output_dir / f"coverage-{i}.svg",
|
|
51
|
+
reverse_xticks=True,
|
|
52
|
+
)
|
|
53
|
+
plot_maes_per_time_bin(
|
|
54
|
+
logs_summaries,
|
|
55
|
+
true_values,
|
|
56
|
+
output_dir / f"maes-{i}.svg",
|
|
57
|
+
reverse_xticks=True,
|
|
58
|
+
)
|
|
@@ -1,56 +1,66 @@
|
|
|
1
1
|
import os
|
|
2
|
+
from pathlib import Path
|
|
2
3
|
|
|
3
4
|
import matplotlib.pyplot as plt
|
|
4
5
|
import numpy as np
|
|
5
6
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
from
|
|
7
|
+
from bella_companion.simulations.scenarios.epi_multitype import (
|
|
8
|
+
MIGRATION_PREDICTOR,
|
|
9
|
+
MIGRATION_RATES,
|
|
10
|
+
)
|
|
11
|
+
from bella_companion.simulations.scenarios.epi_skyline import REPRODUCTION_NUMBERS
|
|
12
|
+
from bella_companion.simulations.scenarios.fbd_2traits import (
|
|
11
13
|
BIRTH_RATE_TRAIT1_SET,
|
|
12
14
|
BIRTH_RATE_TRAIT1_UNSET,
|
|
13
15
|
DEATH_RATE_TRAIT1_SET,
|
|
14
16
|
DEATH_RATE_TRAIT1_UNSET,
|
|
15
17
|
)
|
|
16
|
-
from
|
|
17
|
-
from
|
|
18
|
+
from bella_companion.simulations.scenarios.fbd_no_traits import RATES
|
|
19
|
+
from bella_companion.utils import step
|
|
18
20
|
|
|
19
21
|
|
|
20
|
-
def
|
|
21
|
-
output_dir = os.
|
|
22
|
+
def plot_scenarios():
|
|
23
|
+
output_dir = Path(os.environ["BELLA_FIGURES_DIR"]) / "targets"
|
|
22
24
|
os.makedirs(output_dir, exist_ok=True)
|
|
23
25
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
+
# -----------
|
|
27
|
+
# epi-skyline
|
|
28
|
+
# -----------
|
|
26
29
|
for i, reproduction_number in enumerate(REPRODUCTION_NUMBERS, start=1):
|
|
27
30
|
step(reproduction_number, color="k")
|
|
28
|
-
plt.ylabel("Reproduction number")
|
|
29
|
-
plt.savefig(
|
|
31
|
+
plt.ylabel("Reproduction number") # pyright: ignore
|
|
32
|
+
plt.savefig(output_dir / f"epi-skyline_{i}.svg") # pyright: ignore
|
|
30
33
|
plt.close()
|
|
31
34
|
|
|
35
|
+
# -------------
|
|
36
|
+
# epi-multitype
|
|
37
|
+
# -------------
|
|
32
38
|
sort_idx = np.argsort(MIGRATION_PREDICTOR.flatten())
|
|
33
|
-
plt.plot(
|
|
39
|
+
plt.plot( # pyright: ignore
|
|
34
40
|
MIGRATION_PREDICTOR.flatten()[sort_idx],
|
|
35
41
|
MIGRATION_RATES.flatten()[sort_idx],
|
|
36
42
|
marker="o",
|
|
37
43
|
color="k",
|
|
38
44
|
)
|
|
39
|
-
plt.xlabel("Migration predictor")
|
|
40
|
-
plt.ylabel("Migration rate")
|
|
41
|
-
plt.savefig(
|
|
45
|
+
plt.xlabel("Migration predictor") # pyright: ignore
|
|
46
|
+
plt.ylabel("Migration rate") # pyright: ignore
|
|
47
|
+
plt.savefig(output_dir / "epi-multitype.svg") # pyright: ignore
|
|
42
48
|
plt.close()
|
|
43
49
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
step(
|
|
49
|
-
|
|
50
|
-
plt.
|
|
51
|
-
plt.
|
|
50
|
+
# -------------
|
|
51
|
+
# fbd-no-traits
|
|
52
|
+
# -------------
|
|
53
|
+
for i, rates in enumerate(RATES, start=1):
|
|
54
|
+
step(rates["birth"], label=r"$\lambda$", reverse_xticks=True)
|
|
55
|
+
step(rates["death"], label=r"$\mu$", reverse_xticks=True)
|
|
56
|
+
plt.ylabel("Rate") # pyright: ignore
|
|
57
|
+
plt.legend() # pyright: ignore
|
|
58
|
+
plt.savefig(output_dir / f"fbd-no-traits_{i}.svg") # pyright: ignore
|
|
52
59
|
plt.close()
|
|
53
60
|
|
|
61
|
+
# -----------
|
|
62
|
+
# fbd-2traits
|
|
63
|
+
# -----------
|
|
54
64
|
step(
|
|
55
65
|
BIRTH_RATE_TRAIT1_UNSET,
|
|
56
66
|
label=r"$\lambda_{0,0} = \lambda_{0,1}$",
|
|
@@ -77,11 +87,7 @@ def main():
|
|
|
77
87
|
linestyle="dashed",
|
|
78
88
|
reverse_xticks=True,
|
|
79
89
|
)
|
|
80
|
-
plt.ylabel("Rate")
|
|
81
|
-
plt.legend()
|
|
82
|
-
plt.savefig(
|
|
90
|
+
plt.ylabel("Rate") # pyright: ignore
|
|
91
|
+
plt.legend() # pyright: ignore
|
|
92
|
+
plt.savefig(output_dir / "fbd-2traits.svg") # pyright: ignore
|
|
83
93
|
plt.close()
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
if __name__ == "__main__":
|
|
87
|
-
main()
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import matplotlib.pyplot as plt
|
|
2
|
+
|
|
3
|
+
from bella_companion.simulations.figures import (
|
|
4
|
+
plot_epi_multitype_results,
|
|
5
|
+
plot_epi_skyline_results,
|
|
6
|
+
plot_fbd_2traits_results,
|
|
7
|
+
plot_fbd_no_traits_results,
|
|
8
|
+
plot_scenarios,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def generate_figures():
|
|
13
|
+
plt.rcParams["pdf.fonttype"] = 42
|
|
14
|
+
plt.rcParams["xtick.labelsize"] = 14
|
|
15
|
+
plt.rcParams["ytick.labelsize"] = 14
|
|
16
|
+
plt.rcParams["font.size"] = 14
|
|
17
|
+
plt.rcParams["figure.constrained_layout.use"] = True
|
|
18
|
+
plt.rcParams["lines.linewidth"] = 3
|
|
19
|
+
|
|
20
|
+
plot_scenarios()
|
|
21
|
+
plot_epi_skyline_results()
|
|
22
|
+
plot_epi_multitype_results()
|
|
23
|
+
plot_fbd_no_traits_results()
|
|
24
|
+
plot_fbd_2traits_results()
|
|
@@ -38,8 +38,8 @@ BIRTH_RATES = {
|
|
|
38
38
|
DEATH_RATES = {
|
|
39
39
|
"00": DEATH_RATE_TRAIT1_UNSET,
|
|
40
40
|
"01": DEATH_RATE_TRAIT1_UNSET,
|
|
41
|
-
"10":
|
|
42
|
-
"11":
|
|
41
|
+
"10": DEATH_RATE_TRAIT1_SET,
|
|
42
|
+
"11": DEATH_RATE_TRAIT1_SET,
|
|
43
43
|
}
|
|
44
44
|
RATES = {
|
|
45
45
|
"birth": BIRTH_RATES,
|
|
@@ -1,4 +1,23 @@
|
|
|
1
1
|
from bella_companion.utils.beast import summarize_log, summarize_logs, summarize_weights
|
|
2
|
+
from bella_companion.utils.explain import (
|
|
3
|
+
get_median_partial_dependence_values,
|
|
4
|
+
get_median_shap_features_importance,
|
|
5
|
+
)
|
|
6
|
+
from bella_companion.utils.plots import (
|
|
7
|
+
plot_coverage_per_time_bin,
|
|
8
|
+
plot_maes_per_time_bin,
|
|
9
|
+
step,
|
|
10
|
+
)
|
|
2
11
|
from bella_companion.utils.slurm import submit_job
|
|
3
12
|
|
|
4
|
-
__all__ = [
|
|
13
|
+
__all__ = [
|
|
14
|
+
"summarize_log",
|
|
15
|
+
"summarize_logs",
|
|
16
|
+
"summarize_weights",
|
|
17
|
+
"get_median_partial_dependence_values",
|
|
18
|
+
"get_median_shap_features_importance",
|
|
19
|
+
"plot_coverage_per_time_bin",
|
|
20
|
+
"plot_maes_per_time_bin",
|
|
21
|
+
"step",
|
|
22
|
+
"submit_job",
|
|
23
|
+
]
|