bella-companion 0.0.4__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bella-companion might be problematic. Click here for more details.

Files changed (26) hide show
  1. bella_companion/cli.py +13 -4
  2. bella_companion/simulations/__init__.py +2 -1
  3. bella_companion/simulations/figures/__init__.py +21 -0
  4. bella_companion/simulations/figures/epi_multitype_results.py +81 -0
  5. bella_companion/simulations/figures/epi_skyline_results.py +46 -0
  6. bella_companion/simulations/figures/explain/__init__.py +6 -0
  7. bella_companion/simulations/figures/explain/pdp.py +101 -0
  8. bella_companion/simulations/figures/explain/shap.py +56 -0
  9. bella_companion/simulations/figures/fbd_2traits_results.py +83 -0
  10. bella_companion/simulations/figures/fbd_no_traits_results.py +58 -0
  11. bella_companion/simulations/figures/scenarios.py +38 -32
  12. bella_companion/simulations/generate_figures.py +24 -0
  13. bella_companion/simulations/scenarios/fbd_2traits.py +2 -2
  14. bella_companion/utils/__init__.py +20 -1
  15. bella_companion/utils/beast.py +1 -1
  16. bella_companion/utils/explain.py +45 -0
  17. bella_companion/utils/plots.py +98 -0
  18. {bella_companion-0.0.4.dist-info → bella_companion-0.0.5.dist-info}/METADATA +4 -3
  19. {bella_companion-0.0.4.dist-info → bella_companion-0.0.5.dist-info}/RECORD +21 -16
  20. bella_companion/simulations/figures/epi_explainations.py +0 -109
  21. bella_companion/simulations/figures/epi_predictions.py +0 -58
  22. bella_companion/simulations/figures/fbd_explainations.py +0 -99
  23. bella_companion/simulations/figures/fbd_predictions.py +0 -66
  24. bella_companion/simulations/figures/utils.py +0 -250
  25. {bella_companion-0.0.4.dist-info → bella_companion-0.0.5.dist-info}/WHEEL +0 -0
  26. {bella_companion-0.0.4.dist-info → bella_companion-0.0.5.dist-info}/entry_points.txt +0 -0
bella_companion/cli.py CHANGED
@@ -4,7 +4,12 @@ from pathlib import Path
4
4
 
5
5
  from dotenv import load_dotenv
6
6
 
7
- from bella_companion.simulations import generate_data, run_beast, summarize_logs
7
+ from bella_companion.simulations import (
8
+ generate_data,
9
+ generate_figures,
10
+ run_beast,
11
+ summarize_logs,
12
+ )
8
13
 
9
14
 
10
15
  def main():
@@ -18,16 +23,20 @@ def main():
18
23
  subparsers = parser.add_subparsers(dest="command", required=True)
19
24
 
20
25
  subparsers.add_parser(
21
- "generate-simulations-data", help="Generate simulated data."
26
+ "generate-simulation-data", help="Generate simulated data."
22
27
  ).set_defaults(func=generate_data)
23
28
 
24
29
  subparsers.add_parser(
25
- "run-beast-simulations", help="Run BEAST2 on simulated data."
30
+ "beast-run-simulations", help="Run BEAST2 on simulated data."
26
31
  ).set_defaults(func=run_beast)
27
32
 
28
33
  subparsers.add_parser(
29
- "summarize-simulation-logs", help="Summarize simulation logs."
34
+ "summarize-simulation-logs", help="Summarize BEAST2 logs for simulations data."
30
35
  ).set_defaults(func=summarize_logs)
31
36
 
37
+ subparsers.add_parser(
38
+ "generate-simulation-figures", help="Generate figures for simulations data."
39
+ ).set_defaults(func=generate_figures)
40
+
32
41
  args = parser.parse_args()
33
42
  args.func()
@@ -1,5 +1,6 @@
1
1
  from bella_companion.simulations.generate_data import generate_data
