bella-companion 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bella-companion might be problematic. Click here for more details.

Files changed (34) hide show
  1. bella_companion/__init__.py +0 -0
  2. bella_companion/cli.py +24 -0
  3. bella_companion/fbd_empirical/data/body_mass.csv +1378 -0
  4. bella_companion/fbd_empirical/data/change_times.csv +22 -0
  5. bella_companion/fbd_empirical/data/sampling_change_times.csv +6 -0
  6. bella_companion/fbd_empirical/data/trees.nwk +100 -0
  7. bella_companion/fbd_empirical/figure.py +37 -0
  8. bella_companion/fbd_empirical/notbooks.ipynb +359 -0
  9. bella_companion/fbd_empirical/params.json +11 -0
  10. bella_companion/fbd_empirical/run_beast.py +54 -0
  11. bella_companion/fbd_empirical/summarize_logs.py +50 -0
  12. bella_companion/simulations/__init__.py +0 -0
  13. bella_companion/simulations/features.py +7 -0
  14. bella_companion/simulations/figures/__init__.py +0 -0
  15. bella_companion/simulations/figures/epi_explainations.py +101 -0
  16. bella_companion/simulations/figures/epi_predictions.py +58 -0
  17. bella_companion/simulations/figures/fbd_explainations.py +99 -0
  18. bella_companion/simulations/figures/fbd_predictions.py +66 -0
  19. bella_companion/simulations/figures/scenarios.py +87 -0
  20. bella_companion/simulations/figures/utils.py +250 -0
  21. bella_companion/simulations/generate_data.py +25 -0
  22. bella_companion/simulations/run_beast.py +92 -0
  23. bella_companion/simulations/scenarios/__init__.py +20 -0
  24. bella_companion/simulations/scenarios/common.py +29 -0
  25. bella_companion/simulations/scenarios/epi_multitype.py +68 -0
  26. bella_companion/simulations/scenarios/epi_skyline.py +65 -0
  27. bella_companion/simulations/scenarios/fbd_2traits.py +101 -0
  28. bella_companion/simulations/scenarios/fbd_no_traits.py +71 -0
  29. bella_companion/simulations/scenarios/scenario.py +26 -0
  30. bella_companion/simulations/summarize_logs.py +39 -0
  31. bella_companion/utils.py +164 -0
  32. bella_companion-0.0.0.dist-info/METADATA +13 -0
  33. bella_companion-0.0.0.dist-info/RECORD +34 -0
  34. bella_companion-0.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,54 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import polars as pl
