bella-companion 0.0.4__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of bella-companion might be problematic. Click here for more details.

Files changed (29) hide show
  1. bella_companion/cli.py +13 -4
  2. bella_companion/fbd_empirical/run_beast.py +1 -1
  3. bella_companion/simulations/__init__.py +2 -1
  4. bella_companion/simulations/figures/__init__.py +21 -0
  5. bella_companion/simulations/figures/epi_multitype_results.py +81 -0
  6. bella_companion/simulations/figures/epi_skyline_results.py +46 -0
  7. bella_companion/simulations/figures/explain/__init__.py +6 -0
  8. bella_companion/simulations/figures/explain/pdp.py +101 -0
  9. bella_companion/simulations/figures/explain/shap.py +56 -0
  10. bella_companion/simulations/figures/fbd_2traits_results.py +83 -0
  11. bella_companion/simulations/figures/fbd_no_traits_results.py +58 -0
  12. bella_companion/simulations/figures/scenarios.py +38 -32
  13. bella_companion/simulations/generate_figures.py +24 -0
  14. bella_companion/simulations/run_beast.py +1 -0
  15. bella_companion/simulations/scenarios/fbd_2traits.py +2 -2
  16. bella_companion/utils/__init__.py +20 -1
  17. bella_companion/utils/beast.py +1 -1
  18. bella_companion/utils/explain.py +45 -0
  19. bella_companion/utils/plots.py +98 -0
  20. {bella_companion-0.0.4.dist-info → bella_companion-0.0.6.dist-info}/METADATA +4 -3
  21. bella_companion-0.0.6.dist-info/RECORD +42 -0
  22. bella_companion/simulations/figures/epi_explainations.py +0 -109
  23. bella_companion/simulations/figures/epi_predictions.py +0 -58
  24. bella_companion/simulations/figures/fbd_explainations.py +0 -99
  25. bella_companion/simulations/figures/fbd_predictions.py +0 -66
  26. bella_companion/simulations/figures/utils.py +0 -250
  27. bella_companion-0.0.4.dist-info/RECORD +0 -37
  28. {bella_companion-0.0.4.dist-info → bella_companion-0.0.6.dist-info}/WHEEL +0 -0
  29. {bella_companion-0.0.4.dist-info → bella_companion-0.0.6.dist-info}/entry_points.txt +0 -0
