bella-companion 0.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bella-companion might be problematic. Click here for more details.

Files changed (34) hide show
  1. bella_companion/__init__.py +0 -0
  2. bella_companion/cli.py +24 -0
  3. bella_companion/fbd_empirical/data/body_mass.csv +1378 -0
  4. bella_companion/fbd_empirical/data/change_times.csv +22 -0
  5. bella_companion/fbd_empirical/data/sampling_change_times.csv +6 -0
  6. bella_companion/fbd_empirical/data/trees.nwk +100 -0
  7. bella_companion/fbd_empirical/figure.py +37 -0
  8. bella_companion/fbd_empirical/notbooks.ipynb +359 -0
  9. bella_companion/fbd_empirical/params.json +11 -0
  10. bella_companion/fbd_empirical/run_beast.py +54 -0
  11. bella_companion/fbd_empirical/summarize_logs.py +50 -0
  12. bella_companion/simulations/__init__.py +0 -0
  13. bella_companion/simulations/features.py +7 -0
  14. bella_companion/simulations/figures/__init__.py +0 -0
  15. bella_companion/simulations/figures/epi_explainations.py +101 -0
  16. bella_companion/simulations/figures/epi_predictions.py +58 -0
  17. bella_companion/simulations/figures/fbd_explainations.py +99 -0
  18. bella_companion/simulations/figures/fbd_predictions.py +66 -0
  19. bella_companion/simulations/figures/scenarios.py +87 -0
  20. bella_companion/simulations/figures/utils.py +250 -0
  21. bella_companion/simulations/generate_data.py +25 -0
  22. bella_companion/simulations/run_beast.py +92 -0
  23. bella_companion/simulations/scenarios/__init__.py +20 -0
  24. bella_companion/simulations/scenarios/common.py +29 -0
  25. bella_companion/simulations/scenarios/epi_multitype.py +68 -0
  26. bella_companion/simulations/scenarios/epi_skyline.py +65 -0
  27. bella_companion/simulations/scenarios/fbd_2traits.py +101 -0
  28. bella_companion/simulations/scenarios/fbd_no_traits.py +71 -0
  29. bella_companion/simulations/scenarios/scenario.py +26 -0
  30. bella_companion/simulations/summarize_logs.py +39 -0
  31. bella_companion/utils.py +164 -0
  32. bella_companion-0.0.0.dist-info/METADATA +13 -0
  33. bella_companion-0.0.0.dist-info/RECORD +34 -0
  34. bella_companion-0.0.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,71 @@