2
+ from bella_companion.simulations.generate_figures import generate_figures
2
3
  from bella_companion.simulations.run_beast import run_beast
3
4
  from bella_companion.simulations.summarize_logs import summarize_logs
4
5
 
5
- __all__ = ["generate_data", "run_beast", "summarize_logs"]
6
+ __all__ = ["generate_data", "generate_figures", "run_beast", "summarize_logs"]
@@ -0,0 +1,21 @@
1
+ from bella_companion.simulations.figures.epi_multitype_results import (
2
+ plot_epi_multitype_results,
3
+ )
4
+ from bella_companion.simulations.figures.epi_skyline_results import (
5
+ plot_epi_skyline_results,
6
+ )
7
+ from bella_companion.simulations.figures.fbd_2traits_results import (
8
+ plot_fbd_2traits_results,
9
+ )
10
+ from bella_companion.simulations.figures.fbd_no_traits_results import (
11
+ plot_fbd_no_traits_results,
12
+ )
13
+ from bella_companion.simulations.figures.scenarios import plot_scenarios
14
+
15
+ __all__ = [
16
+ "plot_epi_multitype_results",
17
+ "plot_epi_skyline_results",
18
+ "plot_fbd_2traits_results",
19
+ "plot_fbd_no_traits_results",
20
+ "plot_scenarios",
21
+ ]
@@ -0,0 +1,81 @@
1
+ import os
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import joblib
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import polars as pl
9
+ from lumiere.backend import relu, sigmoid
10
+
11
+ from bella_companion.simulations.figures.explain import (
12
+ plot_partial_dependencies,
13
+ plot_shap_features_importance,
14
+ )
15
+ from bella_companion.simulations.scenarios.epi_multitype import (
16
+ MIGRATION_PREDICTOR,
17
+ MIGRATION_RATE_UPPER,
18
+ MIGRATION_RATES,
19
+ SCENARIO,
20
+ )
21
+
22
+
23
+ def _plot_predictions(log_summary: pl.DataFrame, output_dir: Path):
24
+ targets = SCENARIO.targets["migrationRate"]
25
+ estimates = np.array(
26
+ [log_summary[f"{target}_median"].median() for target in targets]
27
+ )
28
+ lower = np.array([log_summary[f"{target}_lower"].median() for target in targets])
29
+ upper = np.array([log_summary[f"{target}_upper"].median() for target in targets])
30
+
31
+ sort_idx = np.argsort(MIGRATION_PREDICTOR.flatten())
32
+ predictors = MIGRATION_PREDICTOR.flatten()[sort_idx]
33
+ rates = MIGRATION_RATES.flatten()[sort_idx]
34
+ estimates = estimates[sort_idx]
35
+ lower = lower[sort_idx]
36
+ upper = upper[sort_idx]
37
+
38
+ plt.errorbar( # pyright: ignore
39
+ predictors,
40
+ estimates,
41
+ yerr=[estimates - lower, upper - estimates],
42
+ fmt="o",
43
+ color="C2",
44
+ elinewidth=2,
45
+ capsize=5,
46
+ )
47
+ plt.plot( # pyright: ignore
48
+ predictors, rates, linestyle="--", marker="o", color="k"
49
+ )
50
+
51
+ plt.xlabel("Migration predictor") # pyright: ignore
52
+ plt.ylabel("Migration rate") # pyright: ignore
53
+ plt.savefig(output_dir / "predictions.svg") # pyright: ignore
54
+ plt.close()
55
+
56
+
57
+ def plot_epi_multitype_results():
58
+ output_dir = Path(os.environ["BELLA_FIGURES_DIR"]) / "epi-multitype"
59
+ os.makedirs(output_dir, exist_ok=True)
60
+
61
+ log_dir = Path(os.environ["BELLA_LOG_SUMMARIES_DIR"]) / "epi-multitype"
62
+ model = "MLP-32_16"
63
+ log_summary = pl.read_csv(log_dir / f"{model}.csv")
64
+ weights = joblib.load(log_dir / f"{model}.weights.pkl")
65
+ weights = [w["migrationRate"] for w in weights]
66
+
67
+ _plot_predictions(log_summary, output_dir)
68
+ plot_partial_dependencies(
69
+ weights=weights,
70
+ features=SCENARIO.features["migrationRate"],
71
+ output_dir=output_dir,
72
+ hidden_activation=relu,
73
+ output_activation=partial(sigmoid, upper=MIGRATION_RATE_UPPER),
74
+ )
75
+ plot_shap_features_importance(
76
+ weights=weights,
77
+ features=SCENARIO.features["migrationRate"],
78
+ output_file=output_dir / "shap_values.svg",
79
+ hidden_activation=relu,
80
+ output_activation=partial(sigmoid, upper=MIGRATION_RATE_UPPER),
81
+ )
@@ -0,0 +1,46 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import polars as pl
7
+
8
+ from bella_companion.simulations.scenarios.epi_skyline import REPRODUCTION_NUMBERS
9
+ from bella_companion.utils import (
10
+ plot_coverage_per_time_bin,
11
+ plot_maes_per_time_bin,
12
+ step,
13
+ )
14
+
15
+
16
+ def plot_epi_skyline_results():
17
+ output_dir = Path(os.environ["BELLA_FIGURES_DIR"]) / "epi-skyline-results"
18
+ os.makedirs(output_dir, exist_ok=True)
19
+
20
+ for i, reproduction_number in enumerate(REPRODUCTION_NUMBERS, start=1):
21
+ summaries_dir = Path(os.environ["BELLA_LOG_SUMMARIES_DIR"]) / f"epi-skyline_{i}"
22
+ logs_summaries = {
23
+ "Nonparametric": pl.read_csv(summaries_dir / "Nonparametric.csv"),
24
+ "GLM": pl.read_csv(summaries_dir / "GLM.csv"),
25
+ "MLP": pl.read_csv(summaries_dir / "MLP-16_8.csv"),
26
+ }
27
+ true_values = {"reproductionNumberSP": reproduction_number}
28
+
29
+ for log_summary in logs_summaries.values():
30
+ step(
31
+ [
32
+ float(np.median(log_summary[f"reproductionNumberSPi{i}_median"]))
33
+ for i in range(len(reproduction_number))
34
+ ]
35
+ )
36
+ step(reproduction_number, color="k", linestyle="--")
37
+ plt.ylabel("Reproduction number") # pyright: ignore
38
+ plt.savefig(output_dir / f"predictions-{i}.svg") # pyright: ignore
39
+ plt.close()
40
+
41
+ plot_coverage_per_time_bin(
42
+ logs_summaries, true_values, output_dir / f"coverage-{i}.svg"
43
+ )
44
+ plot_maes_per_time_bin(
45
+ logs_summaries, true_values, output_dir / f"maes-{i}.svg"
46
+ )
@@ -0,0 +1,6 @@
1
+ from bella_companion.simulations.figures.explain.pdp import plot_partial_dependencies
2
+ from bella_companion.simulations.figures.explain.shap import (
3
+ plot_shap_features_importance,
4
+ )
5
+
6
+ __all__ = ["plot_partial_dependencies", "plot_shap_features_importance"]
@@ -0,0 +1,101 @@
1
+ import os
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import seaborn as sns
8
+ from joblib import Parallel, delayed
9
+ from lumiere.typing import ActivationFunction, Weights
10
+ from tqdm import tqdm
11
+
12
+ from bella_companion.simulations.features import Feature
13
+ from bella_companion.utils import get_median_partial_dependence_values
14
+
15
+
16
+ def plot_partial_dependencies(
17
+ weights: list[list[Weights]], # shape: (n_runs, n_weights_samples, ...)
18
+ features: dict[str, Feature],
19
+ output_dir: Path,
20
+ hidden_activation: ActivationFunction,
21
+ output_activation: ActivationFunction,
22
+ ):
23
+ os.makedirs(output_dir, exist_ok=True)
24
+
25
+ continuous_grid: list[float] = np.linspace(0, 1, 10).tolist()
26
+ features_grid: list[list[float]] = [
27
+ [0, 1] if feature.is_binary else continuous_grid
28
+ for feature in features.values()
29
+ ]
30
+ jobs = Parallel(n_jobs=-1, return_as="generator_unordered")(
31
+ delayed(
32
+ partial(
33
+ get_median_partial_dependence_values,
34
+ features_grid=features_grid,
35
+ hidden_activation=hidden_activation,
36
+ output_activation=output_activation,
37
+ )
38
+ )(w)
39
+ for w in weights
40
+ )
41
+ pdvalues = [
42
+ job for job in tqdm(jobs, total=len(weights), desc="Evaluating PDPs")
43
+ ] # shape: (n_runs, n_features, n_grid_points)
44
+ pdvalues = [
45
+ np.array(mcmc_pds).T for mcmc_pds in zip(*pdvalues)
46
+ ] # shape: (n_features, n_grid_points, n_runs)
47
+
48
+ if any(not feature.is_binary for feature in features.values()):
49
+ for (feature_name, feature), feature_pdvalues in zip(
50
+ features.items(), pdvalues
51
+ ):
52
+ if not feature.is_binary:
53
+ color = "red" if feature.is_relevant else "gray"
54
+ median = np.median(feature_pdvalues, axis=1)
55
+ lower = np.percentile(feature_pdvalues, 2.5, axis=1)
56
+ high = np.percentile(feature_pdvalues, 100 - 2.5, axis=1)
57
+ plt.fill_between( # pyright: ignore
58
+ continuous_grid, lower, high, alpha=0.25, color=color
59
+ )
60
+ for mcmc_pds in feature_pdvalues.T:
61
+ plt.plot( # pyright: ignore
62
+ continuous_grid, mcmc_pds, color=color, alpha=0.2, linewidth=1
63
+ )
64
+ plt.plot( # pyright: ignore
65
+ continuous_grid, median, color=color, label=feature_name
66
+ )
67
+ plt.xlabel("Feature value") # pyright: ignore
68
+ plt.ylabel("MLP Output") # pyright: ignore
69
+ plt.legend() # pyright: ignore
70
+ plt.savefig(output_dir / "PDPs-continuous.svg") # pyright: ignore
71
+ plt.close()
72
+
73
+ if any(feature.is_binary for feature in features.values()):
74
+ data: list[float] = []
75
+ grid_labels: list[int] = []
76
+ feature_labels: list[str] = []
77
+ for (feature_name, feature), feature_pdvalues in zip(
78
+ features.items(), pdvalues
79
+ ):
80
+ if feature.is_binary:
81
+ for i in [0, 1]:
82
+ data.extend(feature_pdvalues[i])
83
+ grid_labels.extend([i] * len(feature_pdvalues[i]))
84
+ feature_labels.extend([feature_name] * len(feature_pdvalues[i]))
85
+ sns.violinplot(
86
+ x=grid_labels,
87
+ y=data,
88
+ hue=feature_labels,
89
+ split=False,
90
+ cut=0,
91
+ palette={
92
+ feature_name: "red" if feature.is_relevant else "gray"
93
+ for feature_name, feature in features.items()
94
+ if feature.is_binary
95
+ },
96
+ )
97
+ plt.xlabel("Feature value") # pyright: ignore
98
+ plt.ylabel("MLP Output") # pyright: ignore
99
+ plt.legend() # pyright: ignore
100
+ plt.savefig(output_dir / "PDPs-categorical.svg") # pyright: ignore
101
+ plt.close()
@@ -0,0 +1,56 @@
1
+ from functools import partial
2
+ from itertools import product
3
+ from pathlib import Path
4
+
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import seaborn as sns
8
+ from joblib import Parallel, delayed
9
+ from lumiere.typing import ActivationFunction, Weights
10
+ from tqdm import tqdm
11
+
12
+ from bella_companion.simulations.features import Feature
13
+ from bella_companion.utils import get_median_shap_features_importance
14
+
15
+
16
+ def plot_shap_features_importance(
17
+ weights: list[list[Weights]], # shape: (n_runs, n_weights_samples, ...)
18
+ features: dict[str, Feature],
19
+ output_file: Path,
20
+ hidden_activation: ActivationFunction,
21
+ output_activation: ActivationFunction,
22
+ ):
23
+ continuous_grid: list[float] = np.linspace(0, 1, 10).tolist()
24
+ features_grid: list[list[float]] = [
25
+ [0, 1] if feature.is_binary else continuous_grid
26
+ for feature in features.values()
27
+ ]
28
+ inputs = list(product(*features_grid))
29
+
30
+ jobs = Parallel(n_jobs=-1, return_as="generator_unordered")(
31
+ delayed(
32
+ partial(
33
+ get_median_shap_features_importance,
34
+ inputs=inputs,
35
+ hidden_activation=hidden_activation,
36
+ output_activation=output_activation,
37
+ )
38
+ )(w)
39
+ for w in weights
40
+ )
41
+ features_importances = np.array(
42
+ [job for job in tqdm(jobs, total=len(weights), desc="Evaluating SHAPs")]
43
+ ) # shape: (n_runs, n_features)
44
+ features_importances /= features_importances.sum(axis=1, keepdims=True)
45
+
46
+ for i, (feature_name, feature) in enumerate(features.items()):
47
+ sns.violinplot(
48
+ y=features_importances[:, i],
49
+ x=[feature_name] * len(features_importances),
50
+ cut=0,
51
+ color="red" if feature.is_relevant else "gray",
52
+ )
53
+ plt.xlabel("Feature") # pyright: ignore
54
+ plt.ylabel("Importance") # pyright: ignore
55
+ plt.savefig(output_file) # pyright: ignore
56
+ plt.close()
@@ -0,0 +1,83 @@
1
+ import os
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import joblib
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import polars as pl
9
+ from lumiere.backend import relu, sigmoid
10
+
11
+ from bella_companion.simulations.figures.explain import (
12
+ plot_partial_dependencies,
13
+ plot_shap_features_importance,
14
+ )
15
+ from bella_companion.simulations.scenarios.fbd_2traits import (
16
+ FBD_RATE_UPPER,
17
+ N_TIME_BINS,
18
+ RATES,
19
+ SCENARIO,
20
+ STATES,
21
+ )
22
+ from bella_companion.utils import step
23
+
24
+
25
+ def _plot_predictions(log_summary: pl.DataFrame, output_dir: Path):
26
+ for rate, state_rates in RATES.items():
27
+ label = r"\lambda" if rate == "birth" else r"\mu"
28
+ for state in STATES:
29
+ estimates = [
30
+ float(np.median(log_summary[f"{rate}RateSPi{i}_{state}_median"]))
31
+ for i in range(N_TIME_BINS)
32
+ ]
33
+ step(
34
+ estimates,
35
+ label=rf"${label}_{{{state[0]},{state[1]}}}$",
36
+ reverse_xticks=True,
37
+ )
38
+ step(
39
+ state_rates["00"],
40
+ color="k",
41
+ linestyle="dashed",
42
+ label=rf"${label}_{{0,0}}$ = ${label}_{{0,1}}$",
43
+ reverse_xticks=True,
44
+ )
45
+ step(
46
+ state_rates["10"],
47
+ color="gray",
48
+ linestyle="dashed",
49
+ label=rf"${label}_{{1,0}}$ = ${label}_{{1,1}}$",
50
+ reverse_xticks=True,
51
+ )
52
+ plt.legend() # pyright: ignore
53
+ plt.ylabel(rf"${label}$") # pyright: ignore
54
+ plt.savefig(output_dir / rate / "predictions.svg") # pyright: ignore
55
+ plt.close()
56
+
57
+
58
+ def plot_fbd_2traits_results():
59
+ output_dir = Path(os.environ["BELLA_FIGURES_DIR"]) / "fbd-2traits"
60
+
61
+ log_dir = Path(os.environ["BELLA_LOG_SUMMARIES_DIR"]) / "fbd-2traits"
62
+ model = "MLP-32_16"
63
+ log_summary = pl.read_csv(log_dir / f"{model}.csv")
64
+ weights = joblib.load(log_dir / f"{model}.weights.pkl")
65
+
66
+ for rate in RATES:
67
+ os.makedirs(output_dir / rate, exist_ok=True)
68
+ plot_partial_dependencies(
69
+ weights=[w[f"{rate}Rate"] for w in weights],
70
+ features=SCENARIO.features[f"{rate}Rate"],
71
+ output_dir=output_dir / rate,
72
+ hidden_activation=relu,
73
+ output_activation=partial(sigmoid, upper=FBD_RATE_UPPER),
74
+ )
75
+ plot_shap_features_importance(
76
+ weights=[w[f"{rate}Rate"] for w in weights],
77
+ features=SCENARIO.features[f"{rate}Rate"],
78
+ output_file=output_dir / rate / "shap_values.svg",
79
+ hidden_activation=relu,
80
+ output_activation=partial(sigmoid, upper=FBD_RATE_UPPER),
81
+ )
82
+
83
+ _plot_predictions(log_summary, output_dir)
@@ -0,0 +1,58 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import polars as pl
7
+
8
+ from bella_companion.simulations.scenarios.fbd_no_traits import RATES
9
+ from bella_companion.utils import (
10
+ plot_coverage_per_time_bin,
11
+ plot_maes_per_time_bin,
12
+ step,
13
+ )
14
+
15
+
16
+ def plot_fbd_no_traits_results():
17
+ output_dir = Path(os.environ["BELLA_FIGURES_DIR"]) / "fbd-no-traits-predictions"
18
+ os.makedirs(output_dir, exist_ok=True)
19
+
20
+ for i, rates in enumerate(RATES, start=1):
21
+ summaries_dir = (
22
+ Path(os.environ["BELLA_LOG_SUMMARIES_DIR"]) / f"fbd-no-traits_{i}"
23
+ )
24
+ logs_summaries = {
25
+ "Nonparametric": pl.read_csv(summaries_dir / "Nonparametric.csv"),
26
+ "GLM": pl.read_csv(summaries_dir / "GLM.csv"),
27
+ "MLP": pl.read_csv(summaries_dir / "MLP-16_8.csv"),
28
+ }
29
+ true_values = {"birthRateSP": rates["birth"], "deathRateSP": rates["death"]}
30
+
31
+ for id, rate in true_values.items():
32
+ for log_summary in logs_summaries.values():
33
+ step(
34
+ [
35
+ float(np.median(log_summary[f"{id}i{i}_median"]))
36
+ for i in range(len(rate))
37
+ ],
38
+ reverse_xticks=True,
39
+ )
40
+ step(rate, color="k", linestyle="--", reverse_xticks=True)
41
+ plt.ylabel( # pyright: ignore
42
+ r"$\lambda$" if id == "birthRateSP" else r"$\mu$"
43
+ )
44
+ plt.savefig(output_dir / f"{id}-predictions-{i}-.svg") # pyright: ignore
45
+ plt.close()
46
+
47
+ plot_coverage_per_time_bin(
48
+ logs_summaries,
49
+ true_values,
50
+ output_dir / f"coverage-{i}.svg",
51
+ reverse_xticks=True,
52
+ )
53
+ plot_maes_per_time_bin(
54
+ logs_summaries,
55
+ true_values,
56
+ output_dir / f"maes-{i}.svg",
57
+ reverse_xticks=True,
58
+ )
@@ -1,56 +1,66 @@
1
1
  import os
