bella-companion 0.0.0__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bella-companion might be problematic. Click here for more details.

bella_companion/cli.py CHANGED
@@ -1,24 +1,33 @@
1
1
  import argparse
2
+ import os
3
+ from pathlib import Path
2
4
 
3
5
  from dotenv import load_dotenv
4
6
 
7
+ from bella_companion.simulations import generate_data, run_beast, summarize_logs
8
+
5
9
 
6
10
  def main():
7
- load_dotenv()
11
+ load_dotenv(Path(os.getcwd()) / ".env")
8
12
 
9
- parser = argparse.ArgumentParser(prog="bella")
10
- subparsers = parser.add_subparsers(dest="command")
13
+ parser = argparse.ArgumentParser(
14
+ prog="bella",
15
+ description="Companion tool with experiments and evaluation for Bayesian Evolutionary Layered Learning Architectures (BELLA) BEAST2 package.",
16
+ )
11
17
 
12
- gen_sim_data_parser = subparsers.add_parser("generate-simulations-data")
13
- generate_simulations_data_parser.set_defaults(func=generate_simulations_data)
18
+ subparsers = parser.add_subparsers(dest="command", required=True)
14
19
 
15
- args = parser.parse_args()
16
- if hasattr(args, "func"):
17
- args.func(args)
18
- else:
19
- parser.print_help()
20
+ subparsers.add_parser(
21
+ "generate-simulations-data", help="Generate simulated data."
22
+ ).set_defaults(func=generate_data)
20
23
 
24
+ subparsers.add_parser(
25
+ "run-beast-simulations", help="Run BEAST2 on simulated data."
26
+ ).set_defaults(func=run_beast)
21
27
 
22
- def generate_simulations_data(args):
23
- print("Generating simulations data...")
24
- # your logic here
28
+ subparsers.add_parser(
29
+ "summarize-simulation-logs", help="Summarize simulation logs."
30
+ ).set_defaults(func=summarize_logs)
31
+
32
+ args = parser.parse_args()
33
+ args.func()
@@ -0,0 +1,5 @@
1
+ from bella_companion.simulations.generate_data import generate_data
2
+ from bella_companion.simulations.run_beast import run_beast
3
+ from bella_companion.simulations.summarize_logs import summarize_logs
4
+
5
+ __all__ = ["generate_data", "run_beast", "summarize_logs"]
@@ -5,9 +5,8 @@ import joblib
5
5
  import matplotlib.pyplot as plt
6
6
  import numpy as np
7
7
  import polars as pl
8
- from lumiere.backend import sigmoid
9
-
10
8
  import src.config as cfg