1
+ from functools import partial
2
+
3
+ import numpy as np
4
+ from phylogenie import SkylineParameter, get_canonical_events
5
+
6
+ from bella_companion.simulations.features import Feature
7
+ from bella_companion.simulations.scenarios.common import (
8
+ FBD_MAX_TIME,
9
+ FBD_RATE_UPPER,
10
+ FBD_SAMPLING_RATE,
11
+ get_prior_params,
12
+ get_random_time_series_predictor,
13
+ )
14
+ from bella_companion.simulations.scenarios.scenario import Scenario, ScenarioType
15
+
16
+
17
+ def _get_scenario(rates: dict[str, list[float]]) -> Scenario:
18
+ if len(rates["birth"]) != len(rates["death"]):
19
+ raise ValueError("Birth rate and death rate lists must have the same length.")
20
+ n_time_bins = len(rates["birth"])
21
+ change_times = np.linspace(0, FBD_MAX_TIME, n_time_bins + 1)[1:-1].tolist()
22
+
23
+ return Scenario(
24
+ type=ScenarioType.FBD,
25
+ max_time=FBD_MAX_TIME,
26
+ events=get_canonical_events(
27
+ states=["X"],
28
+ sampling_rates=FBD_SAMPLING_RATE,
29
+ remove_after_sampling=False,
30
+ birth_rates=SkylineParameter(rates["birth"], change_times),
31
+ death_rates=SkylineParameter(rates["death"], change_times),
32
+ ),
33
+ get_random_predictor=partial(
34
+ get_random_time_series_predictor, n_time_bins=n_time_bins
35
+ ),
36
+ beast_args={
37
+ "processLength": FBD_MAX_TIME,
38
+ "changeTimes": " ".join(map(str, change_times)),
39
+ **get_prior_params("birthRate", FBD_RATE_UPPER, n_time_bins),
40
+ **get_prior_params("deathRate", FBD_RATE_UPPER, n_time_bins),
41
+ "samplingRate": FBD_SAMPLING_RATE,
42
+ "timePredictor": " ".join(map(str, np.linspace(0, 1, n_time_bins))),
43
+ },
44
+ targets={
45
+ f"{rate}Rate": {f"{rate}RateSPi{i}": values[i] for i in range(n_time_bins)}
46
+ for rate, values in rates.items()
47
+ },
48
+ features={
49
+ f"{rate}Rate": {
50
+ "timePredictor": Feature(
51
+ is_binary=False, is_relevant=len(set(values)) > 1
52
+ ),
53
+ "randomPredictor": Feature(is_binary=False, is_relevant=False),
54
+ }
55
+ for rate, values in rates.items()
56
+ },
57
+ )
58
+
59
+
60
+ RATES = [
61
+ {"birth": [0.2] * 10, "death": [0.1] * 10},
62
+ {
63
+ "birth": np.linspace(0.4, 0.1, 10).tolist(),
64
+ "death": np.linspace(0.1, 0.2, 10).tolist(),
65
+ },
66
+ {
67
+ "birth": [0.4] * 5 + [0.1] * 3 + [0.01] * 2,
68
+ "death": [0.05] * 7 + [0.3] * 1 + [0.01] * 2,
69
+ },
70
+ ]
71
+ SCENARIOS = [_get_scenario(r) for r in RATES]
@@ -0,0 +1,26 @@
1
+ from collections.abc import Callable
2
+ from dataclasses import dataclass
3
+ from enum import Enum
4
+ from typing import Any
5
+
6
+ from numpy.random import Generator
7
+ from phylogenie.treesimulator import Event
8
+
9
+ from bella_companion.simulations.features import Feature
10
+
11
+
12
+ class ScenarioType(Enum):
13
+ EPI = "epi"
14
+ FBD = "fbd"
15
+
16
+
17
+ @dataclass
18
+ class Scenario:
19
+ type: ScenarioType
20
+ max_time: float
21
+ events: list[Event]
22
+ get_random_predictor: Callable[[Generator], list[float]]
23
+ beast_args: dict[str, Any]
24
+ targets: dict[str, dict[str, float]]
25
+ features: dict[str, dict[str, Feature]]
26
+ init_state: str | None = None
@@ -0,0 +1,39 @@
1
+ import json
2
+ import os
3
+
4
+ import joblib
5
+
6
+ from src.config import BEAST_LOGS_SUMMARIES_DIR, BEAST_OUTPUTS_DIR
7
+ from src.simulations.scenarios import SCENARIOS
8
+ from src.utils import summarize_logs
9
+
10
+
11
+ def main():
12
+ with open(BEAST_OUTPUTS_DIR / "simulations_job_ids.json", "r") as f:
13
+ job_ids: dict[str, dict[str, dict[str, str]]] = json.load(f)
14
+
15
+ for scenario_name, scenario in SCENARIOS.items():
16
+ summaries_dir = BEAST_LOGS_SUMMARIES_DIR / scenario_name
17
+ os.makedirs(summaries_dir, exist_ok=True)
18
+ for model in job_ids[scenario_name]:
19
+ hidden_nodes = (
20
+ list(map(int, model.split("-")[1].split("_")))
21
+ if model.startswith("MLP")
22
+ else None
23
+ )
24
+ logs_dir = BEAST_OUTPUTS_DIR / scenario_name / model
25
+ print(f"Summarizing {scenario_name} - {model}")
26
+ logs_summary, weights = summarize_logs(
27
+ logs_dir,
28
+ target_columns=[c for t in scenario.targets.values() for c in t],
29
+ hidden_nodes=hidden_nodes,
30
+ n_features={t: len(fs) for t, fs in scenario.features.items()},
31
+ job_ids=job_ids[scenario_name][model],
32
+ )
33
+ logs_summary.write_csv(summaries_dir / f"{model}.csv")
34
+ if weights is not None:
35
+ joblib.dump(weights, summaries_dir / f"{model}.weights.pkl")
36
+
37
+
38
+ if __name__ == "__main__":
39
+ main()
@@ -0,0 +1,164 @@
1
+ import os
2
+ import re
3
+ import subprocess
4
+ from glob import glob
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import arviz as az
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import polars as pl
12
+ from joblib import Parallel, delayed
13
+ from lumiere.backend.typings import Weights
14
+ from tqdm import tqdm
15
+
16
+
17
+ def run_sbatch(
18
+ command: str,
19
+ log_dir: Path,
20
+ time: str = "240:00:00",
21
+ mem_per_cpu: str = "2000",
22
+ overwrite: bool = False,
23
+ ) -> str | None:
24
+ if not overwrite and log_dir.exists():
25
+ print(f"Log directory {log_dir} already exists. Skipping.")
26
+ return
27
+ cmd = " ".join(
28
+ [
29
+ "sbatch",
30
+ f"-J {log_dir}",
31
+ f"-o {log_dir / 'output.out'}",
32
+ f"-e {log_dir / 'error.err'}",
33
+ f"--time {time}",
34
+ f"--mem-per-cpu={mem_per_cpu}",
35
+ f"--wrap='{command}'",
36
+ ]
37
+ )
38
+ output = subprocess.run(cmd, shell=True, capture_output=True, text=True)
39
+ job_id = re.search(r"Submitted batch job (\d+)", output.stdout)
40
+ if job_id is None:
41
+ raise RuntimeError(
42
+ f"Failed to submit job.\nCommand: {cmd}\nOutput: {output.stdout}\nError: {output.stderr}"
43
+ )
44
+ return job_id.group(1)
45
+
46
+
47
+ def get_job_metadata(job_id: str):
48
+ output = subprocess.run(
49
+ f"myjobs -j {job_id}", shell=True, capture_output=True, text=True
50
+ ).stdout
51
+
52
+ status = re.search(r"Status\s+:\s+(\w+)", output)
53
+ if status is None:
54
+ raise RuntimeError(f"Failed to get job status for job {job_id}")
55
+ status = status.group(1)
56
+
57
+ wall_clock = re.search(r"Wall-clock\s+:\s+([\d\-:]+)", output)
58
+ if wall_clock is None:
59
+ raise RuntimeError(f"Failed to get wall-clock time for job {job_id}")
60
+ wall_clock = wall_clock.group(1)
61
+
62
+ if "-" in wall_clock:
63
+ days, wall_clock = wall_clock.split("-")
64
+ days = int(days)
65
+ else:
66
+ days = 0
67
+ hours, minutes, seconds = map(int, wall_clock.split(":"))
68
+ total_hours = days * 24 + hours + minutes / 60 + seconds / 3600
69
+
70
+ return {"status": status, "total_hours": total_hours}
71
+
72
+
73
+ def summarize_log(
74
+ log_file: str,
75
+ target_columns: list[str],
76
+ burn_in: float = 0.1,
77
+ hdi_prob: float = 0.95,
78
+ hidden_nodes: list[int] | None = None,
79
+ n_weights_samples: int = 100,
80
+ n_features: dict[str, int] | None = None,
81
+ job_id: str | None = None,
82
+ ) -> tuple[dict[str, Any], dict[str, list[Weights]] | None]:
83
+ df = pl.read_csv(log_file, separator="\t", comment_prefix="#")
84
+ df = df.filter(pl.col("Sample") > burn_in * len(df))
85
+ targets_df = df.select(target_columns)
86
+ summary: dict[str, Any] = {"n_samples": len(df)}
87
+ for column in targets_df.columns:
88
+ summary[f"{column}_median"] = targets_df[column].median()
89
+ summary[f"{column}_ess"] = az.ess( # pyright: ignore[reportUnknownMemberType]
90
+ np.array(targets_df[column])
91
+ )
92
+ lower, upper = az.hdi( # pyright: ignore[reportUnknownMemberType]
93
+ np.array(targets_df[column]), hdi_prob=hdi_prob
94
+ )
95
+ summary[f"{column}_lower"] = lower
96
+ summary[f"{column}_upper"] = upper
97
+ if job_id is not None:
98
+ summary.update(get_job_metadata(job_id))
99
+ if hidden_nodes is not None:
100
+ if n_features is None:
101
+ raise ValueError("`n_features` must be provided to summarize log weights.")
102
+ weights: dict[str, list[Weights]] = {}
103
+ for target, n in n_features.items():
104
+ nodes = [n, *hidden_nodes, 1]
105
+ layer_weights = [
106
+ np.array(
107
+ df.tail(n_weights_samples).select(
108
+ c for c in df.columns if c.startswith(f"{target}W.{i}")
109
+ )
110
+ ).reshape(-1, n_inputs + 1, n_outputs)
111
+ for i, (n_inputs, n_outputs) in enumerate(zip(nodes[:-1], nodes[1:]))
112
+ ]
113
+ weights[target] = [
114
+ list(sample_weights) for sample_weights in zip(*layer_weights)
115
+ ]
116
+ return summary, weights
117
+ return summary, None
118
+
119
+
120
+ def summarize_logs(
121
+ logs_dir: Path,
122
+ target_columns: list[str],
123
+ burn_in: float = 0.1,
124
+ hdi_prob: float = 0.95,
125
+ hidden_nodes: list[int] | None = None,
126
+ n_weights_samples: int = 100,
127
+ n_features: dict[str, int] | None = None,
128
+ job_ids: dict[str, str] | None = None,
129
+ ) -> tuple[pl.DataFrame, dict[str, list[list[Weights]]] | None]:
130
+ def _get_log_summary(
131
+ log_file: str,
132
+ ) -> tuple[dict[str, Any], dict[str, list[Weights]] | None]:
133
+ log_id = Path(log_file).stem
134
+ summary, weights = summarize_log(
135
+ log_file=log_file,
136
+ target_columns=target_columns,
137
+ burn_in=burn_in,
138
+ hdi_prob=hdi_prob,
139
+ hidden_nodes=hidden_nodes,
140
+ n_weights_samples=n_weights_samples,
141
+ n_features=n_features,
142
+ job_id=job_ids[log_id] if job_ids is not None else None,
143
+ )
144
+ return {"id": log_id, **summary}, weights
145
+
146
+ os.environ["POLARS_MAX_THREADS"] = "1"
147
+ summaries = Parallel(n_jobs=-1)(
148
+ delayed(_get_log_summary)(log_file)
149
+ for log_file in tqdm(glob(str(logs_dir / "*.log")))
150
+ )
151
+ data, weights = zip(*summaries)
152
+ if any(w is not None for w in weights):
153
+ assert n_features is not None
154
+ return pl.DataFrame(data), {t: [w[t] for w in weights] for t in n_features}
155
+ return pl.DataFrame(data), None
156
+
157
+
158
+ def set_plt_rcparams():
159
+ plt.rcParams["pdf.fonttype"] = 42
160
+ plt.rcParams["xtick.labelsize"] = 14
161
+ plt.rcParams["ytick.labelsize"] = 14
162
+ plt.rcParams["font.size"] = 14
163
+ plt.rcParams["figure.constrained_layout.use"] = True
164
+ plt.rcParams["lines.linewidth"] = 3
@@ -0,0 +1,13 @@
1
+ Metadata-Version: 2.1
2
+ Name: bella-companion
3
+ Version: 0.0.0
4
+ Summary:
5
+ Author: gabriele-marino
6
+ Author-email: gabmarino.8601@gmail.com
7
+ Requires-Python: >=3.10,<4.0
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.10
10
+ Classifier: Programming Language :: Python :: 3.11
11
+ Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: dotenv (>=0.9.9,<0.10.0)
13
+ Requires-Dist: phylogenie (>=2.1.21,<3.0.0)
@@ -0,0 +1,34 @@
1
+ bella_companion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ bella_companion/cli.py,sha256=IUODGLiDcxrF40ZjL-SeQtEQhoPgB989KJiXXU0-Pik,576
3
+ bella_companion/fbd_empirical/data/body_mass.csv,sha256=-UkKNtm9m3g4PjY3BcfdP6z5nL_I6p9cq6cgZ-bWKI8,30360
4
+ bella_companion/fbd_empirical/data/change_times.csv,sha256=zmc9_z91-XMwKyIoP9v9dVlLcf4MeIHkQiHLjoMriOo,120
5
+ bella_companion/fbd_empirical/data/sampling_change_times.csv,sha256=Gwi9RcMFy89RyvfxKVZ_MoKVRHOZLuwB_3LEaq8asMQ,32
6
+ bella_companion/fbd_empirical/data/trees.nwk,sha256=zhvLvPLZelhMThVmvOENkmi3p2aPAARb8KMdHTm6mss,4645318
7
+ bella_companion/fbd_empirical/figure.py,sha256=4paOXCB1EcxuHzLPxDSleQU2AQ_ndTedtzS1ugiKICs,1018
8
+ bella_companion/fbd_empirical/notbooks.ipynb,sha256=O45kmz0lZENRDFbKXEWPsIKATfF5GVeS5tCYmrGLnqk,83326
9
+ bella_companion/fbd_empirical/params.json,sha256=hU23LniClZL_GSBAxIEJUJgMa93AM8zdtFOq6mt3vkI,311
10
+ bella_companion/fbd_empirical/run_beast.py,sha256=2sV2UmxOfWmbueiU6D0p3lueMYiZyIkSKYoblTMrYuA,1935
11
+ bella_companion/fbd_empirical/summarize_logs.py,sha256=O6rhE606Wa98a8b1KKlLPjUOro1pfyqVTLdQksQMG0g,1439
12
+ bella_companion/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ bella_companion/simulations/features.py,sha256=DZOBpJGlQ0UinqUZYbEtoemZ2eQGVLV_i-DfpW31qJI,104
14
+ bella_companion/simulations/figures/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ bella_companion/simulations/figures/epi_explainations.py,sha256=RL9fyjl0a_zPhrGdUXqbMMu6471su8B-O6LyuFlHknw,2816
16
+ bella_companion/simulations/figures/epi_predictions.py,sha256=4yXwOBKxUv4kgZdI9zAMEhZ0QCNKZdkAafRQ1RTeaWg,1835
17
+ bella_companion/simulations/figures/fbd_explainations.py,sha256=9Uj7yttpn_TH5HqycW8R-Nlky9A9aFXDXRpXQuT1L4s,3037
18
+ bella_companion/simulations/figures/fbd_predictions.py,sha256=jdXYCLledZEWoPCIuTLhHEPMdeG6YXvf5xZnEOslv-U,2119
19
+ bella_companion/simulations/figures/scenarios.py,sha256=vyybn3Qhfq96N8tvW0wSzpFoHHP8EIc8dkOz63o_Atw,2492
20
+ bella_companion/simulations/figures/utils.py,sha256=sY8wFBg02fv5ugpJ80EqQishD_HEdLwhqsw2LfM7wEo,8539
21
+ bella_companion/simulations/generate_data.py,sha256=H8OV4ZlTGZB-jXaROTPmOsK3UxRiU-GrX40l-shliw8,728
22
+ bella_companion/simulations/run_beast.py,sha256=NBGfb5ZvtrLX5sA6Ku4SNHqmPGoEXFj5DmV54ZR4zVs,3411
23
+ bella_companion/simulations/scenarios/__init__.py,sha256=3Kl1lKcFpfb3vLX64DmSW4XCF5kXU1ZoHtstFH-ZIzU,876
24
+ bella_companion/simulations/scenarios/common.py,sha256=_ddaSuTvEVdttGkXB4HPc2B7IB1F_GBOCW3cVOPZ-ZM,807
25
+ bella_companion/simulations/scenarios/epi_multitype.py,sha256=GWGIiqvYwX_FrT_3RXkZKYGDht9nZ7ceHRBKUvXDPnA,2432
26
+ bella_companion/simulations/scenarios/epi_skyline.py,sha256=JqnOVATECxBUqEbkR5lBlMI2O8k4hO6ipR8k9cHUsm0,2365
27
+ bella_companion/simulations/scenarios/fbd_2traits.py,sha256=sCtdWyV6GQQOIhnL9Dd8NIbAR-StTwUTD9-b_BalmFQ,3552
28
+ bella_companion/simulations/scenarios/fbd_no_traits.py,sha256=R6CH0fVeQg-Iesl39pq2uY8ICVEO4VZbvUVUCGwauJU,2520
29
+ bella_companion/simulations/scenarios/scenario.py,sha256=_FRWAyOFbw94lAzd3zCD-1ek4TrssoiXfXRQPShLiIA,620
30
+ bella_companion/simulations/summarize_logs.py,sha256=TXaO9cjzl5O1u0fPZpRl-9txzoN-p-fkhoAHoRXTfm8,1433
31
+ bella_companion/utils.py,sha256=26cF3oVBbsahYPO9rcK69l43ybg5AjS12IyfucgyVIM,5666
32
+ bella_companion-0.0.0.dist-info/METADATA,sha256=j55dzUiDk-NtHXDt3bAQ3MYH3fkMDKNmwZ4OD71TAm4,446
33
+ bella_companion-0.0.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
34
+ bella_companion-0.0.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: poetry-core 1.9.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any