bella-companion 0.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of bella-companion might be problematic. Click here for more details.
- bella_companion/__init__.py +0 -0
- bella_companion/cli.py +24 -0
- bella_companion/fbd_empirical/data/body_mass.csv +1378 -0
- bella_companion/fbd_empirical/data/change_times.csv +22 -0
- bella_companion/fbd_empirical/data/sampling_change_times.csv +6 -0
- bella_companion/fbd_empirical/data/trees.nwk +100 -0
- bella_companion/fbd_empirical/figure.py +37 -0
- bella_companion/fbd_empirical/notbooks.ipynb +359 -0
- bella_companion/fbd_empirical/params.json +11 -0
- bella_companion/fbd_empirical/run_beast.py +54 -0
- bella_companion/fbd_empirical/summarize_logs.py +50 -0
- bella_companion/simulations/__init__.py +0 -0
- bella_companion/simulations/features.py +7 -0
- bella_companion/simulations/figures/__init__.py +0 -0
- bella_companion/simulations/figures/epi_explainations.py +101 -0
- bella_companion/simulations/figures/epi_predictions.py +58 -0
- bella_companion/simulations/figures/fbd_explainations.py +99 -0
- bella_companion/simulations/figures/fbd_predictions.py +66 -0
- bella_companion/simulations/figures/scenarios.py +87 -0
- bella_companion/simulations/figures/utils.py +250 -0
- bella_companion/simulations/generate_data.py +25 -0
- bella_companion/simulations/run_beast.py +92 -0
- bella_companion/simulations/scenarios/__init__.py +20 -0
- bella_companion/simulations/scenarios/common.py +29 -0
- bella_companion/simulations/scenarios/epi_multitype.py +68 -0
- bella_companion/simulations/scenarios/epi_skyline.py +65 -0
- bella_companion/simulations/scenarios/fbd_2traits.py +101 -0
- bella_companion/simulations/scenarios/fbd_no_traits.py +71 -0
- bella_companion/simulations/scenarios/scenario.py +26 -0
- bella_companion/simulations/summarize_logs.py +39 -0
- bella_companion/utils.py +164 -0
- bella_companion-0.0.0.dist-info/METADATA +13 -0
- bella_companion-0.0.0.dist-info/RECORD +34 -0
- bella_companion-0.0.0.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,250 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from itertools import product
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import matplotlib.pyplot as plt
|
|
6
|
+
import numpy as np
|
|
7
|
+
import polars as pl
|
|
8
|
+
import seaborn as sns
|
|
9
|
+
from joblib import Parallel, delayed
|
|
10
|
+
from lumiere.backend import (
|
|
11
|
+
ActivationFunction,
|
|
12
|
+
get_partial_dependence_values,
|
|
13
|
+
get_shap_features_importance,
|
|
14
|
+
sigmoid,
|
|
15
|
+
)
|
|
16
|
+
from lumiere.backend.typings import Weights
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
|
|
19
|
+
from src.simulations.features import Feature
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def _set_xticks(n: int, reverse: bool = False):
|
|
23
|
+
xticks_labels = range(n)
|
|
24
|
+
if reverse:
|
|
25
|
+
xticks_labels = reversed(xticks_labels)
|
|
26
|
+
plt.xticks(ticks=range(n), labels=list(map(str, xticks_labels)))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def step(
|
|
30
|
+
x: list[float],
|
|
31
|
+
reverse_xticks: bool = False,
|
|
32
|
+
**kwargs: dict[str, Any],
|
|
33
|
+
):
|
|
34
|
+
data = x.copy()
|
|
35
|
+
data.insert(0, data[0])
|
|
36
|
+
plt.step(list(range(len(data))), data, **kwargs)
|
|
37
|
+
_set_xticks(len(data), reverse_xticks)
|
|
38
|
+
plt.xlabel("Time bin")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def _count_time_bins(true_values: dict[str, list[float]]) -> int:
|
|
42
|
+
assert (
|
|
43
|
+
len({len(true_value) for true_value in true_values.values()}) == 1
|
|
44
|
+
), "All targets must have the same number of change times."
|
|
45
|
+
return len(next(iter((true_values.values()))))
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def plot_maes_per_time_bin(
|
|
49
|
+
logs_summaries: dict[str, pl.DataFrame],
|
|
50
|
+
true_values: dict[str, list[float]],
|
|
51
|
+
output_filepath: str,
|
|
52
|
+
reverse_xticks: bool = False,
|
|
53
|
+
):
|
|
54
|
+
def _mae(target: str, i: int) -> pl.Expr:
|
|
55
|
+
return (pl.col(f"{target}i{i}_median") - true_values[target][i]).abs()
|
|
56
|
+
|
|
57
|
+
n_time_bins = _count_time_bins(true_values)
|
|
58
|
+
df = pl.concat(
|
|
59
|
+
logs_summaries[model]
|
|
60
|
+
.select(
|
|
61
|
+
pl.mean_horizontal([_mae(target, i) for target in true_values]).alias("MAE")
|
|
62
|
+
)
|
|
63
|
+
.with_columns(pl.lit(i).alias("Time bin"), pl.lit(model).alias("Model"))
|
|
64
|
+
for i in range(n_time_bins)
|
|
65
|
+
for model in logs_summaries
|
|
66
|
+
)
|
|
67
|
+
sns.violinplot(
|
|
68
|
+
x="Time bin",
|
|
69
|
+
y="MAE",
|
|
70
|
+
hue="Model",
|
|
71
|
+
data=df,
|
|
72
|
+
inner=None,
|
|
73
|
+
cut=0,
|
|
74
|
+
density_norm="width",
|
|
75
|
+
legend=False,
|
|
76
|
+
)
|
|
77
|
+
_set_xticks(n_time_bins, reverse_xticks)
|
|
78
|
+
plt.savefig(output_filepath)
|
|
79
|
+
plt.close()
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def plot_coverage_per_time_bin(
|
|
83
|
+
logs_summaries: dict[str, pl.DataFrame],
|
|
84
|
+
true_values: dict[str, list[float]],
|
|
85
|
+
output_filepath: str,
|
|
86
|
+
reverse_xticks: bool = False,
|
|
87
|
+
):
|
|
88
|
+
def _coverage(model: str, target: str, i: int) -> float:
|
|
89
|
+
lower_bound = logs_summaries[model][f"{target}i{i}_lower"]
|
|
90
|
+
upper_bound = logs_summaries[model][f"{target}i{i}_upper"]
|
|
91
|
+
true_value = true_values[target][i]
|
|
92
|
+
N = len(logs_summaries[model])
|
|
93
|
+
return ((lower_bound <= true_value) & (true_value <= upper_bound)).sum() / N
|
|
94
|
+
|
|
95
|
+
n_time_bins = _count_time_bins(true_values)
|
|
96
|
+
for model in logs_summaries:
|
|
97
|
+
avg_coverage_by_time_bin = [
|
|
98
|
+
np.mean([_coverage(model, target, i) for target in true_values])
|
|
99
|
+
for i in range(_count_time_bins(true_values))
|
|
100
|
+
]
|
|
101
|
+
plt.plot(avg_coverage_by_time_bin, marker="o")
|
|
102
|
+
|
|
103
|
+
_set_xticks(n_time_bins, reverse_xticks)
|
|
104
|
+
plt.xlabel("Time bin")
|
|
105
|
+
plt.ylabel("Coverage")
|
|
106
|
+
plt.ylim((0, 1.05))
|
|
107
|
+
plt.savefig(output_filepath)
|
|
108
|
+
plt.close()
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def plot_partial_dependencies(
|
|
112
|
+
weights: list[list[Weights]], # shape: (n_mcmcs, n_weights_samples, ...)
|
|
113
|
+
features: dict[str, Feature],
|
|
114
|
+
output_dir: str,
|
|
115
|
+
hidden_activation: ActivationFunction = sigmoid,
|
|
116
|
+
output_activation: ActivationFunction = sigmoid,
|
|
117
|
+
):
|
|
118
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
119
|
+
features_grid = [feature.grid for feature in features.values()]
|
|
120
|
+
|
|
121
|
+
def _get_median_partial_dependence_values(
|
|
122
|
+
weights: list[Weights],
|
|
123
|
+
) -> list[list[float]]:
|
|
124
|
+
pdvalues_distribution = [
|
|
125
|
+
get_partial_dependence_values(
|
|
126
|
+
weights=w,
|
|
127
|
+
features_grid=features_grid,
|
|
128
|
+
hidden_activation=hidden_activation,
|
|
129
|
+
output_activation=output_activation,
|
|
130
|
+
)
|
|
131
|
+
for w in weights
|
|
132
|
+
]
|
|
133
|
+
return [
|
|
134
|
+
np.median(
|
|
135
|
+
[pdvalues[feature_idx] for pdvalues in pdvalues_distribution], axis=0
|
|
136
|
+
).tolist()
|
|
137
|
+
for feature_idx in range(len(features))
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
jobs = Parallel(n_jobs=-1)(
|
|
141
|
+
delayed(_get_median_partial_dependence_values)(w) for w in weights
|
|
142
|
+
)
|
|
143
|
+
pdvalues = [
|
|
144
|
+
job for job in tqdm(jobs, total=len(weights), desc="Evaluating PDPs")
|
|
145
|
+
] # shape: (n_mcmcs, n_features, n_grid_points)
|
|
146
|
+
pdvalues = [
|
|
147
|
+
np.array(mcmc_pds).T for mcmc_pds in zip(*pdvalues)
|
|
148
|
+
] # shape: (n_features, n_grid_points, n_mcmcs)
|
|
149
|
+
|
|
150
|
+
for feature_idx, (feature_name, feature) in enumerate(features.items()):
|
|
151
|
+
color = "red" if feature.is_relevant else "gray"
|
|
152
|
+
feature_pdvalues = pdvalues[feature_idx] # shape: (n_grid_points, n_mcmcs)
|
|
153
|
+
if not feature.is_categorical:
|
|
154
|
+
median = np.median(feature_pdvalues, axis=1)
|
|
155
|
+
lower = np.percentile(feature_pdvalues, 2.5, axis=1)
|
|
156
|
+
high = np.percentile(feature_pdvalues, 100 - 2.5, axis=1)
|
|
157
|
+
plt.fill_between(feature.grid, lower, high, alpha=0.25, color=color)
|
|
158
|
+
for mcmc_pds in feature_pdvalues.T:
|
|
159
|
+
plt.plot(
|
|
160
|
+
feature.grid,
|
|
161
|
+
mcmc_pds,
|
|
162
|
+
color=color,
|
|
163
|
+
alpha=0.2,
|
|
164
|
+
linewidth=1,
|
|
165
|
+
)
|
|
166
|
+
plt.plot(feature.grid, median, color=color, label=feature_name)
|
|
167
|
+
plt.xlabel("Feature value")
|
|
168
|
+
plt.ylabel(f"MLP Output")
|
|
169
|
+
plt.legend()
|
|
170
|
+
plt.savefig(os.path.join(output_dir, "PDPs-continuous.svg"))
|
|
171
|
+
plt.close()
|
|
172
|
+
|
|
173
|
+
plot_data = []
|
|
174
|
+
grid_labels = []
|
|
175
|
+
list_labels = []
|
|
176
|
+
for feature_idx, (feature_name, feature) in enumerate(features.items()):
|
|
177
|
+
feature_pdvalues = pdvalues[feature_idx] # shape: (n_grid_points, n_mcmcs)
|
|
178
|
+
if feature.is_categorical:
|
|
179
|
+
for i, grid_point in enumerate(feature.grid):
|
|
180
|
+
plot_data.extend(feature_pdvalues[i])
|
|
181
|
+
grid_labels.extend([grid_point] * len(feature_pdvalues[i]))
|
|
182
|
+
list_labels.extend([feature_name] * len(feature_pdvalues[i]))
|
|
183
|
+
if not (any(feature.is_categorical for feature in features.values())):
|
|
184
|
+
return
|
|
185
|
+
sns.violinplot(
|
|
186
|
+
x=grid_labels,
|
|
187
|
+
y=plot_data,
|
|
188
|
+
hue=list_labels,
|
|
189
|
+
split=False,
|
|
190
|
+
cut=0,
|
|
191
|
+
palette={
|
|
192
|
+
feature_name: "red" if feature.is_relevant else "gray"
|
|
193
|
+
for feature_name, feature in features.items()
|
|
194
|
+
if feature.is_categorical
|
|
195
|
+
},
|
|
196
|
+
)
|
|
197
|
+
plt.xlabel("Feature value")
|
|
198
|
+
plt.ylabel(f"MLP Output")
|
|
199
|
+
plt.legend()
|
|
200
|
+
plt.savefig(os.path.join(output_dir, "PDPs-categorical.svg"))
|
|
201
|
+
plt.close()
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def plot_shap_features_importance(
|
|
205
|
+
weights: list[list[Weights]], # shape: (n_mcmcs, n_weights_samples, ...)
|
|
206
|
+
features: dict[str, Feature],
|
|
207
|
+
output_file: str,
|
|
208
|
+
hidden_activation: ActivationFunction = sigmoid,
|
|
209
|
+
output_activation: ActivationFunction = sigmoid,
|
|
210
|
+
):
|
|
211
|
+
features_grid = [feature.grid for feature in features.values()]
|
|
212
|
+
inputs = list(product(*features_grid))
|
|
213
|
+
|
|
214
|
+
def _get_median_shap_features_importance(
|
|
215
|
+
weights: list[Weights],
|
|
216
|
+
) -> list[list[float]]:
|
|
217
|
+
features_importance = np.array(
|
|
218
|
+
[
|
|
219
|
+
get_shap_features_importance(
|
|
220
|
+
weights=w,
|
|
221
|
+
inputs=inputs,
|
|
222
|
+
hidden_activation=hidden_activation,
|
|
223
|
+
output_activation=output_activation,
|
|
224
|
+
)
|
|
225
|
+
for w in weights
|
|
226
|
+
]
|
|
227
|
+
) # shape: (n_weights_samples, n_features)
|
|
228
|
+
return np.median(features_importance, axis=0).tolist() # shape: (n_features,)
|
|
229
|
+
|
|
230
|
+
jobs = Parallel(n_jobs=-1, return_as="generator_unordered")(
|
|
231
|
+
delayed(_get_median_shap_features_importance)(w) for w in weights
|
|
232
|
+
)
|
|
233
|
+
features_importance_distribution = np.array(
|
|
234
|
+
[job for job in tqdm(jobs, total=len(weights), desc="Evaluating SHAPs")]
|
|
235
|
+
) # shape: (n_mcmcs, n_features)
|
|
236
|
+
features_importance_distribution /= features_importance_distribution.sum(
|
|
237
|
+
axis=1, keepdims=True
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
for i, (feature_name, feature) in enumerate(features.items()):
|
|
241
|
+
sns.violinplot(
|
|
242
|
+
y=features_importance_distribution[:, i],
|
|
243
|
+
x=[feature_name] * features_importance_distribution.shape[0],
|
|
244
|
+
cut=0,
|
|
245
|
+
color="red" if feature.is_relevant else "gray",
|
|
246
|
+
)
|
|
247
|
+
plt.xlabel("Feature")
|
|
248
|
+
plt.ylabel("Importance")
|
|
249
|
+
plt.savefig(output_file)
|
|
250
|
+
plt.close()
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import os
|
|
2
|
+
|
|
3
|
+
from phylogenie import generate_trees
|
|
4
|
+
|
|
5
|
+
from bella_companion.simulations.scenarios import SCENARIOS, ScenarioType
|
|
6
|
+
|
|
7
|
+
N_TREES = 100
|
|
8
|
+
MIN_TIPS = 200
|
|
9
|
+
MAX_TIPS = 500
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def generate_data():
|
|
13
|
+
base_output_dir = os.environ["BELLA_SIMULATIONS_DATA_DIR"]
|
|
14
|
+
for scenario_name, scenario in SCENARIOS.items():
|
|
15
|
+
generate_trees(
|
|
16
|
+
output_dir=os.path.join(base_output_dir, scenario_name),
|
|
17
|
+
n_trees=N_TREES,
|
|
18
|
+
events=scenario.events,
|
|
19
|
+
init_state=scenario.init_state,
|
|
20
|
+
sampling_probability_at_present=int(scenario.type == ScenarioType.FBD),
|
|
21
|
+
max_time=scenario.max_time,
|
|
22
|
+
min_tips=MIN_TIPS,
|
|
23
|
+
max_tips=MAX_TIPS,
|
|
24
|
+
seed=42,
|
|
25
|
+
)
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from glob import glob
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from numpy.random import default_rng
|
|
8
|
+
from phylogenie import Tree, load_newick
|
|
9
|
+
from phylogenie.utils import get_node_depths
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
import config as cfg
|
|
13
|
+
from bella_companion.simulations.scenarios import SCENARIOS, ScenarioType
|
|
14
|
+
from bella_companion.utils import run_sbatch
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def main():
|
|
18
|
+
rng = default_rng(42)
|
|
19
|
+
job_ids = {}
|
|
20
|
+
for scenario_name, scenario in SCENARIOS.items():
|
|
21
|
+
job_ids[scenario_name] = defaultdict(dict)
|
|
22
|
+
data_dir = cfg.SIMULATED_DATA_DIR / scenario_name
|
|
23
|
+
inference_configs_dir = (
|
|
24
|
+
scenario_name.split("_")[0] if "_" in scenario_name else scenario_name
|
|
25
|
+
)
|
|
26
|
+
for tree_file in tqdm(
|
|
27
|
+
glob(str(data_dir / "*.nwk")),
|
|
28
|
+
desc=f"Submitting BEAST2 jobs for {scenario_name}",
|
|
29
|
+
):
|
|
30
|
+
tree_id = Path(tree_file).stem
|
|
31
|
+
for model in ["Nonparametric", "GLM"] + [
|
|
32
|
+
f"MLP-{hidden_nodes}" for hidden_nodes in ["3_2", "16_8", "32_16"]
|
|
33
|
+
]:
|
|
34
|
+
outputs_dir = cfg.BEAST_OUTPUTS_DIR / scenario_name / model
|
|
35
|
+
os.makedirs(outputs_dir, exist_ok=True)
|
|
36
|
+
beast_args = [
|
|
37
|
+
f"-D treeFile={tree_file},treeID={tree_id}",
|
|
38
|
+
f"-prefix {outputs_dir}{os.sep}",
|
|
39
|
+
]
|
|
40
|
+
beast_args.extend(
|
|
41
|
+
[
|
|
42
|
+
f'-D {key}="{value}"'
|
|
43
|
+
for key, value in scenario.beast_args.items()
|
|
44
|
+
]
|
|
45
|
+
)
|
|
46
|
+
beast_args.append(
|
|
47
|
+
f'-D randomPredictor="{" ".join(map(str, scenario.get_random_predictor(rng)))}"'
|
|
48
|
+
)
|
|
49
|
+
if scenario.type == ScenarioType.EPI:
|
|
50
|
+
tree = load_newick(tree_file)
|
|
51
|
+
assert isinstance(tree, Tree)
|
|
52
|
+
beast_args.append(
|
|
53
|
+
f"-D lastSampleTime={max(get_node_depths(tree).values())}"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if model in ["Nonparametric", "GLM"]:
|
|
57
|
+
command = " ".join(
|
|
58
|
+
[
|
|
59
|
+
cfg.RUN_BEAST,
|
|
60
|
+
*beast_args,
|
|
61
|
+
str(
|
|
62
|
+
cfg.BEAST_CONFIGS_DIR
|
|
63
|
+
/ inference_configs_dir
|
|
64
|
+
/ f"{model}.xml"
|
|
65
|
+
),
|
|
66
|
+
]
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
nodes = model.split("-")[1].split("_")
|
|
70
|
+
command = " ".join(
|
|
71
|
+
[
|
|
72
|
+
cfg.RUN_BEAST,
|
|
73
|
+
*beast_args,
|
|
74
|
+
f'-D nodes="{" ".join(map(str, nodes))}"',
|
|
75
|
+
str(
|
|
76
|
+
cfg.BEAST_CONFIGS_DIR
|
|
77
|
+
/ inference_configs_dir
|
|
78
|
+
/ "MLP.xml"
|
|
79
|
+
),
|
|
80
|
+
]
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
job_ids[scenario_name][model][tree_id] = run_sbatch(
|
|
84
|
+
command, cfg.SBATCH_LOGS_DIR / scenario_name / model / tree_id
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
with open(cfg.BEAST_OUTPUTS_DIR / "simulations_job_ids.json", "w") as f:
|
|
88
|
+
json.dump(job_ids, f)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
if __name__ == "__main__":
|
|
92
|
+
main()
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from bella_companion.simulations.scenarios.epi_multitype import SCENARIO as EPI_MULTITYPE_SCENARIO
|
|
2
|
+
from bella_companion.simulations.scenarios.epi_skyline import SCENARIOS as EPI_SKYLINE_SCENARIOS
|
|
3
|
+
from bella_companion.simulations.scenarios.fbd_2traits import SCENARIO as FBD_2TRAITS_SCENARIO
|
|
4
|
+
from bella_companion.simulations.scenarios.fbd_no_traits import SCENARIOS as FBD_NO_TRAITS_SCENARIOS
|
|
5
|
+
from bella_companion.simulations.scenarios.scenario import Scenario, ScenarioType
|
|
6
|
+
|
|
7
|
+
SCENARIOS = {
|
|
8
|
+
**{
|
|
9
|
+
f"epi-skyline_{i}": scenario
|
|
10
|
+
for i, scenario in enumerate(EPI_SKYLINE_SCENARIOS, start=1)
|
|
11
|
+
},
|
|
12
|
+
"epi-multitype": EPI_MULTITYPE_SCENARIO,
|
|
13
|
+
**{
|
|
14
|
+
f"fbd-no-traits_{i}": scenario
|
|
15
|
+
for i, scenario in enumerate(FBD_NO_TRAITS_SCENARIOS, start=1)
|
|
16
|
+
},
|
|
17
|
+
"fbd-2traits": FBD_2TRAITS_SCENARIO,
|
|
18
|
+
}
|
|
19
|
+
|
|
20
|
+
__all__ = ["SCENARIOS", "Scenario", "ScenarioType"]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from numpy.random import Generator
|
|
5
|
+
|
|
6
|
+
EPI_MAX_TIME = 250
|
|
7
|
+
EPI_SAMPLING_PROPORTION = 0.15
|
|
8
|
+
BECOME_UNINFECTIOUS_RATE = 0.07
|
|
9
|
+
|
|
10
|
+
FBD_MAX_TIME = 35
|
|
11
|
+
FBD_SAMPLING_RATE = 0.2
|
|
12
|
+
FBD_RATE_UPPER = 2
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_start_type_prior_probabilities(types: list[str], init_type: str):
|
|
16
|
+
start_type_prior_probabilities = ["0"] * len(types)
|
|
17
|
+
start_type_prior_probabilities[types.index(init_type)] = "1"
|
|
18
|
+
return " ".join(start_type_prior_probabilities)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_random_time_series_predictor(rng: Generator, n_time_bins: int) -> list[float]:
|
|
22
|
+
return np.cumsum(rng.normal(size=n_time_bins)).tolist()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_prior_params(target: str, upper: float, n: int) -> dict[str, Any]:
|
|
26
|
+
return {
|
|
27
|
+
f"{target}Upper": upper,
|
|
28
|
+
f"{target}Init": " ".join([str(upper / 2)] * n),
|
|
29
|
+
}
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from numpy.random import Generator
|
|
3
|
+
from phylogenie import get_epidemiological_events
|
|
4
|
+
|
|
5
|
+
from bella_companion.simulations.features import Feature
|
|
6
|
+
from bella_companion.simulations.scenarios.common import (
|
|
7
|
+
BECOME_UNINFECTIOUS_RATE,
|
|
8
|
+
EPI_MAX_TIME,
|
|
9
|
+
EPI_SAMPLING_PROPORTION,
|
|
10
|
+
get_prior_params,
|
|
11
|
+
get_start_type_prior_probabilities,
|
|
12
|
+
)
|
|
13
|
+
from bella_companion.simulations.scenarios.scenario import Scenario, ScenarioType
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def _get_random_predictor(rng: Generator) -> list[float]:
|
|
17
|
+
return rng.uniform(-1, 1, N_TYPE_PAIRS).tolist()
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
TYPES = ["A", "B", "C", "D", "E"]
|
|
21
|
+
_REPRODUCTION_NUMBERS = [0.8, 1.0, 1.2, 1.4, 1.6]
|
|
22
|
+
_INIT_TYPE = "C"
|
|
23
|
+
N_TYPES = len(TYPES)
|
|
24
|
+
N_TYPE_PAIRS = N_TYPES * (N_TYPES - 1)
|
|
25
|
+
MIGRATION_PREDICTOR = np.random.default_rng(42).uniform(-1, 1, (N_TYPES, N_TYPES - 1))
|
|
26
|
+
_MIGRATION_SIGMOID_AMPLITUDE = 0.04
|
|
27
|
+
_MIGRATION_SIGMOID_SCALE = -8
|
|
28
|
+
MIGRATION_RATES = _MIGRATION_SIGMOID_AMPLITUDE / (
|
|
29
|
+
1 + np.exp(_MIGRATION_SIGMOID_SCALE * MIGRATION_PREDICTOR)
|
|
30
|
+
)
|
|
31
|
+
MIGRATION_RATE_UPPER = 0.05
|
|
32
|
+
|
|
33
|
+
SCENARIO = Scenario(
|
|
34
|
+
type=ScenarioType.EPI,
|
|
35
|
+
max_time=EPI_MAX_TIME,
|
|
36
|
+
init_state=_INIT_TYPE,
|
|
37
|
+
events=get_epidemiological_events(
|
|
38
|
+
states=TYPES,
|
|
39
|
+
sampling_proportions=EPI_SAMPLING_PROPORTION,
|
|
40
|
+
reproduction_numbers=_REPRODUCTION_NUMBERS,
|
|
41
|
+
become_uninfectious_rates=BECOME_UNINFECTIOUS_RATE,
|
|
42
|
+
migration_rates=MIGRATION_RATES.tolist(),
|
|
43
|
+
),
|
|
44
|
+
get_random_predictor=_get_random_predictor,
|
|
45
|
+
beast_args={
|
|
46
|
+
"types": ",".join(TYPES),
|
|
47
|
+
"startTypePriorProbs": get_start_type_prior_probabilities(TYPES, _INIT_TYPE),
|
|
48
|
+
"processLength": EPI_MAX_TIME,
|
|
49
|
+
**get_prior_params("migrationRate", MIGRATION_RATE_UPPER, N_TYPE_PAIRS),
|
|
50
|
+
"reproductionNumber": " ".join(map(str, _REPRODUCTION_NUMBERS)),
|
|
51
|
+
"becomeUninfectiousRate": BECOME_UNINFECTIOUS_RATE,
|
|
52
|
+
"samplingProportion": EPI_SAMPLING_PROPORTION,
|
|
53
|
+
"migrationPredictor": " ".join(map(str, MIGRATION_PREDICTOR.flatten())),
|
|
54
|
+
},
|
|
55
|
+
targets={
|
|
56
|
+
"migrationRate": {
|
|
57
|
+
f"migrationRateSP{t1}_to_{t2}": MIGRATION_RATES[i, j]
|
|
58
|
+
for i, t1 in enumerate(TYPES)
|
|
59
|
+
for j, t2 in enumerate([t for t in TYPES if t != t1])
|
|
60
|
+
}
|
|
61
|
+
},
|
|
62
|
+
features={
|
|
63
|
+
"migrationRate": {
|
|
64
|
+
"migrationPredictor": Feature(is_binary=False, is_relevant=True),
|
|
65
|
+
"randomPredictor": Feature(is_binary=False, is_relevant=False),
|
|
66
|
+
}
|
|
67
|
+
},
|
|
68
|
+
)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
from phylogenie import SkylineParameter, get_epidemiological_events
|
|
5
|
+
|
|
6
|
+
from bella_companion.simulations.features import Feature
|
|
7
|
+
from bella_companion.simulations.scenarios.common import (
|
|
8
|
+
BECOME_UNINFECTIOUS_RATE,
|
|
9
|
+
EPI_MAX_TIME,
|
|
10
|
+
EPI_SAMPLING_PROPORTION,
|
|
11
|
+
get_prior_params,
|
|
12
|
+
get_random_time_series_predictor,
|
|
13
|
+
)
|
|
14
|
+
from bella_companion.simulations.scenarios.scenario import Scenario, ScenarioType
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _get_scenario(reproduction_number: list[float]) -> Scenario:
|
|
18
|
+
n_time_bins = len(reproduction_number)
|
|
19
|
+
change_times = np.linspace(0, EPI_MAX_TIME, n_time_bins + 1)[1:-1].tolist()
|
|
20
|
+
return Scenario(
|
|
21
|
+
type=ScenarioType.EPI,
|
|
22
|
+
max_time=EPI_MAX_TIME,
|
|
23
|
+
events=get_epidemiological_events(
|
|
24
|
+
states=["X"],
|
|
25
|
+
sampling_proportions=EPI_SAMPLING_PROPORTION,
|
|
26
|
+
reproduction_numbers=SkylineParameter(reproduction_number, change_times),
|
|
27
|
+
become_uninfectious_rates=BECOME_UNINFECTIOUS_RATE,
|
|
28
|
+
),
|
|
29
|
+
get_random_predictor=partial(
|
|
30
|
+
get_random_time_series_predictor, n_time_bins=n_time_bins
|
|
31
|
+
),
|
|
32
|
+
beast_args={
|
|
33
|
+
"processLength": EPI_MAX_TIME,
|
|
34
|
+
"changeTimes": " ".join(map(str, change_times)),
|
|
35
|
+
**get_prior_params(
|
|
36
|
+
"reproductionNumber", REPRODUCTION_NUMBER_UPPER, n_time_bins
|
|
37
|
+
),
|
|
38
|
+
"becomeUninfectiousRate": BECOME_UNINFECTIOUS_RATE,
|
|
39
|
+
"samplingProportion": EPI_SAMPLING_PROPORTION,
|
|
40
|
+
"timePredictor": " ".join(map(str, np.linspace(0, 1, n_time_bins))),
|
|
41
|
+
},
|
|
42
|
+
targets={
|
|
43
|
+
"reproductionNumber": {
|
|
44
|
+
f"reproductionNumberSPi{i}": r
|
|
45
|
+
for i, r in enumerate(reproduction_number)
|
|
46
|
+
}
|
|
47
|
+
},
|
|
48
|
+
features={
|
|
49
|
+
"reproductionNumber": {
|
|
50
|
+
"timePredictor": Feature(
|
|
51
|
+
is_binary=False, is_relevant=len(set(reproduction_number)) > 1
|
|
52
|
+
),
|
|
53
|
+
"randomPredictor": Feature(is_binary=False, is_relevant=False),
|
|
54
|
+
}
|
|
55
|
+
},
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
REPRODUCTION_NUMBERS: list[list[float]] = [
|
|
60
|
+
[1.2] * 10,
|
|
61
|
+
np.linspace(1.5, 1.0, 10).tolist(),
|
|
62
|
+
np.linspace(1.2, 1.5, 5).tolist() + np.linspace(1.5, 1.0, 5).tolist(),
|
|
63
|
+
]
|
|
64
|
+
REPRODUCTION_NUMBER_UPPER = 5
|
|
65
|
+
SCENARIOS = [_get_scenario(r) for r in REPRODUCTION_NUMBERS]
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from numpy.random import Generator
|
|
3
|
+
from phylogenie import SkylineVector, get_canonical_events
|
|
4
|
+
|
|
5
|
+
from bella_companion.simulations.features import Feature
|
|
6
|
+
from bella_companion.simulations.scenarios.common import (
|
|
7
|
+
FBD_MAX_TIME,
|
|
8
|
+
FBD_RATE_UPPER,
|
|
9
|
+
FBD_SAMPLING_RATE,
|
|
10
|
+
get_prior_params,
|
|
11
|
+
get_random_time_series_predictor,
|
|
12
|
+
get_start_type_prior_probabilities,
|
|
13
|
+
)
|
|
14
|
+
from bella_companion.simulations.scenarios.scenario import Scenario, ScenarioType
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _get_random_predictor(rng: Generator) -> list[float]:
|
|
18
|
+
return np.repeat(
|
|
19
|
+
get_random_time_series_predictor(rng, N_TIME_BINS), N_STATES
|
|
20
|
+
).tolist()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
STATES = ["00", "01", "10", "11"]
|
|
24
|
+
_INIT_STATE = "00"
|
|
25
|
+
N_STATES = len(STATES)
|
|
26
|
+
N_TIME_BINS = 10
|
|
27
|
+
_CHANGE_TIMES = np.linspace(0, FBD_MAX_TIME, N_TIME_BINS + 1)[1:-1].tolist()
|
|
28
|
+
BIRTH_RATE_TRAIT1_UNSET = np.linspace(0.6, 0.1, N_TIME_BINS).tolist()
|
|
29
|
+
BIRTH_RATE_TRAIT1_SET = np.linspace(0.3, 0.05, N_TIME_BINS).tolist()
|
|
30
|
+
DEATH_RATE_TRAIT1_UNSET = np.linspace(0.1, 0.4, N_TIME_BINS).tolist()
|
|
31
|
+
DEATH_RATE_TRAIT1_SET = np.linspace(0.1, 0.2, N_TIME_BINS).tolist()
|
|
32
|
+
BIRTH_RATES = {
|
|
33
|
+
"00": BIRTH_RATE_TRAIT1_UNSET,
|
|
34
|
+
"01": BIRTH_RATE_TRAIT1_UNSET,
|
|
35
|
+
"10": BIRTH_RATE_TRAIT1_SET,
|
|
36
|
+
"11": BIRTH_RATE_TRAIT1_SET,
|
|
37
|
+
}
|
|
38
|
+
DEATH_RATES = {
|
|
39
|
+
"00": DEATH_RATE_TRAIT1_UNSET,
|
|
40
|
+
"01": DEATH_RATE_TRAIT1_UNSET,
|
|
41
|
+
"10": BIRTH_RATE_TRAIT1_SET,
|
|
42
|
+
"11": BIRTH_RATE_TRAIT1_SET,
|
|
43
|
+
}
|
|
44
|
+
RATES = {
|
|
45
|
+
"birth": BIRTH_RATES,
|
|
46
|
+
"death": DEATH_RATES,
|
|
47
|
+
}
|
|
48
|
+
_MIGRATION_RATES = (
|
|
49
|
+
np.array([[1, 1, 0], [1, 0, 1], [1, 0, 1], [0, 1, 1]]) * 0.1
|
|
50
|
+
).tolist()
|
|
51
|
+
|
|
52
|
+
SCENARIO = Scenario(
|
|
53
|
+
type=ScenarioType.FBD,
|
|
54
|
+
max_time=FBD_MAX_TIME,
|
|
55
|
+
init_state=_INIT_STATE,
|
|
56
|
+
events=get_canonical_events(
|
|
57
|
+
states=STATES,
|
|
58
|
+
sampling_rates=FBD_SAMPLING_RATE,
|
|
59
|
+
remove_after_sampling=False,
|
|
60
|
+
birth_rates=SkylineVector(
|
|
61
|
+
value=list(zip(*BIRTH_RATES.values())), change_times=_CHANGE_TIMES
|
|
62
|
+
),
|
|
63
|
+
death_rates=SkylineVector(
|
|
64
|
+
value=list(zip(*DEATH_RATES.values())), change_times=_CHANGE_TIMES
|
|
65
|
+
),
|
|
66
|
+
migration_rates=_MIGRATION_RATES,
|
|
67
|
+
),
|
|
68
|
+
get_random_predictor=_get_random_predictor,
|
|
69
|
+
beast_args={
|
|
70
|
+
"types": ",".join(STATES),
|
|
71
|
+
"startTypePriorProbs": get_start_type_prior_probabilities(STATES, _INIT_STATE),
|
|
72
|
+
"processLength": FBD_MAX_TIME,
|
|
73
|
+
"changeTimes": " ".join(map(str, _CHANGE_TIMES)),
|
|
74
|
+
**get_prior_params("birthRate", FBD_RATE_UPPER, N_TIME_BINS * N_STATES),
|
|
75
|
+
**get_prior_params("deathRate", FBD_RATE_UPPER, N_TIME_BINS * N_STATES),
|
|
76
|
+
"samplingRate": FBD_SAMPLING_RATE,
|
|
77
|
+
"migrationRate": " ".join(map(str, np.array(_MIGRATION_RATES).flatten())),
|
|
78
|
+
"timePredictor": " ".join(
|
|
79
|
+
list(map(str, np.repeat(np.linspace(0, 1, N_TIME_BINS), N_STATES)))
|
|
80
|
+
),
|
|
81
|
+
"trait1Predictor": " ".join(map(str, [1, 1, 0, 0] * N_TIME_BINS)),
|
|
82
|
+
"trait2Predictor": " ".join(map(str, [0, 1, 0, 1] * N_TIME_BINS)),
|
|
83
|
+
},
|
|
84
|
+
targets={
|
|
85
|
+
f"{rate}Rate": {
|
|
86
|
+
f"{rate}RateSPi{i}_{s}": values[s][i]
|
|
87
|
+
for i in range(N_TIME_BINS)
|
|
88
|
+
for s in STATES
|
|
89
|
+
}
|
|
90
|
+
for rate, values in RATES.items()
|
|
91
|
+
},
|
|
92
|
+
features={
|
|
93
|
+
f"{rate}Rate": {
|
|
94
|
+
"timePredictor": Feature(is_binary=False, is_relevant=True),
|
|
95
|
+
"trait1Predictor": Feature(is_binary=True, is_relevant=True),
|
|
96
|
+
"trait2Predictor": Feature(is_binary=True, is_relevant=False),
|
|
97
|
+
"randomPredictor": Feature(is_binary=False, is_relevant=False),
|
|
98
|
+
}
|
|
99
|
+
for rate in RATES
|
|
100
|
+
},
|
|
101
|
+
)
|