9
+ from lumiere.backend import sigmoid
11
10
  from src.simulations.figures.utils import (
12
11
  plot_partial_dependencies,
13
12
  plot_shap_features_importance,
@@ -21,6 +20,15 @@ from src.simulations.scenarios.epi_multitype import (
21
20
  from src.utils import set_plt_rcparams
22
21
 
23
22
 
23
+ def set_plt_rcparams():
24
+ plt.rcParams["pdf.fonttype"] = 42
25
+ plt.rcParams["xtick.labelsize"] = 14
26
+ plt.rcParams["ytick.labelsize"] = 14
27
+ plt.rcParams["font.size"] = 14
28
+ plt.rcParams["figure.constrained_layout.use"] = True
29
+ plt.rcParams["lines.linewidth"] = 3
30
+
31
+
24
32
  def _plot_predictions(log_summary: pl.DataFrame, output_dir: str):
25
33
  sort_idx = np.argsort(MIGRATION_PREDICTOR.flatten())
26
34
 
@@ -9,20 +9,23 @@ from phylogenie import Tree, load_newick
9
9
  from phylogenie.utils import get_node_depths
10
10
  from tqdm import tqdm
11
11
 
12
- import config as cfg
13
12
  from bella_companion.simulations.scenarios import SCENARIOS, ScenarioType
14
- from bella_companion.utils import run_sbatch
13
+ from bella_companion.utils import submit_job
15
14
 
16
15
 
17
- def main():
16
+ def run_beast():
18
17
  rng = default_rng(42)
18
+ base_data_dir = Path(os.environ["BELLA_SIMULATIONS_DATA_DIR"])
19
+ base_output_dir = Path(os.environ["BELLA_BEAST_OUTPUT_DIR"])
20
+
19
21
  job_ids = {}
20
22
  for scenario_name, scenario in SCENARIOS.items():
21
23
  job_ids[scenario_name] = defaultdict(dict)
22
- data_dir = cfg.SIMULATED_DATA_DIR / scenario_name
23
- inference_configs_dir = (
24
+ data_dir = base_data_dir / scenario_name
25
+ inference_configs_dir = Path(os.environ["BELLA_BEAST_CONFIGS_DIR"]) / (
24
26
  scenario_name.split("_")[0] if "_" in scenario_name else scenario_name
25
27
  )
28
+ log_dir = Path(os.environ["BELLA_SBATCH_LOG_DIR"]) / scenario_name
26
29
  for tree_file in tqdm(
27
30
  glob(str(data_dir / "*.nwk")),
28
31
  desc=f"Submitting BEAST2 jobs for {scenario_name}",
@@ -31,11 +34,13 @@ def main():
31
34
  for model in ["Nonparametric", "GLM"] + [
32
35
  f"MLP-{hidden_nodes}" for hidden_nodes in ["3_2", "16_8", "32_16"]
33
36
  ]:
34
- outputs_dir = cfg.BEAST_OUTPUTS_DIR / scenario_name / model
35
- os.makedirs(outputs_dir, exist_ok=True)
37
+ output_dir = base_output_dir / scenario_name / model
38
+ os.makedirs(output_dir, exist_ok=True)
39
+
36
40
  beast_args = [
37
41
  f"-D treeFile={tree_file},treeID={tree_id}",
38
- f"-prefix {outputs_dir}{os.sep}",
42
+ f"-prefix {output_dir}{os.sep}",
43
+ f'-D randomPredictor="{" ".join(map(str, scenario.get_random_predictor(rng)))}"',
39
44
  ]
40
45
  beast_args.extend(
41
46
  [
@@ -43,9 +48,6 @@ def main():
43
48
  for key, value in scenario.beast_args.items()
44
49
  ]
45
50
  )
46
- beast_args.append(
47
- f'-D randomPredictor="{" ".join(map(str, scenario.get_random_predictor(rng)))}"'
48
- )
49
51
  if scenario.type == ScenarioType.EPI:
50
52
  tree = load_newick(tree_file)
51
53
  assert isinstance(tree, Tree)
@@ -53,40 +55,24 @@ def main():
53
55
  f"-D lastSampleTime={max(get_node_depths(tree).values())}"
54
56
  )
55
57
 
58
+ base_command = [os.environ["BELLA_RUN_BEAST_CMD"], *beast_args]
56
59
  if model in ["Nonparametric", "GLM"]:
57
60
  command = " ".join(
58
- [
59
- cfg.RUN_BEAST,
60
- *beast_args,
61
- str(
62
- cfg.BEAST_CONFIGS_DIR
63
- / inference_configs_dir
64
- / f"{model}.xml"
65
- ),
66
- ]
61
+ [*base_command, str(inference_configs_dir / f"{model}.xml")]
67
62
  )
68
63
  else:
69
64
  nodes = model.split("-")[1].split("_")
70
65
  command = " ".join(
71
66
  [
72
- cfg.RUN_BEAST,
73
- *beast_args,
67
+ *base_command,
74
68
  f'-D nodes="{" ".join(map(str, nodes))}"',
75
- str(
76
- cfg.BEAST_CONFIGS_DIR
77
- / inference_configs_dir
78
- / "MLP.xml"
79
- ),
69
+ str(inference_configs_dir / "MLP.xml"),
80
70
  ]
81
71
  )
82
72
 
83
- job_ids[scenario_name][model][tree_id] = run_sbatch(
84
- command, cfg.SBATCH_LOGS_DIR / scenario_name / model / tree_id
73
+ job_ids[scenario_name][model][tree_id] = submit_job(
74
+ command, log_dir / model / tree_id
85
75
  )
86
76
 
87
- with open(cfg.BEAST_OUTPUTS_DIR / "simulations_job_ids.json", "w") as f:
77
+ with open(base_output_dir / "simulations_job_ids.json", "w") as f:
88
78
  json.dump(job_ids, f)
89
-
90
-
91
- if __name__ == "__main__":
92
- main()
@@ -1,39 +1,31 @@
1
1
  import json
2
2
  import os
3
+ from pathlib import Path
3
4
 
4
5
  import joblib
5
6
 
6
- from src.config import BEAST_LOGS_SUMMARIES_DIR, BEAST_OUTPUTS_DIR
7
- from src.simulations.scenarios import SCENARIOS
8
- from src.utils import summarize_logs
7
+ from bella_companion.simulations.scenarios import SCENARIOS
8
+ from bella_companion.utils import summarize_logs as _summarize_logs
9
+ from bella_companion.utils import summarize_weights
9
10
 
10
11
 
11
- def main():
12
- with open(BEAST_OUTPUTS_DIR / "simulations_job_ids.json", "r") as f:
12
+ def summarize_logs():
13
+ output_dir = Path(os.environ["BELLA_BEAST_OUTPUT_DIR"])
14
+ with open(output_dir / "simulations_job_ids.json", "r") as f:
13
15
  job_ids: dict[str, dict[str, dict[str, str]]] = json.load(f)
14
16
 
15
17
  for scenario_name, scenario in SCENARIOS.items():
16
- summaries_dir = BEAST_LOGS_SUMMARIES_DIR / scenario_name
18
+ summaries_dir = Path(os.environ["BEAST_LOGS_SUMMARIES_DIR"]) / scenario_name
17
19
  os.makedirs(summaries_dir, exist_ok=True)
18
20
  for model in job_ids[scenario_name]:
19
- hidden_nodes = (
20
- list(map(int, model.split("-")[1].split("_")))
21
- if model.startswith("MLP")
22
- else None
23
- )
24
- logs_dir = BEAST_OUTPUTS_DIR / scenario_name / model
21
+ logs_dir = output_dir / scenario_name / model
25
22
  print(f"Summarizing {scenario_name} - {model}")
26
- logs_summary, weights = summarize_logs(
23
+ summary = _summarize_logs(
27
24
  logs_dir,
28
25
  target_columns=[c for t in scenario.targets.values() for c in t],
29
- hidden_nodes=hidden_nodes,
30
- n_features={t: len(fs) for t, fs in scenario.features.items()},
31
26
  job_ids=job_ids[scenario_name][model],
32
27
  )
33
- logs_summary.write_csv(summaries_dir / f"{model}.csv")
34
- if weights is not None:
28
+ summary.write_csv(summaries_dir / f"{model}.csv")
29
+ if model.startswith("MLP"):
30
+ weights = summarize_weights(logs_dir)
35
31
  joblib.dump(weights, summaries_dir / f"{model}.weights.pkl")
36
-
37
-
38
- if __name__ == "__main__":
39
- main()
@@ -0,0 +1,4 @@
1
+ from bella_companion.utils.beast import summarize_log, summarize_logs, summarize_weights
2
+ from bella_companion.utils.slurm import submit_job
3
+
4
+ __all__ = ["submit_job", "summarize_log", "summarize_logs", "summarize_weights"]
@@ -0,0 +1,69 @@
1
+ import os
2
+ from functools import partial
3
+ from glob import glob
4
+ from pathlib import Path
5
+ from typing import Any
6
+
7
+ import arviz as az
8
+ import numpy as np
9
+ import polars as pl
10
+ from joblib import Parallel, delayed
11
+ from lumiere import read_log_file, read_weights
12
+ from lumiere.backend.typing import Weights
13
+ from tqdm import tqdm
14
+
15
+ from bella_companion.utils.slurm import get_job_metadata
16
+
17
+
18
+ def summarize_log(
19
+ log_file: str,
20
+ target_columns: list[str],
21
+ burn_in: int | float = 0.1,
22
+ hdi_prob: float = 0.95,
23
+ job_id: str | None = None,
24
+ ) -> dict[str, Any]:
25
+ log = read_log_file(log_file, burn_in=burn_in)
26
+ log = log.select(target_columns)
27
+ summary: dict[str, Any] = {"id": Path(log_file).stem, "n_samples": len(log)}
28
+ for column in log.columns:
29
+ summary[f"{column}_median"] = log[column].median()
30
+ summary[f"{column}_ess"] = az.ess(np.array(log[column])) # pyright: ignore
31
+ lower, upper = az.hdi(np.array(log[column]), hdi_prob) # pyright: ignore
32
+ summary[f"{column}_lower"] = lower
33
+ summary[f"{column}_upper"] = upper
34
+ if job_id is not None:
35
+ summary.update(get_job_metadata(job_id))
36
+ return summary
37
+
38
+
39
+ def summarize_logs(
40
+ logs_dir: Path,
41
+ target_columns: list[str],
42
+ burn_in: float = 0.1,
43
+ hdi_prob: float = 0.95,
44
+ job_ids: dict[str, str] | None = None,
45
+ ) -> pl.DataFrame:
46
+ os.environ["POLARS_MAX_THREADS"] = "1"
47
+ summaries = Parallel(n_jobs=-1)(
48
+ delayed(
49
+ partial(
50
+ summarize_log,
51
+ target_columns=target_columns,
52
+ burn_in=burn_in,
53
+ hdi_prob=hdi_prob,
54
+ job_id=None if job_ids is None else job_ids[Path(log_file).stem],
55
+ )
56
+ )(log_file)
57
+ for log_file in tqdm(glob(str(logs_dir / "*.log")))
58
+ )
59
+ return pl.DataFrame(summaries)
60
+
61
+
62
+ def summarize_weights(
63
+ logs_dir: Path, n_samples: int = 100, burn_in: float = 0.1
64
+ ) -> list[dict[str, list[Weights]]]:
65
+ os.environ["POLARS_MAX_THREADS"] = "1"
66
+ return Parallel(n_jobs=-1)(
67
+ delayed(partial(read_weights, burn_in=burn_in, n_samples=n_samples))(log_file)
68
+ for log_file in tqdm(glob(str(logs_dir / "*.log")))
69
+ )
@@ -0,0 +1,58 @@
1
+ import re
2
+ import subprocess
3
+ from pathlib import Path
4
+
5
+
6
+ def submit_job(
7
+ command: str, log_dir: Path, time: str = "240:00:00", mem_per_cpu: str = "2000"
8
+ ) -> str | None:
9
+ if log_dir.exists():
10
+ print(f"Log directory {log_dir} already exists. Skipping.")
11
+ return
12
+ cmd = " ".join(
13
+ [
14
+ "sbatch",
15
+ f"-J {log_dir}",
16
+ f"-o {log_dir / 'output.out'}",
17
+ f"-e {log_dir / 'error.err'}",
18
+ f"--time {time}",
19
+ f"--mem-per-cpu={mem_per_cpu}",
20
+ f"--wrap='{command}'",
21
+ ]
22
+ )
23
+ output = subprocess.run(cmd, shell=True, capture_output=True, text=True)
24
+ job_id = re.search(r"Submitted batch job (\d+)", output.stdout)
25
+ if job_id is None:
26
+ raise RuntimeError(
27
+ f"Failed to submit job.\n"
28
+ f"Command: {cmd}\n"
29
+ f"Output: {output.stdout}\n"
30
+ f"Error: {output.stderr}"
31
+ )
32
+ return job_id.group(1)
33
+
34
+
35
+ def get_job_metadata(job_id: str):
36
+ output = subprocess.run(
37
+ f"myjobs -j {job_id}", shell=True, capture_output=True, text=True
38
+ ).stdout
39
+
40
+ status = re.search(r"Status\s+:\s+(\w+)", output)
41
+ if status is None:
42
+ raise ValueError(f"Failed to get job status for job {job_id}")
43
+ status = status.group(1)
44
+
45
+ wall_clock = re.search(r"Wall-clock\s+:\s+([\d\-:]+)", output)
46
+ if wall_clock is None:
47
+ raise ValueError(f"Failed to get wall-clock time for job {job_id}")
48
+ wall_clock = wall_clock.group(1)
49
+
50
+ if "-" in wall_clock:
51
+ days, wall_clock = wall_clock.split("-")
52
+ days = int(days)
53
+ else:
54
+ days = 0
55
+ hours, minutes, seconds = map(int, wall_clock.split(":"))
56
+ total_hours = days * 24 + hours + minutes / 60 + seconds / 3600
57
+
58
+ return {"status": status, "total_hours": total_hours}
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: bella-companion
3
- Version: 0.0.0
3
+ Version: 0.0.2
4
4
  Summary:
5
5
  Author: gabriele-marino
6
6
  Author-email: gabmarino.8601@gmail.com
@@ -9,5 +9,7 @@ Classifier: Programming Language :: Python :: 3
9
9
  Classifier: Programming Language :: Python :: 3.10
10
10
  Classifier: Programming Language :: Python :: 3.11
11
11
  Classifier: Programming Language :: Python :: 3.12
12
+ Requires-Dist: arviz (>=0.22.0,<0.23.0)
13
+ Requires-Dist: bella-lumiere (>=0.0.10,<0.0.11)
12
14
  Requires-Dist: dotenv (>=0.9.9,<0.10.0)
13
- Requires-Dist: phylogenie (>=2.1.21,<3.0.0)
15
+ Requires-Dist: phylogenie (>=2.1.27,<3.0.0)
@@ -1,5 +1,5 @@
1
1
  bella_companion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- bella_companion/cli.py,sha256=IUODGLiDcxrF40ZjL-SeQtEQhoPgB989KJiXXU0-Pik,576
2
+ bella_companion/cli.py,sha256=0sPnzGyUGo2OBZ0rj17ZGzMdwNH0o-BXKsYtCJjzGvQ,968
3
3
  bella_companion/fbd_empirical/data/body_mass.csv,sha256=-UkKNtm9m3g4PjY3BcfdP6z5nL_I6p9cq6cgZ-bWKI8,30360
4
4
  bella_companion/fbd_empirical/data/change_times.csv,sha256=zmc9_z91-XMwKyIoP9v9dVlLcf4MeIHkQiHLjoMriOo,120
5
5
  bella_companion/fbd_empirical/data/sampling_change_times.csv,sha256=Gwi9RcMFy89RyvfxKVZ_MoKVRHOZLuwB_3LEaq8asMQ,32
@@ -9,17 +9,17 @@ bella_companion/fbd_empirical/notbooks.ipynb,sha256=O45kmz0lZENRDFbKXEWPsIKATfF5
9
9
  bella_companion/fbd_empirical/params.json,sha256=hU23LniClZL_GSBAxIEJUJgMa93AM8zdtFOq6mt3vkI,311
10
10
  bella_companion/fbd_empirical/run_beast.py,sha256=2sV2UmxOfWmbueiU6D0p3lueMYiZyIkSKYoblTMrYuA,1935
11
11
  bella_companion/fbd_empirical/summarize_logs.py,sha256=O6rhE606Wa98a8b1KKlLPjUOro1pfyqVTLdQksQMG0g,1439
12
- bella_companion/simulations/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ bella_companion/simulations/__init__.py,sha256=i6Fe7l5sUJY9hPxdg6L_FVhwbSPhNxQNMb-m33JlfxI,258
13
13
  bella_companion/simulations/features.py,sha256=DZOBpJGlQ0UinqUZYbEtoemZ2eQGVLV_i-DfpW31qJI,104
14
14
  bella_companion/simulations/figures/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- bella_companion/simulations/figures/epi_explainations.py,sha256=RL9fyjl0a_zPhrGdUXqbMMu6471su8B-O6LyuFlHknw,2816
15
+ bella_companion/simulations/figures/epi_explainations.py,sha256=omiJgyIY-I6zcJAcyOF7GJ2pba6pMZySLkWy7OFrjFY,3093
16
16
  bella_companion/simulations/figures/epi_predictions.py,sha256=4yXwOBKxUv4kgZdI9zAMEhZ0QCNKZdkAafRQ1RTeaWg,1835
17
17
  bella_companion/simulations/figures/fbd_explainations.py,sha256=9Uj7yttpn_TH5HqycW8R-Nlky9A9aFXDXRpXQuT1L4s,3037
18
18
  bella_companion/simulations/figures/fbd_predictions.py,sha256=jdXYCLledZEWoPCIuTLhHEPMdeG6YXvf5xZnEOslv-U,2119
19
19
  bella_companion/simulations/figures/scenarios.py,sha256=vyybn3Qhfq96N8tvW0wSzpFoHHP8EIc8dkOz63o_Atw,2492
20
20
  bella_companion/simulations/figures/utils.py,sha256=sY8wFBg02fv5ugpJ80EqQishD_HEdLwhqsw2LfM7wEo,8539
21
21
  bella_companion/simulations/generate_data.py,sha256=H8OV4ZlTGZB-jXaROTPmOsK3UxRiU-GrX40l-shliw8,728
22
- bella_companion/simulations/run_beast.py,sha256=NBGfb5ZvtrLX5sA6Ku4SNHqmPGoEXFj5DmV54ZR4zVs,3411
22
+ bella_companion/simulations/run_beast.py,sha256=xOuwE0w4IbOqqCSym6kHsAEhfGT2mWdA-jmUZuviMbc,3121
23
23
  bella_companion/simulations/scenarios/__init__.py,sha256=3Kl1lKcFpfb3vLX64DmSW4XCF5kXU1ZoHtstFH-ZIzU,876
24
24
  bella_companion/simulations/scenarios/common.py,sha256=_ddaSuTvEVdttGkXB4HPc2B7IB1F_GBOCW3cVOPZ-ZM,807
25
25
  bella_companion/simulations/scenarios/epi_multitype.py,sha256=GWGIiqvYwX_FrT_3RXkZKYGDht9nZ7ceHRBKUvXDPnA,2432
@@ -27,8 +27,11 @@ bella_companion/simulations/scenarios/epi_skyline.py,sha256=JqnOVATECxBUqEbkR5lB
27
27
  bella_companion/simulations/scenarios/fbd_2traits.py,sha256=sCtdWyV6GQQOIhnL9Dd8NIbAR-StTwUTD9-b_BalmFQ,3552
28
28
  bella_companion/simulations/scenarios/fbd_no_traits.py,sha256=R6CH0fVeQg-Iesl39pq2uY8ICVEO4VZbvUVUCGwauJU,2520
29
29
  bella_companion/simulations/scenarios/scenario.py,sha256=_FRWAyOFbw94lAzd3zCD-1ek4TrssoiXfXRQPShLiIA,620
30
- bella_companion/simulations/summarize_logs.py,sha256=TXaO9cjzl5O1u0fPZpRl-9txzoN-p-fkhoAHoRXTfm8,1433
31
- bella_companion/utils.py,sha256=26cF3oVBbsahYPO9rcK69l43ybg5AjS12IyfucgyVIM,5666
32
- bella_companion-0.0.0.dist-info/METADATA,sha256=j55dzUiDk-NtHXDt3bAQ3MYH3fkMDKNmwZ4OD71TAm4,446
33
- bella_companion-0.0.0.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
34
- bella_companion-0.0.0.dist-info/RECORD,,
30
+ bella_companion/simulations/summarize_logs.py,sha256=N4W41IbTeJDbJyYZ5HCGyMPz6hKTkWdnbMfowqlD3J0,1264
31
+ bella_companion/utils/__init__.py,sha256=_5tLPH_3GHtimNcH0Yd9Z6yIM3WkWkNApNGLzFnF6nY,222
32
+ bella_companion/utils/beast.py,sha256=RG-iSEFuL92K6yxUV2nxdmcVqfrEiPhaYTmReW4ZoWk,2189
33
+ bella_companion/utils/slurm.py,sha256=v5DaG7YHVyK8KRFptgGDC6I8jxEhyJuMVK9N08pZSAI,1812
34
+ bella_companion-0.0.2.dist-info/METADATA,sha256=3jBu7TyB8P3S1YO9CAnKHKdpD4IcFE4loKhz52xmZeQ,534
35
+ bella_companion-0.0.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
36
+ bella_companion-0.0.2.dist-info/entry_points.txt,sha256=rSeKoAhmjnQqAYFcXBv0gAM2ViJfJe0D8_dD-fWrXeg,50
37
+ bella_companion-0.0.2.dist-info/RECORD,,
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ bella=bella_companion.cli:main
3
+
bella_companion/utils.py DELETED
@@ -1,164 +0,0 @@
1
- import os
2
- import re
3
- import subprocess
4
- from glob import glob
5
- from pathlib import Path
6
- from typing import Any
7
-
8
- import arviz as az
9
- import matplotlib.pyplot as plt
10
- import numpy as np
11
- import polars as pl
12
- from joblib import Parallel, delayed
13
- from lumiere.backend.typings import Weights
14
- from tqdm import tqdm
15
-
16
-
17
- def run_sbatch(
18
- command: str,
19
- log_dir: Path,
20
- time: str = "240:00:00",
21
- mem_per_cpu: str = "2000",
22
- overwrite: bool = False,
23
- ) -> str | None:
24
- if not overwrite and log_dir.exists():
25
- print(f"Log directory {log_dir} already exists. Skipping.")
26
- return
27
- cmd = " ".join(
28
- [
29
- "sbatch",
30
- f"-J {log_dir}",
31
- f"-o {log_dir / 'output.out'}",
32
- f"-e {log_dir / 'error.err'}",
33
- f"--time {time}",
34
- f"--mem-per-cpu={mem_per_cpu}",
35
- f"--wrap='{command}'",
36
- ]
37
- )
38
- output = subprocess.run(cmd, shell=True, capture_output=True, text=True)
39
- job_id = re.search(r"Submitted batch job (\d+)", output.stdout)
40
- if job_id is None:
41
- raise RuntimeError(
42
- f"Failed to submit job.\nCommand: {cmd}\nOutput: {output.stdout}\nError: {output.stderr}"
43
- )
44
- return job_id.group(1)
45
-
46
-
47
- def get_job_metadata(job_id: str):
48
- output = subprocess.run(
49
- f"myjobs -j {job_id}", shell=True, capture_output=True, text=True
50
- ).stdout
51
-
52
- status = re.search(r"Status\s+:\s+(\w+)", output)
53
- if status is None:
54
- raise RuntimeError(f"Failed to get job status for job {job_id}")
55
- status = status.group(1)
56
-
57
- wall_clock = re.search(r"Wall-clock\s+:\s+([\d\-:]+)", output)
58
- if wall_clock is None:
59
- raise RuntimeError(f"Failed to get wall-clock time for job {job_id}")
60
- wall_clock = wall_clock.group(1)
61
-
62
- if "-" in wall_clock:
63
- days, wall_clock = wall_clock.split("-")
64
- days = int(days)
65
- else:
66
- days = 0
67
- hours, minutes, seconds = map(int, wall_clock.split(":"))
68
- total_hours = days * 24 + hours + minutes / 60 + seconds / 3600
69
-
70
- return {"status": status, "total_hours": total_hours}
71
-
72
-
73
- def summarize_log(
74
- log_file: str,
75
- target_columns: list[str],
76
- burn_in: float = 0.1,
77
- hdi_prob: float = 0.95,
78
- hidden_nodes: list[int] | None = None,
79
- n_weights_samples: int = 100,
80
- n_features: dict[str, int] | None = None,
81
- job_id: str | None = None,
82
- ) -> tuple[dict[str, Any], dict[str, list[Weights]] | None]:
83
- df = pl.read_csv(log_file, separator="\t", comment_prefix="#")
84
- df = df.filter(pl.col("Sample") > burn_in * len(df))
85
- targets_df = df.select(target_columns)
86
- summary: dict[str, Any] = {"n_samples": len(df)}
87
- for column in targets_df.columns:
88
- summary[f"{column}_median"] = targets_df[column].median()
89
- summary[f"{column}_ess"] = az.ess( # pyright: ignore[reportUnknownMemberType]
90
- np.array(targets_df[column])
91
- )
92
- lower, upper = az.hdi( # pyright: ignore[reportUnknownMemberType]
93
- np.array(targets_df[column]), hdi_prob=hdi_prob
94
- )
95
- summary[f"{column}_lower"] = lower
96
- summary[f"{column}_upper"] = upper
97
- if job_id is not None:
98
- summary.update(get_job_metadata(job_id))
99
- if hidden_nodes is not None:
100
- if n_features is None:
101
- raise ValueError("`n_features` must be provided to summarize log weights.")
102
- weights: dict[str, list[Weights]] = {}
103
- for target, n in n_features.items():
104
- nodes = [n, *hidden_nodes, 1]
105
- layer_weights = [
106
- np.array(
107
- df.tail(n_weights_samples).select(
108
- c for c in df.columns if c.startswith(f"{target}W.{i}")
109
- )
110
- ).reshape(-1, n_inputs + 1, n_outputs)
111
- for i, (n_inputs, n_outputs) in enumerate(zip(nodes[:-1], nodes[1:]))
112
- ]
113
- weights[target] = [
114
- list(sample_weights) for sample_weights in zip(*layer_weights)
115
- ]
116
- return summary, weights
117
- return summary, None
118
-
119
-
120
- def summarize_logs(
121
- logs_dir: Path,
122
- target_columns: list[str],
123
- burn_in: float = 0.1,
124
- hdi_prob: float = 0.95,
125
- hidden_nodes: list[int] | None = None,
126
- n_weights_samples: int = 100,
127
- n_features: dict[str, int] | None = None,
128
- job_ids: dict[str, str] | None = None,
129
- ) -> tuple[pl.DataFrame, dict[str, list[list[Weights]]] | None]:
130
- def _get_log_summary(
131
- log_file: str,
132
- ) -> tuple[dict[str, Any], dict[str, list[Weights]] | None]:
133
- log_id = Path(log_file).stem
134
- summary, weights = summarize_log(
135
- log_file=log_file,
136
- target_columns=target_columns,
137
- burn_in=burn_in,
138
- hdi_prob=hdi_prob,
139
- hidden_nodes=hidden_nodes,
140
- n_weights_samples=n_weights_samples,
141
- n_features=n_features,
142
- job_id=job_ids[log_id] if job_ids is not None else None,
143
- )
144
- return {"id": log_id, **summary}, weights
145
-
146
- os.environ["POLARS_MAX_THREADS"] = "1"
147
- summaries = Parallel(n_jobs=-1)(
148
- delayed(_get_log_summary)(log_file)
149
- for log_file in tqdm(glob(str(logs_dir / "*.log")))
150
- )
151
- data, weights = zip(*summaries)
152
- if any(w is not None for w in weights):
153
- assert n_features is not None
154
- return pl.DataFrame(data), {t: [w[t] for w in weights] for t in n_features}
155
- return pl.DataFrame(data), None
156
-
157
-
158
- def set_plt_rcparams():
159
- plt.rcParams["pdf.fonttype"] = 42
160
- plt.rcParams["xtick.labelsize"] = 14
161
- plt.rcParams["ytick.labelsize"] = 14
162
- plt.rcParams["font.size"] = 14
163
- plt.rcParams["figure.constrained_layout.use"] = True
164
- plt.rcParams["lines.linewidth"] = 3