2
+ from pathlib import Path
2
3
 
3
4
  import matplotlib.pyplot as plt
4
5
  import numpy as np
5
6
 
6
- import src.config as cfg
7
- from src.simulations.figures.utils import step
8
- from src.simulations.scenarios.epi_multitype import MIGRATION_PREDICTOR, MIGRATION_RATES
9
- from src.simulations.scenarios.epi_skyline import REPRODUCTION_NUMBERS
10
- from src.simulations.scenarios.fbd_2traits import (
7
+ from bella_companion.simulations.scenarios.epi_multitype import (
8
+ MIGRATION_PREDICTOR,
9
+ MIGRATION_RATES,
10
+ )
11
+ from bella_companion.simulations.scenarios.epi_skyline import REPRODUCTION_NUMBERS
12
+ from bella_companion.simulations.scenarios.fbd_2traits import (
11
13
  BIRTH_RATE_TRAIT1_SET,
12
14
  BIRTH_RATE_TRAIT1_UNSET,
13
15
  DEATH_RATE_TRAIT1_SET,
14
16
  DEATH_RATE_TRAIT1_UNSET,
15
17
  )
16
- from src.simulations.scenarios.fbd_no_traits import BIRTH_RATES, DEATH_RATES
17
- from src.utils import set_plt_rcparams
18
+ from bella_companion.simulations.scenarios.fbd_no_traits import RATES
19
+ from bella_companion.utils import step
18
20
 
19
21
 
20
- def main():
21
- output_dir = os.path.join(cfg.FIGURES_DIR, "scenarios")
22
+ def plot_scenarios():
23
+ output_dir = Path(os.environ["BELLA_FIGURES_DIR"]) / "targets"
22
24
  os.makedirs(output_dir, exist_ok=True)
23
25
 
24
- set_plt_rcparams()
25
-
26
+ # -----------
27
+ # epi-skyline
28
+ # -----------
26
29
  for i, reproduction_number in enumerate(REPRODUCTION_NUMBERS, start=1):
27
30
  step(reproduction_number, color="k")
28
- plt.ylabel("Reproduction number")
29
- plt.savefig(os.path.join(output_dir, f"epi-skyline_{i}.svg"))
31
+ plt.ylabel("Reproduction number") # pyright: ignore
32
+ plt.savefig(output_dir / f"epi-skyline_{i}.svg") # pyright: ignore
30
33
  plt.close()
31
34
 
35
+ # -------------
36
+ # epi-multitype
37
+ # -------------
32
38
  sort_idx = np.argsort(MIGRATION_PREDICTOR.flatten())
33
- plt.plot(
39
+ plt.plot( # pyright: ignore
34
40
  MIGRATION_PREDICTOR.flatten()[sort_idx],
35
41
  MIGRATION_RATES.flatten()[sort_idx],
36
42
  marker="o",
37
43
  color="k",
38
44
  )
39
- plt.xlabel("Migration predictor")
40
- plt.ylabel("Migration rate")
41
- plt.savefig(os.path.join(output_dir, "epi-multitype.svg"))
45
+ plt.xlabel("Migration predictor") # pyright: ignore
46
+ plt.ylabel("Migration rate") # pyright: ignore
47
+ plt.savefig(output_dir / "epi-multitype.svg") # pyright: ignore
42
48
  plt.close()
43
49
 
44
- for i, (birth_rate, death_rate) in enumerate(
45
- zip(BIRTH_RATES, DEATH_RATES), start=1
46
- ):
47
- step(birth_rate, label=r"$\lambda$", reverse_xticks=True)
48
- step(death_rate, label=r"$\mu$", reverse_xticks=True)
49
- plt.ylabel("Rate")
50
- plt.legend()
51
- plt.savefig(os.path.join(output_dir, f"fbd-no-traits_{i}.svg"))
50
+ # -------------
51
+ # fbd-no-traits
52
+ # -------------
53
+ for i, rates in enumerate(RATES, start=1):
54
+ step(rates["birth"], label=r"$\lambda$", reverse_xticks=True)
55
+ step(rates["death"], label=r"$\mu$", reverse_xticks=True)
56
+ plt.ylabel("Rate") # pyright: ignore
57
+ plt.legend() # pyright: ignore
58
+ plt.savefig(output_dir / f"fbd-no-traits_{i}.svg") # pyright: ignore
52
59
  plt.close()
53
60
 
61
+ # -----------
62
+ # fbd-2traits
63
+ # -----------
54
64
  step(
55
65
  BIRTH_RATE_TRAIT1_UNSET,
56
66
  label=r"$\lambda_{0,0} = \lambda_{0,1}$",
@@ -77,11 +87,7 @@ def main():
77
87
  linestyle="dashed",
78
88
  reverse_xticks=True,
79
89
  )
80
- plt.ylabel("Rate")
81
- plt.legend()
82
- plt.savefig(os.path.join(output_dir, "fbd-2traits.svg"))
90
+ plt.ylabel("Rate") # pyright: ignore
91
+ plt.legend() # pyright: ignore
92
+ plt.savefig(output_dir / "fbd-2traits.svg") # pyright: ignore
83
93
  plt.close()
84
-
85
-
86
- if __name__ == "__main__":
87
- main()
@@ -0,0 +1,24 @@
1
+ import matplotlib.pyplot as plt
2
+
3
+ from bella_companion.simulations.figures import (
4
+ plot_epi_multitype_results,
5
+ plot_epi_skyline_results,
6
+ plot_fbd_2traits_results,
7
+ plot_fbd_no_traits_results,
8
+ plot_scenarios,
9
+ )
10
+
11
+
12
+ def generate_figures():
13
+ plt.rcParams["pdf.fonttype"] = 42
14
+ plt.rcParams["xtick.labelsize"] = 14
15
+ plt.rcParams["ytick.labelsize"] = 14
16
+ plt.rcParams["font.size"] = 14
17
+ plt.rcParams["figure.constrained_layout.use"] = True
18
+ plt.rcParams["lines.linewidth"] = 3
19
+
20
+ plot_scenarios()
21
+ plot_epi_skyline_results()
22
+ plot_epi_multitype_results()
23
+ plot_fbd_no_traits_results()
24
+ plot_fbd_2traits_results()
@@ -38,8 +38,8 @@ BIRTH_RATES = {
38
38
  DEATH_RATES = {
39
39
  "00": DEATH_RATE_TRAIT1_UNSET,
40
40
  "01": DEATH_RATE_TRAIT1_UNSET,
41
- "10": BIRTH_RATE_TRAIT1_SET,
42
- "11": BIRTH_RATE_TRAIT1_SET,
41
+ "10": DEATH_RATE_TRAIT1_SET,
42
+ "11": DEATH_RATE_TRAIT1_SET,
43
43
  }
44
44
  RATES = {
45
45
  "birth": BIRTH_RATES,
@@ -1,4 +1,23 @@
1
1
  from bella_companion.utils.beast import summarize_log, summarize_logs, summarize_weights
2
+ from bella_companion.utils.explain import (
3
+ get_median_partial_dependence_values,
4
+ get_median_shap_features_importance,
5
+ )
6
+ from bella_companion.utils.plots import (
7
+ plot_coverage_per_time_bin,
8
+ plot_maes_per_time_bin,
9
+ step,
10
+ )
2
11
  from bella_companion.utils.slurm import submit_job
3
12
 
4
- __all__ = ["submit_job", "summarize_log", "summarize_logs", "summarize_weights"]
13
+ __all__ = [
14
+ "summarize_log",
15
+ "summarize_logs",
16
+ "summarize_weights",
17
+ "get_median_partial_dependence_values",
18
+ "get_median_shap_features_importance",
19
+ "plot_coverage_per_time_bin",
20
+ "plot_maes_per_time_bin",
21
+ "step",
22
+ "submit_job",
23
+ ]