@@ -1,250 +0,0 @@
1
- import os
2
- from itertools import product
3
- from typing import Any
4
-
5
- import matplotlib.pyplot as plt
6
- import numpy as np
7
- import polars as pl
8
- import seaborn as sns
9
- from joblib import Parallel, delayed
10
- from lumiere.backend import (
11
- ActivationFunction,
12
- get_partial_dependence_values,
13
- get_shap_features_importance,
14
- sigmoid,
15
- )
16
- from lumiere.backend.typings import Weights
17
- from tqdm import tqdm
18
-
19
- from src.simulations.features import Feature
20
-
21
-
22
- def _set_xticks(n: int, reverse: bool = False):
23
- xticks_labels = range(n)
24
- if reverse:
25
- xticks_labels = reversed(xticks_labels)
26
- plt.xticks(ticks=range(n), labels=list(map(str, xticks_labels)))
27
-
28
-
29
- def step(
30
- x: list[float],
31
- reverse_xticks: bool = False,
32
- **kwargs: dict[str, Any],
33
- ):
34
- data = x.copy()
35
- data.insert(0, data[0])
36
- plt.step(list(range(len(data))), data, **kwargs)
37
- _set_xticks(len(data), reverse_xticks)
38
- plt.xlabel("Time bin")
39
-
40
-
41
- def _count_time_bins(true_values: dict[str, list[float]]) -> int:
42
- assert (
43
- len({len(true_value) for true_value in true_values.values()}) == 1
44
- ), "All targets must have the same number of change times."
45
- return len(next(iter((true_values.values()))))
46
-
47
-
48
- def plot_maes_per_time_bin(
49
- logs_summaries: dict[str, pl.DataFrame],
50
- true_values: dict[str, list[float]],
51
- output_filepath: str,
52
- reverse_xticks: bool = False,
53
- ):
54
- def _mae(target: str, i: int) -> pl.Expr:
55
- return (pl.col(f"{target}i{i}_median") - true_values[target][i]).abs()
56
-
57
- n_time_bins = _count_time_bins(true_values)
58
- df = pl.concat(
59
- logs_summaries[model]
60
- .select(
61
- pl.mean_horizontal([_mae(target, i) for target in true_values]).alias("MAE")
62
- )
63
- .with_columns(pl.lit(i).alias("Time bin"), pl.lit(model).alias("Model"))
64
- for i in range(n_time_bins)
65
- for model in logs_summaries
66
- )
67
- sns.violinplot(
68
- x="Time bin",
69
- y="MAE",
70
- hue="Model",
71
- data=df,
72
- inner=None,
73
- cut=0,
74
- density_norm="width",
75
- legend=False,
76
- )
77
- _set_xticks(n_time_bins, reverse_xticks)
78
- plt.savefig(output_filepath)
79
- plt.close()
80
-
81
-
82
- def plot_coverage_per_time_bin(
83
- logs_summaries: dict[str, pl.DataFrame],
84
- true_values: dict[str, list[float]],
85
- output_filepath: str,
86
- reverse_xticks: bool = False,
87
- ):
88
- def _coverage(model: str, target: str, i: int) -> float:
89
- lower_bound = logs_summaries[model][f"{target}i{i}_lower"]
90
- upper_bound = logs_summaries[model][f"{target}i{i}_upper"]
91
- true_value = true_values[target][i]
92
- N = len(logs_summaries[model])
93
- return ((lower_bound <= true_value) & (true_value <= upper_bound)).sum() / N
94
-
95
- n_time_bins = _count_time_bins(true_values)
96
- for model in logs_summaries:
97
- avg_coverage_by_time_bin = [
98
- np.mean([_coverage(model, target, i) for target in true_values])
99
- for i in range(_count_time_bins(true_values))
100
- ]
101
- plt.plot(avg_coverage_by_time_bin, marker="o")
102
-
103
- _set_xticks(n_time_bins, reverse_xticks)
104
- plt.xlabel("Time bin")
105
- plt.ylabel("Coverage")
106
- plt.ylim((0, 1.05))
107
- plt.savefig(output_filepath)
108
- plt.close()
109
-
110
-
111
- def plot_partial_dependencies(
112
- weights: list[list[Weights]], # shape: (n_mcmcs, n_weights_samples, ...)
113
- features: dict[str, Feature],
114
- output_dir: str,
115
- hidden_activation: ActivationFunction = sigmoid,
116
- output_activation: ActivationFunction = sigmoid,
117
- ):
118
- os.makedirs(output_dir, exist_ok=True)
119
- features_grid = [feature.grid for feature in features.values()]
120
-
121
- def _get_median_partial_dependence_values(
122
- weights: list[Weights],
123
- ) -> list[list[float]]:
124
- pdvalues_distribution = [
125
- get_partial_dependence_values(
126
- weights=w,
127
- features_grid=features_grid,
128
- hidden_activation=hidden_activation,
129
- output_activation=output_activation,
130
- )
131
- for w in weights
132
- ]
133
- return [
134
- np.median(
135
- [pdvalues[feature_idx] for pdvalues in pdvalues_distribution], axis=0
136
- ).tolist()
137
- for feature_idx in range(len(features))
138
- ]
139
-
140
- jobs = Parallel(n_jobs=-1)(
141
- delayed(_get_median_partial_dependence_values)(w) for w in weights
142
- )
143
- pdvalues = [
144
- job for job in tqdm(jobs, total=len(weights), desc="Evaluating PDPs")
145
- ] # shape: (n_mcmcs, n_features, n_grid_points)
146
- pdvalues = [
147
- np.array(mcmc_pds).T for mcmc_pds in zip(*pdvalues)
148
- ] # shape: (n_features, n_grid_points, n_mcmcs)
149
-
150
- for feature_idx, (feature_name, feature) in enumerate(features.items()):
151
- color = "red" if feature.is_relevant else "gray"
152
- feature_pdvalues = pdvalues[feature_idx] # shape: (n_grid_points, n_mcmcs)
153
- if not feature.is_categorical:
154
- median = np.median(feature_pdvalues, axis=1)
155
- lower = np.percentile(feature_pdvalues, 2.5, axis=1)
156
- high = np.percentile(feature_pdvalues, 100 - 2.5, axis=1)
157
- plt.fill_between(feature.grid, lower, high, alpha=0.25, color=color)
158
- for mcmc_pds in feature_pdvalues.T:
159
- plt.plot(
160
- feature.grid,
161
- mcmc_pds,
162
- color=color,
163
- alpha=0.2,
164
- linewidth=1,
165
- )
166
- plt.plot(feature.grid, median, color=color, label=feature_name)
167
- plt.xlabel("Feature value")
168
- plt.ylabel(f"MLP Output")
169
- plt.legend()
170
- plt.savefig(os.path.join(output_dir, "PDPs-continuous.svg"))
171
- plt.close()
172
-
173
- plot_data = []
174
- grid_labels = []
175
- list_labels = []
176
- for feature_idx, (feature_name, feature) in enumerate(features.items()):
177
- feature_pdvalues = pdvalues[feature_idx] # shape: (n_grid_points, n_mcmcs)
178
- if feature.is_categorical:
179
- for i, grid_point in enumerate(feature.grid):
180
- plot_data.extend(feature_pdvalues[i])
181
- grid_labels.extend([grid_point] * len(feature_pdvalues[i]))
182
- list_labels.extend([feature_name] * len(feature_pdvalues[i]))
183
- if not (any(feature.is_categorical for feature in features.values())):
184
- return
185
- sns.violinplot(
186
- x=grid_labels,
187
- y=plot_data,
188
- hue=list_labels,
189
- split=False,
190
- cut=0,
191
- palette={
192
- feature_name: "red" if feature.is_relevant else "gray"
193
- for feature_name, feature in features.items()
194
- if feature.is_categorical
195
- },
196
- )
197
- plt.xlabel("Feature value")
198
- plt.ylabel(f"MLP Output")
199
- plt.legend()
200
- plt.savefig(os.path.join(output_dir, "PDPs-categorical.svg"))
201
- plt.close()
202
-
203
-
204
- def plot_shap_features_importance(
205
- weights: list[list[Weights]], # shape: (n_mcmcs, n_weights_samples, ...)
206
- features: dict[str, Feature],
207
- output_file: str,
208
- hidden_activation: ActivationFunction = sigmoid,
209
- output_activation: ActivationFunction = sigmoid,
210
- ):
211
- features_grid = [feature.grid for feature in features.values()]
212
- inputs = list(product(*features_grid))
213
-
214
- def _get_median_shap_features_importance(
215
- weights: list[Weights],
216
- ) -> list[list[float]]:
217
- features_importance = np.array(
218
- [
219
- get_shap_features_importance(
220
- weights=w,
221
- inputs=inputs,
222
- hidden_activation=hidden_activation,
223
- output_activation=output_activation,
224
- )
225
- for w in weights
226
- ]
227
- ) # shape: (n_weights_samples, n_features)
228
- return np.median(features_importance, axis=0).tolist() # shape: (n_features,)
229
-
230
- jobs = Parallel(n_jobs=-1, return_as="generator_unordered")(
231
- delayed(_get_median_shap_features_importance)(w) for w in weights
232
- )
233
- features_importance_distribution = np.array(
234
- [job for job in tqdm(jobs, total=len(weights), desc="Evaluating SHAPs")]
235
- ) # shape: (n_mcmcs, n_features)
236
- features_importance_distribution /= features_importance_distribution.sum(
237
- axis=1, keepdims=True
238
- )
239
-
240
- for i, (feature_name, feature) in enumerate(features.items()):
241
- sns.violinplot(
242
- y=features_importance_distribution[:, i],
243
- x=[feature_name] * features_importance_distribution.shape[0],
244
- cut=0,
245
- color="red" if feature.is_relevant else "gray",
246
- )
247
- plt.xlabel("Feature")
248
- plt.ylabel("Importance")
249
- plt.savefig(output_file)
250
- plt.close()
@@ -1,37 +0,0 @@
1
- bella_companion/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- bella_companion/cli.py,sha256=0sPnzGyUGo2OBZ0rj17ZGzMdwNH0o-BXKsYtCJjzGvQ,968
3
- bella_companion/fbd_empirical/data/body_mass.csv,sha256=-UkKNtm9m3g4PjY3BcfdP6z5nL_I6p9cq6cgZ-bWKI8,30360
4
- bella_companion/fbd_empirical/data/change_times.csv,sha256=zmc9_z91-XMwKyIoP9v9dVlLcf4MeIHkQiHLjoMriOo,120
5
- bella_companion/fbd_empirical/data/sampling_change_times.csv,sha256=Gwi9RcMFy89RyvfxKVZ_MoKVRHOZLuwB_3LEaq8asMQ,32
6
- bella_companion/fbd_empirical/data/trees.nwk,sha256=zhvLvPLZelhMThVmvOENkmi3p2aPAARb8KMdHTm6mss,4645318
7
- bella_companion/fbd_empirical/figure.py,sha256=4paOXCB1EcxuHzLPxDSleQU2AQ_ndTedtzS1ugiKICs,1018
8
- bella_companion/fbd_empirical/notbooks.ipynb,sha256=O45kmz0lZENRDFbKXEWPsIKATfF5GVeS5tCYmrGLnqk,83326
9
- bella_companion/fbd_empirical/params.json,sha256=hU23LniClZL_GSBAxIEJUJgMa93AM8zdtFOq6mt3vkI,311
10
- bella_companion/fbd_empirical/run_beast.py,sha256=2sV2UmxOfWmbueiU6D0p3lueMYiZyIkSKYoblTMrYuA,1935
11
- bella_companion/fbd_empirical/summarize_logs.py,sha256=O6rhE606Wa98a8b1KKlLPjUOro1pfyqVTLdQksQMG0g,1439
12
- bella_companion/simulations/__init__.py,sha256=i6Fe7l5sUJY9hPxdg6L_FVhwbSPhNxQNMb-m33JlfxI,258
13
- bella_companion/simulations/features.py,sha256=DZOBpJGlQ0UinqUZYbEtoemZ2eQGVLV_i-DfpW31qJI,104
14
- bella_companion/simulations/figures/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
- bella_companion/simulations/figures/epi_explainations.py,sha256=omiJgyIY-I6zcJAcyOF7GJ2pba6pMZySLkWy7OFrjFY,3093
16
- bella_companion/simulations/figures/epi_predictions.py,sha256=4yXwOBKxUv4kgZdI9zAMEhZ0QCNKZdkAafRQ1RTeaWg,1835
17
- bella_companion/simulations/figures/fbd_explainations.py,sha256=9Uj7yttpn_TH5HqycW8R-Nlky9A9aFXDXRpXQuT1L4s,3037
18
- bella_companion/simulations/figures/fbd_predictions.py,sha256=jdXYCLledZEWoPCIuTLhHEPMdeG6YXvf5xZnEOslv-U,2119
19
- bella_companion/simulations/figures/scenarios.py,sha256=vyybn3Qhfq96N8tvW0wSzpFoHHP8EIc8dkOz63o_Atw,2492
20
- bella_companion/simulations/figures/utils.py,sha256=sY8wFBg02fv5ugpJ80EqQishD_HEdLwhqsw2LfM7wEo,8539
21
- bella_companion/simulations/generate_data.py,sha256=H8OV4ZlTGZB-jXaROTPmOsK3UxRiU-GrX40l-shliw8,728
22
- bella_companion/simulations/run_beast.py,sha256=xOuwE0w4IbOqqCSym6kHsAEhfGT2mWdA-jmUZuviMbc,3121
23
- bella_companion/simulations/scenarios/__init__.py,sha256=3Kl1lKcFpfb3vLX64DmSW4XCF5kXU1ZoHtstFH-ZIzU,876
24
- bella_companion/simulations/scenarios/common.py,sha256=_ddaSuTvEVdttGkXB4HPc2B7IB1F_GBOCW3cVOPZ-ZM,807
25
- bella_companion/simulations/scenarios/epi_multitype.py,sha256=GWGIiqvYwX_FrT_3RXkZKYGDht9nZ7ceHRBKUvXDPnA,2432
26
- bella_companion/simulations/scenarios/epi_skyline.py,sha256=JqnOVATECxBUqEbkR5lBlMI2O8k4hO6ipR8k9cHUsm0,2365
27
- bella_companion/simulations/scenarios/fbd_2traits.py,sha256=sCtdWyV6GQQOIhnL9Dd8NIbAR-StTwUTD9-b_BalmFQ,3552
28
- bella_companion/simulations/scenarios/fbd_no_traits.py,sha256=R6CH0fVeQg-Iesl39pq2uY8ICVEO4VZbvUVUCGwauJU,2520
29
- bella_companion/simulations/scenarios/scenario.py,sha256=_FRWAyOFbw94lAzd3zCD-1ek4TrssoiXfXRQPShLiIA,620
30
- bella_companion/simulations/summarize_logs.py,sha256=5IdzR9IwjeF2LgZmzpuK0rPfYMct2OUgEp0QyUbUS7g,1263
31
- bella_companion/utils/__init__.py,sha256=_5tLPH_3GHtimNcH0Yd9Z6yIM3WkWkNApNGLzFnF6nY,222
32
- bella_companion/utils/beast.py,sha256=RG-iSEFuL92K6yxUV2nxdmcVqfrEiPhaYTmReW4ZoWk,2189
33
- bella_companion/utils/slurm.py,sha256=v5DaG7YHVyK8KRFptgGDC6I8jxEhyJuMVK9N08pZSAI,1812
34
- bella_companion-0.0.4.dist-info/METADATA,sha256=UKU-LZpRje6oxM1GDM6Qxa82sb8eyFKjOEfHu8Xb0fw,534
35
- bella_companion-0.0.4.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
36
- bella_companion-0.0.4.dist-info/entry_points.txt,sha256=rSeKoAhmjnQqAYFcXBv0gAM2ViJfJe0D8_dD-fWrXeg,50
37
- bella_companion-0.0.4.dist-info/RECORD,,