6
+ from phylogenie import load_newick
7
+ from phylogenie.utils import get_node_depths
8
+ from tqdm import tqdm
9
+
10
+ import src.config as cfg
11
+ from src.utils import run_sbatch
12
+
13
+ THIS_DIR = Path(__file__).parent
14
+
15
+
16
+ def main():
17
+ output_dir = cfg.BEAST_OUTPUTS_DIR / "fbd-empirical"
18
+ os.makedirs(output_dir, exist_ok=True)
19
+
20
+ tree_file = THIS_DIR / "data" / "trees.nwk"
21
+ change_times_file = THIS_DIR / "data" / "change_times.csv"
22
+ sampling_change_times_file = THIS_DIR / "data" / "sampling_change_times.csv"
23
+
24
+ change_times = (
25
+ pl.read_csv(change_times_file, has_header=False).to_series().to_numpy()
26
+ )
27
+ time_predictor = " ".join(
28
+ list(map(str, np.repeat(np.insert(change_times, 0, 0), 4)))
29
+ )
30
+ body_mass_predictor = " ".join(["0", "1", "2", "3"] * (len(change_times) + 1))
31
+
32
+ trees = load_newick(str(tree_file))
33
+ assert isinstance(trees, list)
34
+ for i, tree in enumerate(tqdm(trees)):
35
+ process_length = max(get_node_depths(tree).values())
36
+ for model in ["hidden-relu", "hidden-tanh"]:
37
+ command = " ".join(
38
+ [
39
+ cfg.RUN_BEAST,
40
+ f'-D treeFile={tree_file},treeIndex={i},typeTraitFile={THIS_DIR / "data" / "body_mass.csv"},changeTimesFile={change_times_file},samplingChangeTimesFile={sampling_change_times_file},processLength={process_length},timePredictor="{time_predictor}",bodyMassPredictor="{body_mass_predictor}"',
41
+ f"-DF {THIS_DIR / 'params.json'}",
42
+ f"-prefix {output_dir / model}",
43
+ str(cfg.BEAST_CONFIGS_DIR / "fbd-empirical" / f"{model}.xml"),
44
+ ]
45
+ )
46
+ run_sbatch(
47
+ command,
48
+ cfg.SBATCH_LOGS_DIR / "fbd-empirical" / model / str(i),
49
+ mem_per_cpu="12000",
50
+ )
51
+
52
+
53
+ if __name__ == "__main__":
54
+ main()
@@ -0,0 +1,50 @@
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import joblib
6
+ import polars as pl
7
+
8
+ from src.config import BEAST_LOGS_SUMMARIES_DIR, BEAST_OUTPUTS_DIR
9
+ from src.utils import summarize_logs
10
+
11
+ THIS_DIR = Path(__file__).parent
12
+
13
+
14
+ def main():
15
+ summaries_dir = os.path.join(BEAST_LOGS_SUMMARIES_DIR, "fbd-empirical")
16
+ os.makedirs(summaries_dir, exist_ok=True)
17
+
18
+ with open(os.path.join(THIS_DIR, "params", "MLP.json"), "r") as f:
19
+ params = json.load(f)
20
+ hidden_nodes = list(map(int, params["nodes"].split()))[:-1]
21
+ print(hidden_nodes)
22
+ states = params["types"].split(",")
23
+ logs_dir = os.path.join(BEAST_OUTPUTS_DIR, "fbd-empirical", "MLP")
24
+ change_times = (
25
+ pl.read_csv(
26
+ os.path.join(THIS_DIR, "data", "change_times.csv"), has_header=False
27
+ )
28
+ .to_series()
29
+ .to_list()
30
+ )
31
+ n_time_bins = len(change_times) + 1
32
+ logs_summary, weights = summarize_logs(
33
+ logs_dir,
34
+ target_columns=[
35
+ f"{rate}Ratei{i}_{s}"
36
+ for rate in ["birth", "death"]
37
+ for i in range(n_time_bins)
38
+ for s in states
39
+ ],
40
+ hidden_nodes=hidden_nodes,
41
+ n_features={f"{rate}Rate": 2 for rate in ["birth", "death"]},
42
+ layers_range_start=0,
43
+ )
44
+
45
+ logs_summary.write_csv(os.path.join(summaries_dir, f"MLP.csv"))
46
+ joblib.dump(weights, os.path.join(summaries_dir, "weights.pkl"))
47
+
48
+
49
+ if __name__ == "__main__":
50
+ main()
File without changes
@@ -0,0 +1,7 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class Feature:
6
+ is_binary: bool
7
+ is_relevant: bool
File without changes
@@ -0,0 +1,101 @@
1
+ import os
2
+ from functools import partial
3
+
4
+ import joblib
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import polars as pl
8
+ from lumiere.backend import sigmoid
9
+
10
+ import src.config as cfg
11
+ from src.simulations.figures.utils import (
12
+ plot_partial_dependencies,
13
+ plot_shap_features_importance,
14
+ )
15
+ from src.simulations.scenarios.epi_multitype import (
16
+ MIGRATION_PREDICTOR,
17
+ MIGRATION_RATE_UPPER,
18
+ MIGRATION_RATES,
19
+ SCENARIO,
20
+ )
21
+ from src.utils import set_plt_rcparams
22
+
23
+
24
+ def _plot_predictions(log_summary: pl.DataFrame, output_dir: str):
25
+ sort_idx = np.argsort(MIGRATION_PREDICTOR.flatten())
26
+
27
+ estimates = np.array(
28
+ [
29
+ log_summary[f"{target}_median"].median()
30
+ for target in SCENARIO.targets["migrationRate"]
31
+ ]
32
+ )
33
+ lower = np.array(
34
+ [
35
+ log_summary[f"{target}_lower"].median()
36
+ for target in SCENARIO.targets["migrationRate"]
37
+ ]
38
+ )
39
+ upper = np.array(
40
+ [
41
+ log_summary[f"{target}_upper"].median()
42
+ for target in SCENARIO.targets["migrationRate"]
43
+ ]
44
+ )
45
+ plt.errorbar(
46
+ MIGRATION_PREDICTOR.flatten()[sort_idx],
47
+ estimates[sort_idx],
48
+ yerr=[
49
+ estimates[sort_idx] - lower[sort_idx],
50
+ upper[sort_idx] - estimates[sort_idx],
51
+ ],
52
+ marker="o",
53
+ color="C2",
54
+ )
55
+ plt.plot(
56
+ MIGRATION_PREDICTOR.flatten()[sort_idx],
57
+ estimates[sort_idx],
58
+ marker="o",
59
+ color="C2",
60
+ )
61
+ plt.plot(
62
+ MIGRATION_PREDICTOR.flatten()[sort_idx],
63
+ MIGRATION_RATES.flatten()[sort_idx],
64
+ linestyle="dashed",
65
+ marker="o",
66
+ color="k",
67
+ )
68
+ plt.xlabel("Migration predictor")
69
+ plt.ylabel("Migration rate")
70
+ plt.savefig(os.path.join(output_dir, "predictions.svg"))
71
+ plt.close()
72
+
73
+
74
+ def main():
75
+ output_dir = os.path.join(cfg.FIGURES_DIR, "epi-explainations")
76
+ os.makedirs(output_dir, exist_ok=True)
77
+
78
+ log_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, "epi-multitype")
79
+ model = "MLP-32_16"
80
+ log_summary = pl.read_csv(os.path.join(log_dir, f"{model}.csv"))
81
+ weights = joblib.load(os.path.join(log_dir, f"{model}_weights.pkl"))
82
+
83
+ set_plt_rcparams()
84
+
85
+ _plot_predictions(log_summary, output_dir)
86
+ plot_partial_dependencies(
87
+ weights=weights["migrationRate"],
88
+ features=SCENARIO.features["migrationRate"],
89
+ output_dir=output_dir,
90
+ output_activation=partial(sigmoid, upper=MIGRATION_RATE_UPPER),
91
+ )
92
+ plot_shap_features_importance(
93
+ weights=weights["migrationRate"],
94
+ features=SCENARIO.features["migrationRate"],
95
+ output_file=os.path.join(output_dir, "shap_values.svg"),
96
+ output_activation=partial(sigmoid, upper=MIGRATION_RATE_UPPER),
97
+ )
98
+
99
+
100
+ if __name__ == "__main__":
101
+ main()
@@ -0,0 +1,58 @@
1
+ import os
2
+
3
+ import matplotlib.pyplot as plt
4
+ import polars as pl
5
+
6
+ import src.config as cfg
7
+ from src.simulations.figures.utils import (
8
+ plot_coverage_per_time_bin,
9
+ plot_maes_per_time_bin,
10
+ step,
11
+ )
12
+ from src.simulations.scenarios.epi_skyline import REPRODUCTION_NUMBERS
13
+ from src.utils import set_plt_rcparams
14
+
15
+
16
+ def main():
17
+ output_dir = os.path.join(cfg.FIGURES_DIR, "epi-predictions")
18
+ os.makedirs(output_dir, exist_ok=True)
19
+
20
+ set_plt_rcparams()
21
+
22
+ for i, reproduction_number in enumerate(REPRODUCTION_NUMBERS, start=1):
23
+ summaries_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, f"epi-skyline_{i}")
24
+ logs_summaries = {
25
+ "Nonparametric": pl.read_csv(
26
+ os.path.join(summaries_dir, "Nonparametric.csv")
27
+ ),
28
+ "GLM": pl.read_csv(os.path.join(summaries_dir, "GLM.csv")),
29
+ "MLP": pl.read_csv(os.path.join(summaries_dir, "MLP-16_8.csv")),
30
+ }
31
+ true_values = {"reproductionNumber": reproduction_number}
32
+
33
+ for log_summary in logs_summaries.values():
34
+ step(
35
+ [
36
+ log_summary[f"reproductionNumberi{i}_median"].median()
37
+ for i in range(len(reproduction_number))
38
+ ]
39
+ )
40
+ step(reproduction_number, color="k", linestyle="--")
41
+ plt.ylabel("Reproduction number")
42
+ plt.savefig(os.path.join(output_dir, f"epi-skyline_{i}-predictions.svg"))
43
+ plt.close()
44
+
45
+ plot_coverage_per_time_bin(
46
+ logs_summaries,
47
+ true_values,
48
+ os.path.join(output_dir, f"epi-skyline_{i}-coverage.svg"),
49
+ )
50
+ plot_maes_per_time_bin(
51
+ logs_summaries,
52
+ true_values,
53
+ os.path.join(output_dir, f"epi-skyline_{i}-maes.svg"),
54
+ )
55
+
56
+
57
+ if __name__ == "__main__":
58
+ main()
@@ -0,0 +1,99 @@
1
+ import ast
2
+ import os
3
+ from functools import partial
4
+
5
+ import joblib
6
+ import matplotlib.pyplot as plt
7
+ import polars as pl
8
+ from joblib import Parallel, delayed
9
+ from lumiere.backend import sigmoid
10
+
11
+ import src.config as cfg
12
+ from src.figures.utils import (
13
+ plot_partial_dependencies,
14
+ plot_shap_features_importance,
15
+ step,
16
+ )
17
+ from src.simulations.scenarios.fbd_2traits import (
18
+ BIRTH_RATE_TRAIT1_SET,
19
+ BIRTH_RATE_TRAIT1_UNSET,
20
+ DEATH_RATE_TRAIT1_SET,
21
+ DEATH_RATE_TRAIT1_UNSET,
22
+ FBD_RATE_UPPER,
23
+ N_TIME_BINS,
24
+ SCENARIO,
25
+ STATES,
26
+ )
27
+ from src.utils import set_plt_rcparams
28
+
29
+
30
+ def _plot_predictions(log_summary: pl.DataFrame, output_dir: str):
31
+ for rate in ["birth", "death"]:
32
+ label = r"\lambda" if rate == "birth" else r"\mu"
33
+ rate_trait_1_set = (
34
+ BIRTH_RATE_TRAIT1_UNSET if rate == "birth" else DEATH_RATE_TRAIT1_UNSET
35
+ )
36
+ rate_trait_1_unset = (
37
+ BIRTH_RATE_TRAIT1_SET if rate == "birth" else DEATH_RATE_TRAIT1_SET
38
+ )
39
+ for state in STATES:
40
+ estimates = [
41
+ log_summary[f"{rate}Ratei{i}_{state}_median"].median()
42
+ for i in range(N_TIME_BINS)
43
+ ]
44
+ step(
45
+ estimates,
46
+ label=rf"${label}_{{{state[0]},{state[1]}}}$",
47
+ reverse_xticks=True,
48
+ )
49
+ step(
50
+ rate_trait_1_unset,
51
+ color="k",
52
+ linestyle="dashed",
53
+ label=rf"${label}_{{0,0}}$ = ${label}_{{0,1}}$",
54
+ reverse_xticks=True,
55
+ )
56
+ step(
57
+ rate_trait_1_set,
58
+ color="gray",
59
+ linestyle="dashed",
60
+ label=rf"${label}_{{1,0}}$ = ${label}_{{1,1}}$",
61
+ reverse_xticks=True,
62
+ )
63
+ plt.legend()
64
+ plt.ylabel(rf"${label}$")
65
+ plt.savefig(os.path.join(output_dir, rate, "predictions.svg"))
66
+ plt.close()
67
+
68
+
69
+ def main():
70
+ output_dir = os.path.join(cfg.FIGURES_DIR, "fbd-explainations")
71
+ for rate in ["birth", "death"]:
72
+ os.makedirs(os.path.join(output_dir, rate), exist_ok=True)
73
+
74
+ log_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, "fbd-2traits")
75
+ model = "MLP-32_16"
76
+ log_summary = pl.read_csv(os.path.join(log_dir, f"{model}.csv"))
77
+ weights = joblib.load(os.path.join(log_dir, f"{model}_weights.pkl"))
78
+
79
+ set_plt_rcparams()
80
+
81
+ _plot_predictions(log_summary, output_dir)
82
+
83
+ for rate in ["birth", "death"]:
84
+ plot_partial_dependencies(
85
+ weights=weights[f"{rate}Rate"],
86
+ features=SCENARIO.features[f"{rate}Rate"],
87
+ output_dir=os.path.join(output_dir, rate),
88
+ output_activation=partial(sigmoid, upper=FBD_RATE_UPPER),
89
+ )
90
+ plot_shap_features_importance(
91
+ weights=weights[f"{rate}Rate"],
92
+ features=SCENARIO.features[f"{rate}Rate"],
93
+ output_file=os.path.join(output_dir, rate, "shap_values.svg"),
94
+ output_activation=partial(sigmoid, upper=FBD_RATE_UPPER),
95
+ )
96
+
97
+
98
+ if __name__ == "__main__":
99
+ main()
@@ -0,0 +1,66 @@
1
+ import os
2
+
3
+ import matplotlib.pyplot as plt
4
+ import polars as pl
5
+
6
+ import src.config as cfg
7
+ from src.simulations.figures.utils import (
8
+ plot_coverage_per_time_bin,
9
+ plot_maes_per_time_bin,
10
+ step,
11
+ )
12
+ from src.simulations.scenarios.fbd_no_traits import BIRTH_RATES, DEATH_RATES
13
+ from src.utils import set_plt_rcparams
14
+
15
+
16
+ def main():
17
+ output_dir = os.path.join(cfg.FIGURES_DIR, "fbd-predictions")
18
+ os.makedirs(output_dir, exist_ok=True)
19
+
20
+ set_plt_rcparams()
21
+
22
+ for i, (birth_rate, death_rate) in enumerate(
23
+ zip(BIRTH_RATES, DEATH_RATES), start=1
24
+ ):
25
+ summaries_dir = os.path.join(cfg.BEAST_LOGS_SUMMARIES_DIR, f"fbd-no-traits_{i}")
26
+ logs_summaries = {
27
+ "Nonparametric": pl.read_csv(
28
+ os.path.join(summaries_dir, "Nonparametric.csv")
29
+ ),
30
+ "GLM": pl.read_csv(os.path.join(summaries_dir, "GLM.csv")),
31
+ "MLP": pl.read_csv(os.path.join(summaries_dir, "MLP-16_8.csv")),
32
+ }
33
+ true_values = {"birthRate": birth_rate, "deathRate": death_rate}
34
+
35
+ for id, rate in true_values.items():
36
+ for log_summary in logs_summaries.values():
37
+ step(
38
+ [
39
+ log_summary[f"{id}i{i}_median"].median()
40
+ for i in range(len(rate))
41
+ ],
42
+ reverse_xticks=True,
43
+ )
44
+ step(rate, color="k", linestyle="--", reverse_xticks=True)
45
+ plt.ylabel(r"$\lambda$" if id == "birthRate" else r"$\mu$")
46
+ plt.savefig(
47
+ os.path.join(output_dir, f"fbd-no-traits_{i}-predictions-{id}.svg")
48
+ )
49
+ plt.close()
50
+
51
+ plot_coverage_per_time_bin(
52
+ logs_summaries,
53
+ true_values,
54
+ os.path.join(output_dir, f"fbd-no-traits_{i}-coverage.svg"),
55
+ reverse_xticks=True,
56
+ )
57
+ plot_maes_per_time_bin(
58
+ logs_summaries,
59
+ true_values,
60
+ os.path.join(output_dir, f"fbd-no-traits_{i}-maes.svg"),
61
+ reverse_xticks=True,
62
+ )
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
@@ -0,0 +1,87 @@
1
+ import os
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+ import src.config as cfg
7
+ from src.simulations.figures.utils import step
8
+ from src.simulations.scenarios.epi_multitype import MIGRATION_PREDICTOR, MIGRATION_RATES
9
+ from src.simulations.scenarios.epi_skyline import REPRODUCTION_NUMBERS
10
+ from src.simulations.scenarios.fbd_2traits import (
11
+ BIRTH_RATE_TRAIT1_SET,
12
+ BIRTH_RATE_TRAIT1_UNSET,
13
+ DEATH_RATE_TRAIT1_SET,
14
+ DEATH_RATE_TRAIT1_UNSET,
15
+ )
16
+ from src.simulations.scenarios.fbd_no_traits import BIRTH_RATES, DEATH_RATES
17
+ from src.utils import set_plt_rcparams
18
+
19
+
20
+ def main():
21
+ output_dir = os.path.join(cfg.FIGURES_DIR, "scenarios")
22
+ os.makedirs(output_dir, exist_ok=True)
23
+
24
+ set_plt_rcparams()
25
+
26
+ for i, reproduction_number in enumerate(REPRODUCTION_NUMBERS, start=1):
27
+ step(reproduction_number, color="k")
28
+ plt.ylabel("Reproduction number")
29
+ plt.savefig(os.path.join(output_dir, f"epi-skyline_{i}.svg"))
30
+ plt.close()
31
+
32
+ sort_idx = np.argsort(MIGRATION_PREDICTOR.flatten())
33
+ plt.plot(
34
+ MIGRATION_PREDICTOR.flatten()[sort_idx],
35
+ MIGRATION_RATES.flatten()[sort_idx],
36
+ marker="o",
37
+ color="k",
38
+ )
39
+ plt.xlabel("Migration predictor")
40
+ plt.ylabel("Migration rate")
41
+ plt.savefig(os.path.join(output_dir, "epi-multitype.svg"))
42
+ plt.close()
43
+
44
+ for i, (birth_rate, death_rate) in enumerate(
45
+ zip(BIRTH_RATES, DEATH_RATES), start=1
46
+ ):
47
+ step(birth_rate, label=r"$\lambda$", reverse_xticks=True)
48
+ step(death_rate, label=r"$\mu$", reverse_xticks=True)
49
+ plt.ylabel("Rate")
50
+ plt.legend()
51
+ plt.savefig(os.path.join(output_dir, f"fbd-no-traits_{i}.svg"))
52
+ plt.close()
53
+
54
+ step(
55
+ BIRTH_RATE_TRAIT1_UNSET,
56
+ label=r"$\lambda_{0,0} = \lambda_{0,1}$",
57
+ color="C0",
58
+ reverse_xticks=True,
59
+ )
60
+ step(
61
+ BIRTH_RATE_TRAIT1_SET,
62
+ label=r"$\lambda_{1,0} = \lambda_{1,1}$",
63
+ color="C0",
64
+ linestyle="dashed",
65
+ reverse_xticks=True,
66
+ )
67
+ step(
68
+ DEATH_RATE_TRAIT1_UNSET,
69
+ label=r"$\mu_{0,0} = \mu_{0,1}$",
70
+ color="C1",
71
+ reverse_xticks=True,
72
+ )
73
+ step(
74
+ DEATH_RATE_TRAIT1_SET,
75
+ label=r"$\mu_{1,0} = \mu_{1,1}$",
76
+ color="C1",
77
+ linestyle="dashed",
78
+ reverse_xticks=True,
79
+ )
80
+ plt.ylabel("Rate")
81
+ plt.legend()
82
+ plt.savefig(os.path.join(output_dir, "fbd-2traits.svg"))
83
+ plt.close()
84
+
85
+
86
+ if __name__ == "__main__":
87
+ main()