PyEvoMotion 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- PyEvoMotion/cli.py +87 -3
- PyEvoMotion/core/base.py +296 -20
- PyEvoMotion/core/core.py +73 -24
- {pyevomotion-0.1.1.dist-info → pyevomotion-0.1.2.dist-info}/METADATA +1 -1
- pyevomotion-0.1.2.dist-info/RECORD +35 -0
- share/analyze_model_selection_accuracy.py +316 -0
- share/analyze_test_runs.py +436 -0
- share/anomalous_diffusion.pdf +0 -0
- share/confusion_matrix_heatmap.pdf +0 -0
- share/figUK_plots.pdf +0 -0
- share/figUK_regression_results.json +54 -7
- share/figUK_run_args.json +1 -0
- share/figUK_stats.tsv +41 -41
- share/figUSA_plots.pdf +0 -0
- share/figUSA_regression_results.json +54 -7
- share/figUSA_run_args.json +1 -0
- share/figUSA_stats.tsv +34 -34
- share/generate_sequences_from_test5_data.py +107 -0
- share/manuscript_figure.py +450 -80
- share/run_parallel_analysis.py +196 -0
- share/synth_figure.pdf +0 -0
- share/uk_time_windows.pdf +0 -0
- share/weekly_size.pdf +0 -0
- pyevomotion-0.1.1.dist-info/RECORD +0 -31
- share/figure.pdf +0 -0
- {pyevomotion-0.1.1.dist-info → pyevomotion-0.1.2.dist-info}/WHEEL +0 -0
- {pyevomotion-0.1.1.dist-info → pyevomotion-0.1.2.dist-info}/entry_points.txt +0 -0
share/manuscript_figure.py
CHANGED
|
@@ -10,6 +10,7 @@ import numpy as np
|
|
|
10
10
|
import pandas as pd
|
|
11
11
|
import matplotlib as mpl
|
|
12
12
|
import matplotlib.pyplot as plt
|
|
13
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
13
14
|
|
|
14
15
|
#´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
|
|
15
16
|
# CONSTANTS #
|
|
@@ -20,6 +21,9 @@ COLORS = {
|
|
|
20
21
|
"USA": "#FF6346",
|
|
21
22
|
}
|
|
22
23
|
|
|
24
|
+
# Control confidence interval plotting
|
|
25
|
+
PLOT_CONFIDENCE_INTERVALS = False
|
|
26
|
+
|
|
23
27
|
#´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
|
|
24
28
|
# FUNCTIONS #
|
|
25
29
|
#.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•#
|
|
@@ -169,36 +173,12 @@ def load_models() -> dict[str, dict[str, callable]]:
|
|
|
169
173
|
|
|
170
174
|
return {
|
|
171
175
|
"USA": {
|
|
172
|
-
"mean": [
|
|
173
|
-
|
|
174
|
-
_contents["USA"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
|
|
175
|
-
+ _contents["USA"]["mean number of mutations per 7D model"]["parameters"]["b"]
|
|
176
|
-
),
|
|
177
|
-
_contents["USA"]["mean number of mutations per 7D model"]["r2"],
|
|
178
|
-
],
|
|
179
|
-
"var": [
|
|
180
|
-
lambda x: (
|
|
181
|
-
_contents["USA"]["scaled var number of mutations per 7D model"]["parameters"]["d"]
|
|
182
|
-
*(x**_contents["USA"]["scaled var number of mutations per 7D model"]["parameters"]["alpha"])
|
|
183
|
-
),
|
|
184
|
-
_contents["USA"]["scaled var number of mutations per 7D model"]["r2"],
|
|
185
|
-
]
|
|
176
|
+
"mean": list(_get_mean_model(_contents["USA"], "USA")),
|
|
177
|
+
"var": list(_get_var_model(_contents["USA"], "USA"))
|
|
186
178
|
},
|
|
187
179
|
"UK": {
|
|
188
|
-
"mean": [
|
|
189
|
-
|
|
190
|
-
_contents["UK"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
|
|
191
|
-
+ _contents["UK"]["mean number of mutations per 7D model"]["parameters"]["b"]
|
|
192
|
-
),
|
|
193
|
-
_contents["UK"]["mean number of mutations per 7D model"]["r2"],
|
|
194
|
-
],
|
|
195
|
-
"var": [
|
|
196
|
-
lambda x: (
|
|
197
|
-
_contents["UK"]["scaled var number of mutations per 7D model"]["parameters"]["d"]
|
|
198
|
-
*(x**_contents["UK"]["scaled var number of mutations per 7D model"]["parameters"]["alpha"])
|
|
199
|
-
),
|
|
200
|
-
_contents["UK"]["scaled var number of mutations per 7D model"]["r2"],
|
|
201
|
-
]
|
|
180
|
+
"mean": list(_get_mean_model(_contents["UK"], "UK")),
|
|
181
|
+
"var": list(_get_var_model(_contents["UK"], "UK"))
|
|
202
182
|
},
|
|
203
183
|
}
|
|
204
184
|
|
|
@@ -211,6 +191,62 @@ def safe_map(f: callable, x: list[int | float]) -> list[int | float]:
|
|
|
211
191
|
_results.append(None)
|
|
212
192
|
return _results
|
|
213
193
|
|
|
194
|
+
|
|
195
|
+
def _calculate_confidence_bounds(x_values: np.ndarray, model_func: callable, confidence_intervals: dict, model_type: str) -> tuple[np.ndarray, np.ndarray]:
|
|
196
|
+
"""Calculate confidence interval bounds for a model.
|
|
197
|
+
|
|
198
|
+
:param x_values: X values to calculate bounds for
|
|
199
|
+
:type x_values: np.ndarray
|
|
200
|
+
:param model_func: The model function
|
|
201
|
+
:type model_func: callable
|
|
202
|
+
:param confidence_intervals: Dictionary of confidence intervals for parameters
|
|
203
|
+
:type confidence_intervals: dict
|
|
204
|
+
:param model_type: Type of model ('linear_mean', 'linear_var', 'power_law')
|
|
205
|
+
:type model_type: str
|
|
206
|
+
:return: Tuple of (lower_bounds, upper_bounds)
|
|
207
|
+
:rtype: tuple[np.ndarray, np.ndarray]
|
|
208
|
+
"""
|
|
209
|
+
if not confidence_intervals:
|
|
210
|
+
# No confidence intervals available, return None bounds
|
|
211
|
+
return None, None
|
|
212
|
+
|
|
213
|
+
if model_type == "linear_mean":
|
|
214
|
+
# For linear mean model: mx + b
|
|
215
|
+
if "m" in confidence_intervals and "b" in confidence_intervals:
|
|
216
|
+
m_lower, m_upper = confidence_intervals["m"]
|
|
217
|
+
b_lower, b_upper = confidence_intervals["b"]
|
|
218
|
+
|
|
219
|
+
# Calculate bounds for the linear function
|
|
220
|
+
lower_bounds = m_lower * x_values + b_lower
|
|
221
|
+
upper_bounds = m_upper * x_values + b_upper
|
|
222
|
+
|
|
223
|
+
return lower_bounds, upper_bounds
|
|
224
|
+
|
|
225
|
+
elif model_type == "linear_var":
|
|
226
|
+
# For linear variance model: mx
|
|
227
|
+
if "m" in confidence_intervals:
|
|
228
|
+
m_lower, m_upper = confidence_intervals["m"]
|
|
229
|
+
|
|
230
|
+
lower_bounds = m_lower * x_values
|
|
231
|
+
upper_bounds = m_upper * x_values
|
|
232
|
+
|
|
233
|
+
return lower_bounds, upper_bounds
|
|
234
|
+
|
|
235
|
+
elif model_type == "power_law":
|
|
236
|
+
# For power law model: d*x^alpha
|
|
237
|
+
if "d" in confidence_intervals and "alpha" in confidence_intervals:
|
|
238
|
+
d_lower, d_upper = confidence_intervals["d"]
|
|
239
|
+
alpha_lower, alpha_upper = confidence_intervals["alpha"]
|
|
240
|
+
|
|
241
|
+
# For power law, we need to be careful about the bounds
|
|
242
|
+
# We'll use the parameter bounds to create approximate confidence bounds
|
|
243
|
+
lower_bounds = d_lower * (x_values ** alpha_lower)
|
|
244
|
+
upper_bounds = d_upper * (x_values ** alpha_upper)
|
|
245
|
+
|
|
246
|
+
return lower_bounds, upper_bounds
|
|
247
|
+
|
|
248
|
+
return None, None
|
|
249
|
+
|
|
214
250
|
def plot_main_figure(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: bool = True) -> None:
|
|
215
251
|
set_matplotlib_global_params()
|
|
216
252
|
fig, ax = plt.subplots(2, 1, figsize=(6, 10))
|
|
@@ -229,14 +265,48 @@ def plot_main_figure(df: pd.DataFrame, models: dict[str, any], export: bool = Fa
|
|
|
229
265
|
)
|
|
230
266
|
|
|
231
267
|
_x = np.arange(-10, 60, 0.5)
|
|
268
|
+
_x_shifted = _x + (8 if _country == "USA" else 0)
|
|
269
|
+
|
|
270
|
+
# Plot the main model line
|
|
232
271
|
ax[idx].plot(
|
|
233
|
-
|
|
272
|
+
_x_shifted,
|
|
234
273
|
safe_map(models[_country][case][0], _x),
|
|
235
274
|
color=COLORS[_country],
|
|
236
275
|
label=rf"{_country} ($R^2 = {round(models[_country][case][1], 2):.2f})$",
|
|
237
276
|
linewidth=3,
|
|
238
277
|
zorder=1,
|
|
239
278
|
)
|
|
279
|
+
|
|
280
|
+
# Plot confidence intervals if available and enabled
|
|
281
|
+
if PLOT_CONFIDENCE_INTERVALS and len(models[_country][case]) > 2 and models[_country][case][2]:
|
|
282
|
+
confidence_intervals = models[_country][case][2]
|
|
283
|
+
|
|
284
|
+
# Determine model type for confidence interval calculation
|
|
285
|
+
if case == "mean":
|
|
286
|
+
model_type = "linear_mean"
|
|
287
|
+
else: # case == "var"
|
|
288
|
+
# Check if it's linear or power law based on the model function
|
|
289
|
+
# We'll determine this by checking if the model has 'alpha' parameter
|
|
290
|
+
if "alpha" in confidence_intervals:
|
|
291
|
+
model_type = "power_law"
|
|
292
|
+
else:
|
|
293
|
+
model_type = "linear_var"
|
|
294
|
+
|
|
295
|
+
lower_bounds, upper_bounds = _calculate_confidence_bounds(
|
|
296
|
+
_x, models[_country][case][0], confidence_intervals, model_type
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
if lower_bounds is not None and upper_bounds is not None:
|
|
300
|
+
# Plot confidence interval as filled area
|
|
301
|
+
# The x-axis shift is already applied to _x_shifted, so we use the original bounds
|
|
302
|
+
ax[idx].fill_between(
|
|
303
|
+
_x_shifted,
|
|
304
|
+
lower_bounds,
|
|
305
|
+
upper_bounds,
|
|
306
|
+
color=COLORS[_country],
|
|
307
|
+
alpha=0.2,
|
|
308
|
+
zorder=0,
|
|
309
|
+
)
|
|
240
310
|
|
|
241
311
|
# Styling
|
|
242
312
|
ax[idx].set_xlim(-0.5, 40.5)
|
|
@@ -264,11 +334,11 @@ def plot_main_figure(df: pd.DataFrame, models: dict[str, any], export: bool = Fa
|
|
|
264
334
|
|
|
265
335
|
if export:
|
|
266
336
|
fig.savefig(
|
|
267
|
-
"share/figure.
|
|
337
|
+
"share/figure.eps",
|
|
268
338
|
dpi=400,
|
|
269
339
|
bbox_inches="tight",
|
|
270
340
|
)
|
|
271
|
-
print("Figure saved as share/figure.
|
|
341
|
+
print("Figure saved as share/figure.eps")
|
|
272
342
|
|
|
273
343
|
if show: plt.show()
|
|
274
344
|
|
|
@@ -382,8 +452,8 @@ def run_synthetic_data_tests() -> None:
|
|
|
382
452
|
result1 = subprocess.run(
|
|
383
453
|
[
|
|
384
454
|
"PyEvoMotion",
|
|
385
|
-
"S1.fasta",
|
|
386
|
-
"S1.tsv",
|
|
455
|
+
"tests/data/test4/S1.fasta",
|
|
456
|
+
"tests/data/test4/S1.tsv",
|
|
387
457
|
"tests/data/test4/synthdata1_out",
|
|
388
458
|
"-ep"
|
|
389
459
|
],
|
|
@@ -401,8 +471,8 @@ def run_synthetic_data_tests() -> None:
|
|
|
401
471
|
result2 = subprocess.run(
|
|
402
472
|
[
|
|
403
473
|
"PyEvoMotion",
|
|
404
|
-
"S2.fasta",
|
|
405
|
-
"S2.tsv",
|
|
474
|
+
"tests/data/test4/S2.fasta",
|
|
475
|
+
"tests/data/test4/S2.tsv",
|
|
406
476
|
"tests/data/test4/synthdata2_out",
|
|
407
477
|
"-ep"
|
|
408
478
|
],
|
|
@@ -433,6 +503,89 @@ def load_synthetic_data_df() -> pd.DataFrame:
|
|
|
433
503
|
suffixes=(" synt1", " synt2"),
|
|
434
504
|
)
|
|
435
505
|
|
|
506
|
+
def _get_mean_model(data: dict, kind: str) -> tuple[callable, float, dict]:
|
|
507
|
+
"""Extract mean model from data, handling both old and new formats.
|
|
508
|
+
|
|
509
|
+
:param data: The regression results data dictionary
|
|
510
|
+
:type data: dict
|
|
511
|
+
:param kind: The dataset kind identifier (for error messages)
|
|
512
|
+
:type kind: str
|
|
513
|
+
:return: Tuple of (lambda function, r2 value, confidence intervals)
|
|
514
|
+
:rtype: tuple[callable, float, dict]
|
|
515
|
+
"""
|
|
516
|
+
# Try different possible key formats
|
|
517
|
+
possible_keys = [
|
|
518
|
+
"mean number of mutations model", # New format (current)
|
|
519
|
+
"mean number of mutations per 7D model", # Old format
|
|
520
|
+
"mean number of substitutions model", # Alternative format
|
|
521
|
+
"mean number of substitutions per 7D model" # Alternative old format
|
|
522
|
+
]
|
|
523
|
+
|
|
524
|
+
for mean_key in possible_keys:
|
|
525
|
+
if mean_key in data:
|
|
526
|
+
params = data[mean_key]["parameters"]
|
|
527
|
+
r2 = data[mean_key]["r2"]
|
|
528
|
+
confidence_intervals = data[mean_key].get("confidence_intervals", {})
|
|
529
|
+
return lambda x: params["m"] * x + params["b"], r2, confidence_intervals
|
|
530
|
+
|
|
531
|
+
raise KeyError(f"Could not find mean model in {kind} data. Available keys: {list(data.keys())}")
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def _get_var_model(data: dict, kind: str) -> tuple[callable, float, dict]:
|
|
535
|
+
"""Extract variance model from data, handling both old and new formats.
|
|
536
|
+
|
|
537
|
+
:param data: The regression results data dictionary
|
|
538
|
+
:type data: dict
|
|
539
|
+
:param kind: The dataset kind identifier (for error messages)
|
|
540
|
+
:type kind: str
|
|
541
|
+
:return: Tuple of (lambda function, r2 value, confidence intervals)
|
|
542
|
+
:rtype: tuple[callable, float, dict]
|
|
543
|
+
"""
|
|
544
|
+
# Try different possible key formats
|
|
545
|
+
possible_keys = [
|
|
546
|
+
"scaled var number of mutations model", # New format (current)
|
|
547
|
+
"scaled var number of mutations per 7D model", # Old format
|
|
548
|
+
"scaled var number of substitutions model", # Alternative format
|
|
549
|
+
"scaled var number of substitutions per 7D model" # Alternative old format
|
|
550
|
+
]
|
|
551
|
+
|
|
552
|
+
for var_key in possible_keys:
|
|
553
|
+
if var_key in data:
|
|
554
|
+
# Check if it has model_selection (new format)
|
|
555
|
+
if "model_selection" in data[var_key]:
|
|
556
|
+
# New format with model selection
|
|
557
|
+
model_selection = data[var_key]["model_selection"]
|
|
558
|
+
selected = model_selection["selected"]
|
|
559
|
+
|
|
560
|
+
if selected == "linear" and "linear_model" in data[var_key]:
|
|
561
|
+
# Use linear model
|
|
562
|
+
linear_model = data[var_key]["linear_model"]
|
|
563
|
+
params = linear_model["parameters"]
|
|
564
|
+
r2 = linear_model["r2"]
|
|
565
|
+
confidence_intervals = linear_model.get("confidence_intervals", {})
|
|
566
|
+
return lambda x: params["m"] * x, r2, confidence_intervals
|
|
567
|
+
elif selected == "power_law" and "power_law_model" in data[var_key]:
|
|
568
|
+
# Use power law model
|
|
569
|
+
power_law_model = data[var_key]["power_law_model"]
|
|
570
|
+
params = power_law_model["parameters"]
|
|
571
|
+
r2 = power_law_model["r2"]
|
|
572
|
+
confidence_intervals = power_law_model.get("confidence_intervals", {})
|
|
573
|
+
return lambda x: params["d"] * (x ** params["alpha"]), r2, confidence_intervals
|
|
574
|
+
else:
|
|
575
|
+
# Old format or new format without model selection - direct parameters
|
|
576
|
+
params = data[var_key]["parameters"]
|
|
577
|
+
r2 = data[var_key]["r2"]
|
|
578
|
+
confidence_intervals = data[var_key].get("confidence_intervals", {})
|
|
579
|
+
if "m" in params:
|
|
580
|
+
# Linear model: mx
|
|
581
|
+
return lambda x: params["m"] * x, r2, confidence_intervals
|
|
582
|
+
elif "d" in params and "alpha" in params:
|
|
583
|
+
# Power law model: d*x^alpha
|
|
584
|
+
return lambda x: params["d"] * (x ** params["alpha"]), r2, confidence_intervals
|
|
585
|
+
|
|
586
|
+
raise KeyError(f"Could not find variance model in {kind} data. Available keys: {list(data.keys())}")
|
|
587
|
+
|
|
588
|
+
|
|
436
589
|
def load_synthetic_data_models() -> dict[str, dict[str, callable]]:
|
|
437
590
|
if not check_synthetic_data_exists():
|
|
438
591
|
run_synthetic_data_tests()
|
|
@@ -448,35 +601,12 @@ def load_synthetic_data_models() -> dict[str, dict[str, callable]]:
|
|
|
448
601
|
|
|
449
602
|
return {
|
|
450
603
|
"synt1": {
|
|
451
|
-
"mean": [
|
|
452
|
-
|
|
453
|
-
_contents["synt1"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
|
|
454
|
-
+ _contents["synt1"]["mean number of mutations per 7D model"]["parameters"]["b"]
|
|
455
|
-
),
|
|
456
|
-
_contents["synt1"]["mean number of mutations per 7D model"]["r2"],
|
|
457
|
-
],
|
|
458
|
-
"var": [
|
|
459
|
-
lambda x: (
|
|
460
|
-
_contents["synt1"]["scaled var number of mutations per 7D model"]["parameters"]["m"]*x
|
|
461
|
-
),
|
|
462
|
-
_contents["synt1"]["scaled var number of mutations per 7D model"]["r2"],
|
|
463
|
-
]
|
|
604
|
+
"mean": list(_get_mean_model(_contents["synt1"], "synt1")),
|
|
605
|
+
"var": list(_get_var_model(_contents["synt1"], "synt1"))
|
|
464
606
|
},
|
|
465
607
|
"synt2": {
|
|
466
|
-
"mean": [
|
|
467
|
-
|
|
468
|
-
_contents["synt2"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
|
|
469
|
-
+ _contents["synt2"]["mean number of mutations per 7D model"]["parameters"]["b"]
|
|
470
|
-
),
|
|
471
|
-
_contents["synt2"]["mean number of mutations per 7D model"]["r2"],
|
|
472
|
-
],
|
|
473
|
-
"var": [
|
|
474
|
-
lambda x: (
|
|
475
|
-
_contents["synt2"]["scaled var number of mutations per 7D model"]["parameters"]["d"]
|
|
476
|
-
*(x**_contents["synt2"]["scaled var number of mutations per 7D model"]["parameters"]["alpha"])
|
|
477
|
-
),
|
|
478
|
-
_contents["synt2"]["scaled var number of mutations per 7D model"]["r2"],
|
|
479
|
-
]
|
|
608
|
+
"mean": list(_get_mean_model(_contents["synt2"], "synt2")),
|
|
609
|
+
"var": list(_get_var_model(_contents["synt2"], "synt2"))
|
|
480
610
|
},
|
|
481
611
|
}
|
|
482
612
|
|
|
@@ -513,6 +643,35 @@ def synthetic_data_plot(df: pd.DataFrame, models: dict[str, any], export: bool =
|
|
|
513
643
|
linewidth=3,
|
|
514
644
|
zorder=1,
|
|
515
645
|
)
|
|
646
|
+
|
|
647
|
+
# Plot confidence intervals if available and enabled
|
|
648
|
+
if PLOT_CONFIDENCE_INTERVALS and len(models[_type.lower()][case]) > 2 and models[_type.lower()][case][2]:
|
|
649
|
+
confidence_intervals = models[_type.lower()][case][2]
|
|
650
|
+
|
|
651
|
+
# Determine model type for confidence interval calculation
|
|
652
|
+
if case == "mean":
|
|
653
|
+
model_type = "linear_mean"
|
|
654
|
+
else: # case == "var"
|
|
655
|
+
# Check if it's linear or power law based on the model function
|
|
656
|
+
if "alpha" in confidence_intervals:
|
|
657
|
+
model_type = "power_law"
|
|
658
|
+
else:
|
|
659
|
+
model_type = "linear_var"
|
|
660
|
+
|
|
661
|
+
lower_bounds, upper_bounds = _calculate_confidence_bounds(
|
|
662
|
+
_x, models[_type.lower()][case][0], confidence_intervals, model_type
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
if lower_bounds is not None and upper_bounds is not None:
|
|
666
|
+
# Plot confidence interval as filled area
|
|
667
|
+
ax[plot_idx].fill_between(
|
|
668
|
+
_x,
|
|
669
|
+
lower_bounds,
|
|
670
|
+
upper_bounds,
|
|
671
|
+
color="#76d6ff",
|
|
672
|
+
alpha=0.2,
|
|
673
|
+
zorder=0,
|
|
674
|
+
)
|
|
516
675
|
|
|
517
676
|
# Styling
|
|
518
677
|
ax[plot_idx].set_xlim(-0.5, 40.5)
|
|
@@ -589,17 +748,19 @@ def load_additional_uk_models() -> dict[str, dict[str, callable]]:
|
|
|
589
748
|
k: {
|
|
590
749
|
"mean": [
|
|
591
750
|
{
|
|
592
|
-
"m": _contents[k][
|
|
593
|
-
"b": _contents[k][
|
|
751
|
+
"m": _contents[k]["mean number of mutations model"]["parameters"]["m"],
|
|
752
|
+
"b": _contents[k]["mean number of mutations model"]["parameters"]["b"]
|
|
594
753
|
},
|
|
595
|
-
_contents[k][
|
|
754
|
+
_contents[k]["mean number of mutations model"]["r2"],
|
|
755
|
+
_contents[k]["mean number of mutations model"]["confidence_intervals"]
|
|
596
756
|
],
|
|
597
757
|
"var": [
|
|
598
758
|
{
|
|
599
|
-
"d": _contents[k][
|
|
600
|
-
"alpha": _contents[k][
|
|
759
|
+
"d": _contents[k]["scaled var number of mutations model"]["power_law_model"]["parameters"]["d"],
|
|
760
|
+
"alpha": _contents[k]["scaled var number of mutations model"]["power_law_model"]["parameters"]["alpha"]
|
|
601
761
|
},
|
|
602
|
-
_contents[k][
|
|
762
|
+
_contents[k]["scaled var number of mutations model"]["power_law_model"]["r2"],
|
|
763
|
+
_contents[k]["scaled var number of mutations model"]["power_law_model"]["confidence_intervals"]
|
|
603
764
|
]
|
|
604
765
|
}
|
|
605
766
|
for k in _files.keys()
|
|
@@ -624,18 +785,18 @@ def plot_uk_time_windows(stats: dict[str, pd.DataFrame], models: dict[str, dict[
|
|
|
624
785
|
for idx, window in enumerate(windows):
|
|
625
786
|
df = stats[window]
|
|
626
787
|
model = models[window]
|
|
627
|
-
scaling = {
|
|
628
|
-
"5D": 5/7,
|
|
788
|
+
scaling = { # For the models to be comparable to the 7D model, we need to scale the x-axis by the square of the time window ratio
|
|
789
|
+
"5D": (5/7)**2,
|
|
629
790
|
"7D": 1,
|
|
630
|
-
"10D": 10/7,
|
|
631
|
-
"14D": 14/7,
|
|
791
|
+
"10D": (10/7)**2,
|
|
792
|
+
"14D": (14/7)**2,
|
|
632
793
|
}
|
|
633
794
|
for idx2, case in enumerate(("mean", "var")):
|
|
634
795
|
|
|
635
796
|
if case == "mean":
|
|
636
797
|
# Plot mean
|
|
637
798
|
ax[idx2, idx].scatter(
|
|
638
|
-
df
|
|
799
|
+
df["dt_idx"],
|
|
639
800
|
df["mean number of mutations"],
|
|
640
801
|
color=COLORS["UK"],
|
|
641
802
|
edgecolor="k",
|
|
@@ -645,17 +806,40 @@ def plot_uk_time_windows(stats: dict[str, pd.DataFrame], models: dict[str, dict[
|
|
|
645
806
|
_x = np.arange(-0.5, 51, 0.5)
|
|
646
807
|
ax[idx2, idx].plot(
|
|
647
808
|
_x,
|
|
648
|
-
model["mean"][0]["m"]*
|
|
809
|
+
model["mean"][0]["m"]*_x + model["mean"][0]["b"],
|
|
649
810
|
color=COLORS["UK"],
|
|
650
811
|
label=rf"Mean ($R^2 = {round(model['mean'][1], 2):.2f})$",
|
|
651
812
|
linewidth=3,
|
|
652
813
|
zorder=1,
|
|
653
814
|
)
|
|
815
|
+
|
|
816
|
+
# Plot confidence intervals if available and enabled
|
|
817
|
+
if PLOT_CONFIDENCE_INTERVALS and len(model["mean"]) > 2 and model["mean"][2]:
|
|
818
|
+
confidence_intervals = model["mean"][2]
|
|
819
|
+
|
|
820
|
+
# For mean, it's always linear model
|
|
821
|
+
model_type = "linear_mean"
|
|
822
|
+
|
|
823
|
+
# Calculate confidence bounds
|
|
824
|
+
lower_bounds, upper_bounds = _calculate_confidence_bounds(
|
|
825
|
+
_x, model["mean"][0], confidence_intervals, model_type
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
if lower_bounds is not None and upper_bounds is not None:
|
|
829
|
+
# Plot confidence interval as filled area
|
|
830
|
+
ax[idx2, idx].fill_between(
|
|
831
|
+
_x,
|
|
832
|
+
lower_bounds,
|
|
833
|
+
upper_bounds,
|
|
834
|
+
color=COLORS["UK"],
|
|
835
|
+
alpha=0.2,
|
|
836
|
+
zorder=0,
|
|
837
|
+
)
|
|
654
838
|
|
|
655
839
|
elif case == "var":
|
|
656
840
|
# Plot variance
|
|
657
841
|
ax[idx2, idx].scatter(
|
|
658
|
-
df
|
|
842
|
+
df["dt_idx"],
|
|
659
843
|
df["var number of mutations"] - df["var number of mutations"].min(),
|
|
660
844
|
color=COLORS["UK"],
|
|
661
845
|
edgecolor="k",
|
|
@@ -664,12 +848,36 @@ def plot_uk_time_windows(stats: dict[str, pd.DataFrame], models: dict[str, dict[
|
|
|
664
848
|
|
|
665
849
|
ax[idx2, idx].plot(
|
|
666
850
|
_x,
|
|
667
|
-
model["var"][0]["d"]*(_x
|
|
851
|
+
model["var"][0]["d"]*(_x*scaling[window])**model["var"][0]["alpha"],
|
|
668
852
|
color=COLORS["UK"],
|
|
669
853
|
label=rf"Var ($R^2 = {round(model['var'][1], 2):.2f})$",
|
|
670
854
|
linewidth=3,
|
|
671
855
|
zorder=1,
|
|
672
856
|
)
|
|
857
|
+
|
|
858
|
+
# Plot confidence intervals if available and enabled
|
|
859
|
+
if PLOT_CONFIDENCE_INTERVALS and len(model["var"]) > 2 and model["var"][2]:
|
|
860
|
+
confidence_intervals = model["var"][2]
|
|
861
|
+
|
|
862
|
+
# For variance, it's always power law model
|
|
863
|
+
model_type = "power_law"
|
|
864
|
+
|
|
865
|
+
# Calculate confidence bounds for the scaled x values
|
|
866
|
+
scaled_x = _x * scaling[window]
|
|
867
|
+
lower_bounds, upper_bounds = _calculate_confidence_bounds(
|
|
868
|
+
scaled_x, model["var"][0], confidence_intervals, model_type
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
if lower_bounds is not None and upper_bounds is not None:
|
|
872
|
+
# Plot confidence interval as filled area
|
|
873
|
+
ax[idx2, idx].fill_between(
|
|
874
|
+
_x,
|
|
875
|
+
lower_bounds,
|
|
876
|
+
upper_bounds,
|
|
877
|
+
color=COLORS["UK"],
|
|
878
|
+
alpha=0.2,
|
|
879
|
+
zorder=0,
|
|
880
|
+
)
|
|
673
881
|
|
|
674
882
|
# Styling
|
|
675
883
|
ax[idx2, idx].set_xlim(-0.5, 40.5)
|
|
@@ -698,6 +906,165 @@ def plot_uk_time_windows(stats: dict[str, pd.DataFrame], models: dict[str, dict[
|
|
|
698
906
|
if show:
|
|
699
907
|
plt.show()
|
|
700
908
|
|
|
909
|
+
def load_model_selection_results(directory: str) -> list[dict]:
|
|
910
|
+
"""Load all regression results from a directory for model selection analysis.
|
|
911
|
+
|
|
912
|
+
This function recursively walks through the directory tree to find all
|
|
913
|
+
regression results JSON files, supporting both flat and nested directory structures.
|
|
914
|
+
|
|
915
|
+
Expected structure:
|
|
916
|
+
directory/
|
|
917
|
+
├── {timestamp}/
|
|
918
|
+
│ ├── {dataset_01}/
|
|
919
|
+
│ │ └── *_out_regression_results.json
|
|
920
|
+
│ ├── {dataset_02}/
|
|
921
|
+
│ │ └── *_out_regression_results.json
|
|
922
|
+
│ └── ...
|
|
923
|
+
└── (or any nested structure)
|
|
924
|
+
|
|
925
|
+
:param directory: Root directory to search for regression results
|
|
926
|
+
:type directory: str
|
|
927
|
+
:return: List of dictionaries containing model selection information
|
|
928
|
+
:rtype: list[dict]
|
|
929
|
+
"""
|
|
930
|
+
results = []
|
|
931
|
+
|
|
932
|
+
# Walk through directory tree to find all regression results files
|
|
933
|
+
# This works with any directory structure (flat or nested)
|
|
934
|
+
for root, dirs, files in os.walk(directory):
|
|
935
|
+
for file in files:
|
|
936
|
+
if file.endswith("out_regression_results.json"):
|
|
937
|
+
file_path = os.path.join(root, file)
|
|
938
|
+
try:
|
|
939
|
+
with open(file_path, 'r') as f:
|
|
940
|
+
data = json.load(f)
|
|
941
|
+
# Extract the model selection info
|
|
942
|
+
model_selection = data.get("scaled var number of substitutions model", {}).get("model_selection", {})
|
|
943
|
+
results.append({
|
|
944
|
+
'file': file_path,
|
|
945
|
+
'selected_model': model_selection.get("selected", "unknown"),
|
|
946
|
+
'linear_AIC': model_selection.get("linear_AIC", None),
|
|
947
|
+
'power_law_AIC': model_selection.get("power_law_AIC", None),
|
|
948
|
+
'delta_AIC_linear': model_selection.get("delta_AIC_linear", None),
|
|
949
|
+
'delta_AIC_power_law': model_selection.get("delta_AIC_power_law", None),
|
|
950
|
+
'akaike_weight_linear': model_selection.get("akaike_weight_linear", None),
|
|
951
|
+
'akaike_weight_power_law': model_selection.get("akaike_weight_power_law", None)
|
|
952
|
+
})
|
|
953
|
+
except Exception as e:
|
|
954
|
+
print(f"Error loading {file_path}: {e}")
|
|
955
|
+
|
|
956
|
+
return results
|
|
957
|
+
|
|
958
|
+
def create_confusion_matrix_plot(export: bool = False, show: bool = True) -> None:
|
|
959
|
+
"""
|
|
960
|
+
Create a confusion matrix plot for model selection accuracy analysis.
|
|
961
|
+
|
|
962
|
+
Analyzes regression results from test5 synthetic datasets to assess
|
|
963
|
+
model selection accuracy. Works with organized directory structure:
|
|
964
|
+
|
|
965
|
+
tests/data/test5/
|
|
966
|
+
├── linear/output/{timestamp}/
|
|
967
|
+
│ ├── linear_01/
|
|
968
|
+
│ │ └── linear_01_out_regression_results.json
|
|
969
|
+
│ ├── linear_02/
|
|
970
|
+
│ └── ...
|
|
971
|
+
└── powerlaw/output/{timestamp}/
|
|
972
|
+
├── powerlaw_01/
|
|
973
|
+
└── ...
|
|
974
|
+
|
|
975
|
+
Args:
|
|
976
|
+
export: Whether to export the figure to share/confusion_matrix_heatmap.pdf
|
|
977
|
+
show: Whether to display the figure
|
|
978
|
+
"""
|
|
979
|
+
set_matplotlib_global_params()
|
|
980
|
+
|
|
981
|
+
# Define paths - searches recursively through all subdirectories
|
|
982
|
+
base_path = "tests/data/test5"
|
|
983
|
+
linear_dir = os.path.join(base_path, "linear", "output")
|
|
984
|
+
powerlaw_dir = os.path.join(base_path, "powerlaw", "output")
|
|
985
|
+
|
|
986
|
+
print("Loading model selection results...")
|
|
987
|
+
|
|
988
|
+
# Load results from both directories
|
|
989
|
+
linear_results = load_model_selection_results(linear_dir)
|
|
990
|
+
powerlaw_results = load_model_selection_results(powerlaw_dir)
|
|
991
|
+
|
|
992
|
+
print(f"Loaded {len(linear_results)} linear results")
|
|
993
|
+
print(f"Loaded {len(powerlaw_results)} powerlaw results")
|
|
994
|
+
|
|
995
|
+
# Analyze results
|
|
996
|
+
linear_success = sum(1 for r in linear_results if r['selected_model'] == 'linear')
|
|
997
|
+
linear_failure = len(linear_results) - linear_success
|
|
998
|
+
|
|
999
|
+
powerlaw_success = sum(1 for r in powerlaw_results if r['selected_model'] == 'power_law')
|
|
1000
|
+
powerlaw_failure = len(powerlaw_results) - powerlaw_success
|
|
1001
|
+
|
|
1002
|
+
# Create confusion matrix
|
|
1003
|
+
# Format: [True Linear, False Linear], [False Powerlaw, True Powerlaw]
|
|
1004
|
+
# Transpose to flip axes: true model on x-axis, predicted model on y-axis
|
|
1005
|
+
confusion_matrix = np.array([
|
|
1006
|
+
[linear_success, linear_failure], # Predicted Linear: [True Linear, False Linear]
|
|
1007
|
+
[powerlaw_failure, powerlaw_success] # Predicted Powerlaw: [False Powerlaw, True Powerlaw]
|
|
1008
|
+
])
|
|
1009
|
+
|
|
1010
|
+
# Create the plot
|
|
1011
|
+
fig, ax = plt.subplots(figsize=(8, 6))
|
|
1012
|
+
|
|
1013
|
+
# Create custom colormap from white to UK blue
|
|
1014
|
+
|
|
1015
|
+
colors = ['white', '#76d6ff']
|
|
1016
|
+
n_bins = 30
|
|
1017
|
+
cmap = LinearSegmentedColormap.from_list('custom', colors, N=n_bins)
|
|
1018
|
+
|
|
1019
|
+
# Create heatmap
|
|
1020
|
+
im = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap)
|
|
1021
|
+
|
|
1022
|
+
# Add colorbar
|
|
1023
|
+
cbar = ax.figure.colorbar(im, ax=ax)
|
|
1024
|
+
cbar.set_label('Count', rotation=270, labelpad=20)
|
|
1025
|
+
|
|
1026
|
+
# Set ticks and labels
|
|
1027
|
+
ax.set_xticks([0, 1])
|
|
1028
|
+
ax.set_yticks([0, 1])
|
|
1029
|
+
ax.set_xticklabels(['Linear', 'Power Law']) # Actual model (x-axis)
|
|
1030
|
+
ax.set_yticklabels(['Linear', 'Power Law']) # Predicted model (y-axis)
|
|
1031
|
+
|
|
1032
|
+
# Add text annotations
|
|
1033
|
+
thresh = confusion_matrix.max() / 2.
|
|
1034
|
+
for i in range(confusion_matrix.shape[0]):
|
|
1035
|
+
for j in range(confusion_matrix.shape[1]):
|
|
1036
|
+
ax.text(j, i, format(confusion_matrix[i, j], 'd'),
|
|
1037
|
+
ha="center", va="center",
|
|
1038
|
+
color="white" if confusion_matrix[i, j] > thresh else "black",
|
|
1039
|
+
fontsize=16, fontweight='bold')
|
|
1040
|
+
|
|
1041
|
+
# Labels and title
|
|
1042
|
+
ax.set_xlabel('Actual Model', fontsize=16)
|
|
1043
|
+
ax.set_ylabel('Predicted Model', fontsize=16)
|
|
1044
|
+
ax.set_title('Model Selection Confusion Matrix', fontsize=18, fontweight='bold')
|
|
1045
|
+
|
|
1046
|
+
# Calculate and display accuracy
|
|
1047
|
+
total_tests = len(linear_results) + len(powerlaw_results)
|
|
1048
|
+
total_successes = linear_success + powerlaw_success
|
|
1049
|
+
overall_accuracy = total_successes / total_tests if total_tests > 0 else 0
|
|
1050
|
+
|
|
1051
|
+
# Add accuracy text
|
|
1052
|
+
ax.text(0.5, -0.25, f'Overall Accuracy: {overall_accuracy:.3f} ({total_successes}/{total_tests})',
|
|
1053
|
+
transform=ax.transAxes, ha='center', fontsize=14, fontweight='bold')
|
|
1054
|
+
|
|
1055
|
+
plt.tight_layout()
|
|
1056
|
+
|
|
1057
|
+
if export:
|
|
1058
|
+
fig.savefig(
|
|
1059
|
+
"share/confusion_matrix_heatmap.pdf",
|
|
1060
|
+
dpi=400,
|
|
1061
|
+
bbox_inches="tight",
|
|
1062
|
+
)
|
|
1063
|
+
print("Confusion matrix saved as share/confusion_matrix_heatmap.pdf")
|
|
1064
|
+
|
|
1065
|
+
if show:
|
|
1066
|
+
plt.show()
|
|
1067
|
+
|
|
701
1068
|
#´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
|
|
702
1069
|
# MAIN #
|
|
703
1070
|
#.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•#
|
|
@@ -733,22 +1100,25 @@ def main(export: bool = False) -> None:
|
|
|
733
1100
|
# Main plot
|
|
734
1101
|
plot_main_figure(df, models, export=export)
|
|
735
1102
|
|
|
736
|
-
# Size plot
|
|
1103
|
+
# # Size plot
|
|
737
1104
|
size_plot(df, export=export)
|
|
738
1105
|
|
|
739
|
-
# Anomalous diffusion plot
|
|
1106
|
+
# # Anomalous diffusion plot
|
|
740
1107
|
anomalous_diffusion_plot(export=export)
|
|
741
1108
|
|
|
742
|
-
# Synthetic data plot
|
|
1109
|
+
# # Synthetic data plot
|
|
743
1110
|
synth_df = load_synthetic_data_df()
|
|
744
1111
|
synth_models = load_synthetic_data_models()
|
|
745
1112
|
synthetic_data_plot(synth_df, synth_models, export=export)
|
|
746
1113
|
|
|
747
|
-
# UK time windows plot
|
|
1114
|
+
# # UK time windows plot
|
|
748
1115
|
additional_uk_stats = load_additional_uk_stats()
|
|
749
1116
|
additional_uk_models = load_additional_uk_models()
|
|
750
1117
|
plot_uk_time_windows(additional_uk_stats, additional_uk_models, export=export)
|
|
751
1118
|
|
|
1119
|
+
# # Confusion matrix plot
|
|
1120
|
+
create_confusion_matrix_plot(export=export)
|
|
1121
|
+
|
|
752
1122
|
|
|
753
1123
|
if __name__ == "__main__":
|
|
754
1124
|
|