PyEvoMotion 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- PyEvoMotion/cli.py +88 -11
- PyEvoMotion/core/base.py +373 -34
- PyEvoMotion/core/core.py +136 -43
- PyEvoMotion/core/parser.py +4 -1
- {pyevomotion-0.1.0.dist-info → pyevomotion-0.1.2.dist-info}/METADATA +72 -4
- pyevomotion-0.1.2.dist-info/RECORD +35 -0
- share/analyze_model_selection_accuracy.py +316 -0
- share/analyze_test_runs.py +436 -0
- share/anomalous_diffusion.pdf +0 -0
- share/confusion_matrix_heatmap.pdf +0 -0
- share/figUK.tsv +9949 -0
- share/figUK_plots.pdf +0 -0
- share/figUK_regression_results.json +65 -0
- share/figUK_run_args.json +14 -0
- share/figUK_stats.tsv +41 -0
- share/figUSA.tsv +9470 -0
- share/figUSA_plots.pdf +0 -0
- share/figUSA_regression_results.json +65 -0
- share/figUSA_run_args.json +14 -0
- share/figUSA_stats.tsv +34 -0
- share/figdataUK.tsv +10001 -0
- share/figdataUSA.tsv +10001 -0
- share/generate_sequences_from_synthdata.py +85 -0
- share/generate_sequences_from_test5_data.py +107 -0
- share/manuscript_figure.py +858 -43
- share/run_parallel_analysis.py +196 -0
- share/synth_figure.pdf +0 -0
- share/uk_time_windows.pdf +0 -0
- share/weekly_size.pdf +0 -0
- pyevomotion-0.1.0.dist-info/RECORD +0 -13
- {pyevomotion-0.1.0.dist-info → pyevomotion-0.1.2.dist-info}/WHEEL +0 -0
- {pyevomotion-0.1.0.dist-info → pyevomotion-0.1.2.dist-info}/entry_points.txt +0 -0
share/manuscript_figure.py
CHANGED
|
@@ -4,12 +4,25 @@ import json
|
|
|
4
4
|
import zipfile
|
|
5
5
|
import warnings
|
|
6
6
|
import urllib.request
|
|
7
|
+
import subprocess
|
|
7
8
|
|
|
8
9
|
import numpy as np
|
|
9
10
|
import pandas as pd
|
|
10
11
|
import matplotlib as mpl
|
|
11
12
|
import matplotlib.pyplot as plt
|
|
13
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
12
14
|
|
|
15
|
+
#´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
|
|
16
|
+
# CONSTANTS #
|
|
17
|
+
#.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•#
|
|
18
|
+
|
|
19
|
+
COLORS = {
|
|
20
|
+
"UK": "#76d6ff",
|
|
21
|
+
"USA": "#FF6346",
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
# Control confidence interval plotting
|
|
25
|
+
PLOT_CONFIDENCE_INTERVALS = False
|
|
13
26
|
|
|
14
27
|
#´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
|
|
15
28
|
# FUNCTIONS #
|
|
@@ -160,36 +173,12 @@ def load_models() -> dict[str, dict[str, callable]]:
|
|
|
160
173
|
|
|
161
174
|
return {
|
|
162
175
|
"USA": {
|
|
163
|
-
"mean": [
|
|
164
|
-
|
|
165
|
-
_contents["USA"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
|
|
166
|
-
+ _contents["USA"]["mean number of mutations per 7D model"]["parameters"]["b"]
|
|
167
|
-
),
|
|
168
|
-
_contents["USA"]["mean number of mutations per 7D model"]["r2"],
|
|
169
|
-
],
|
|
170
|
-
"var": [
|
|
171
|
-
lambda x: (
|
|
172
|
-
_contents["USA"]["scaled var number of mutations per 7D model"]["parameters"]["d"]
|
|
173
|
-
*(x**_contents["USA"]["scaled var number of mutations per 7D model"]["parameters"]["alpha"])
|
|
174
|
-
),
|
|
175
|
-
_contents["USA"]["scaled var number of mutations per 7D model"]["r2"],
|
|
176
|
-
]
|
|
176
|
+
"mean": list(_get_mean_model(_contents["USA"], "USA")),
|
|
177
|
+
"var": list(_get_var_model(_contents["USA"], "USA"))
|
|
177
178
|
},
|
|
178
179
|
"UK": {
|
|
179
|
-
"mean": [
|
|
180
|
-
|
|
181
|
-
_contents["UK"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
|
|
182
|
-
+ _contents["UK"]["mean number of mutations per 7D model"]["parameters"]["b"]
|
|
183
|
-
),
|
|
184
|
-
_contents["UK"]["mean number of mutations per 7D model"]["r2"],
|
|
185
|
-
],
|
|
186
|
-
"var": [
|
|
187
|
-
lambda x: (
|
|
188
|
-
_contents["UK"]["scaled var number of mutations per 7D model"]["parameters"]["d"]
|
|
189
|
-
*(x**_contents["UK"]["scaled var number of mutations per 7D model"]["parameters"]["alpha"])
|
|
190
|
-
),
|
|
191
|
-
_contents["UK"]["scaled var number of mutations per 7D model"]["r2"],
|
|
192
|
-
]
|
|
180
|
+
"mean": list(_get_mean_model(_contents["UK"], "UK")),
|
|
181
|
+
"var": list(_get_var_model(_contents["UK"], "UK"))
|
|
193
182
|
},
|
|
194
183
|
}
|
|
195
184
|
|
|
@@ -202,15 +191,66 @@ def safe_map(f: callable, x: list[int | float]) -> list[int | float]:
|
|
|
202
191
|
_results.append(None)
|
|
203
192
|
return _results
|
|
204
193
|
|
|
205
|
-
|
|
194
|
+
|
|
195
|
+
def _calculate_confidence_bounds(x_values: np.ndarray, model_func: callable, confidence_intervals: dict, model_type: str) -> tuple[np.ndarray, np.ndarray]:
|
|
196
|
+
"""Calculate confidence interval bounds for a model.
|
|
197
|
+
|
|
198
|
+
:param x_values: X values to calculate bounds for
|
|
199
|
+
:type x_values: np.ndarray
|
|
200
|
+
:param model_func: The model function
|
|
201
|
+
:type model_func: callable
|
|
202
|
+
:param confidence_intervals: Dictionary of confidence intervals for parameters
|
|
203
|
+
:type confidence_intervals: dict
|
|
204
|
+
:param model_type: Type of model ('linear_mean', 'linear_var', 'power_law')
|
|
205
|
+
:type model_type: str
|
|
206
|
+
:return: Tuple of (lower_bounds, upper_bounds)
|
|
207
|
+
:rtype: tuple[np.ndarray, np.ndarray]
|
|
208
|
+
"""
|
|
209
|
+
if not confidence_intervals:
|
|
210
|
+
# No confidence intervals available, return None bounds
|
|
211
|
+
return None, None
|
|
212
|
+
|
|
213
|
+
if model_type == "linear_mean":
|
|
214
|
+
# For linear mean model: mx + b
|
|
215
|
+
if "m" in confidence_intervals and "b" in confidence_intervals:
|
|
216
|
+
m_lower, m_upper = confidence_intervals["m"]
|
|
217
|
+
b_lower, b_upper = confidence_intervals["b"]
|
|
218
|
+
|
|
219
|
+
# Calculate bounds for the linear function
|
|
220
|
+
lower_bounds = m_lower * x_values + b_lower
|
|
221
|
+
upper_bounds = m_upper * x_values + b_upper
|
|
222
|
+
|
|
223
|
+
return lower_bounds, upper_bounds
|
|
224
|
+
|
|
225
|
+
elif model_type == "linear_var":
|
|
226
|
+
# For linear variance model: mx
|
|
227
|
+
if "m" in confidence_intervals:
|
|
228
|
+
m_lower, m_upper = confidence_intervals["m"]
|
|
229
|
+
|
|
230
|
+
lower_bounds = m_lower * x_values
|
|
231
|
+
upper_bounds = m_upper * x_values
|
|
232
|
+
|
|
233
|
+
return lower_bounds, upper_bounds
|
|
234
|
+
|
|
235
|
+
elif model_type == "power_law":
|
|
236
|
+
# For power law model: d*x^alpha
|
|
237
|
+
if "d" in confidence_intervals and "alpha" in confidence_intervals:
|
|
238
|
+
d_lower, d_upper = confidence_intervals["d"]
|
|
239
|
+
alpha_lower, alpha_upper = confidence_intervals["alpha"]
|
|
240
|
+
|
|
241
|
+
# For power law, we need to be careful about the bounds
|
|
242
|
+
# We'll use the parameter bounds to create approximate confidence bounds
|
|
243
|
+
lower_bounds = d_lower * (x_values ** alpha_lower)
|
|
244
|
+
upper_bounds = d_upper * (x_values ** alpha_upper)
|
|
245
|
+
|
|
246
|
+
return lower_bounds, upper_bounds
|
|
247
|
+
|
|
248
|
+
return None, None
|
|
249
|
+
|
|
250
|
+
def plot_main_figure(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: bool = True) -> None:
|
|
206
251
|
set_matplotlib_global_params()
|
|
207
252
|
fig, ax = plt.subplots(2, 1, figsize=(6, 10))
|
|
208
253
|
|
|
209
|
-
colors = {
|
|
210
|
-
"UK": "#76d6ff",
|
|
211
|
-
"USA": "#FF6346",
|
|
212
|
-
}
|
|
213
|
-
|
|
214
254
|
for idx, case in enumerate(("mean", "var")):
|
|
215
255
|
for col in (f"{case} number of mutations USA", f"{case} number of mutations UK"):
|
|
216
256
|
|
|
@@ -219,20 +259,54 @@ def plot(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: b
|
|
|
219
259
|
ax[idx].scatter(
|
|
220
260
|
df.index,
|
|
221
261
|
df[col] - (df[col].min() if idx == 1 else 0),
|
|
222
|
-
color=
|
|
262
|
+
color=COLORS[_country],
|
|
223
263
|
edgecolor="k",
|
|
224
264
|
zorder=2,
|
|
225
265
|
)
|
|
226
266
|
|
|
227
|
-
_x = np.arange(-10,
|
|
267
|
+
_x = np.arange(-10, 60, 0.5)
|
|
268
|
+
_x_shifted = _x + (8 if _country == "USA" else 0)
|
|
269
|
+
|
|
270
|
+
# Plot the main model line
|
|
228
271
|
ax[idx].plot(
|
|
229
|
-
|
|
272
|
+
_x_shifted,
|
|
230
273
|
safe_map(models[_country][case][0], _x),
|
|
231
|
-
color=
|
|
274
|
+
color=COLORS[_country],
|
|
232
275
|
label=rf"{_country} ($R^2 = {round(models[_country][case][1], 2):.2f})$",
|
|
233
276
|
linewidth=3,
|
|
234
277
|
zorder=1,
|
|
235
278
|
)
|
|
279
|
+
|
|
280
|
+
# Plot confidence intervals if available and enabled
|
|
281
|
+
if PLOT_CONFIDENCE_INTERVALS and len(models[_country][case]) > 2 and models[_country][case][2]:
|
|
282
|
+
confidence_intervals = models[_country][case][2]
|
|
283
|
+
|
|
284
|
+
# Determine model type for confidence interval calculation
|
|
285
|
+
if case == "mean":
|
|
286
|
+
model_type = "linear_mean"
|
|
287
|
+
else: # case == "var"
|
|
288
|
+
# Check if it's linear or power law based on the model function
|
|
289
|
+
# We'll determine this by checking if the model has 'alpha' parameter
|
|
290
|
+
if "alpha" in confidence_intervals:
|
|
291
|
+
model_type = "power_law"
|
|
292
|
+
else:
|
|
293
|
+
model_type = "linear_var"
|
|
294
|
+
|
|
295
|
+
lower_bounds, upper_bounds = _calculate_confidence_bounds(
|
|
296
|
+
_x, models[_country][case][0], confidence_intervals, model_type
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
if lower_bounds is not None and upper_bounds is not None:
|
|
300
|
+
# Plot confidence interval as filled area
|
|
301
|
+
# The x-axis shift is already applied to _x_shifted, so we use the original bounds
|
|
302
|
+
ax[idx].fill_between(
|
|
303
|
+
_x_shifted,
|
|
304
|
+
lower_bounds,
|
|
305
|
+
upper_bounds,
|
|
306
|
+
color=COLORS[_country],
|
|
307
|
+
alpha=0.2,
|
|
308
|
+
zorder=0,
|
|
309
|
+
)
|
|
236
310
|
|
|
237
311
|
# Styling
|
|
238
312
|
ax[idx].set_xlim(-0.5, 40.5)
|
|
@@ -260,13 +334,736 @@ def plot(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: b
|
|
|
260
334
|
|
|
261
335
|
if export:
|
|
262
336
|
fig.savefig(
|
|
263
|
-
"share/figure.
|
|
337
|
+
"share/figure.eps",
|
|
338
|
+
dpi=400,
|
|
339
|
+
bbox_inches="tight",
|
|
340
|
+
)
|
|
341
|
+
print("Figure saved as share/figure.eps")
|
|
342
|
+
|
|
343
|
+
if show: plt.show()
|
|
344
|
+
|
|
345
|
+
def size_plot(df: pd.DataFrame, export: bool = False, show: bool = True) -> None:
|
|
346
|
+
set_matplotlib_global_params()
|
|
347
|
+
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
|
|
348
|
+
|
|
349
|
+
# Plot UK first
|
|
350
|
+
markerline, stemlines, baseline = ax.stem(df.index, df[f"size UK"], label="UK")
|
|
351
|
+
plt.setp(stemlines, color=COLORS["UK"])
|
|
352
|
+
plt.setp(markerline, color=COLORS["UK"], markeredgecolor="k")
|
|
353
|
+
plt.setp(baseline, color="#ffffff")
|
|
354
|
+
|
|
355
|
+
# Plot USA
|
|
356
|
+
markerline, stemlines, baseline = ax.stem(df.index, df[f"size USA"], label="USA")
|
|
357
|
+
plt.setp(stemlines, color=COLORS["USA"])
|
|
358
|
+
plt.setp(markerline, color=COLORS["USA"], markeredgecolor="k")
|
|
359
|
+
plt.setp(baseline, color="#ffffff")
|
|
360
|
+
|
|
361
|
+
# Plot UK again but with slight transparency on the stem
|
|
362
|
+
markerline, stemlines, baseline = ax.stem(df.index, df[f"size UK"])
|
|
363
|
+
plt.setp(stemlines, color=COLORS["UK"], alpha=0.5)
|
|
364
|
+
plt.setp(markerline, color=COLORS["UK"], markeredgecolor="#000000")
|
|
365
|
+
plt.setp(baseline, color="#ffffff")
|
|
366
|
+
|
|
367
|
+
ax.set_ylim(0, 405)
|
|
368
|
+
ax.set_xlim(-0.5, 40.5)
|
|
369
|
+
|
|
370
|
+
ax.set_xlabel("time (wk)")
|
|
371
|
+
ax.set_ylabel("Number of sequences")
|
|
372
|
+
|
|
373
|
+
ax.legend(
|
|
374
|
+
fontsize=16,
|
|
375
|
+
loc="upper right",
|
|
376
|
+
bbox_to_anchor=(1.08, 1.08)
|
|
377
|
+
)
|
|
378
|
+
|
|
379
|
+
if export:
|
|
380
|
+
fig.savefig(
|
|
381
|
+
"share/weekly_size.pdf",
|
|
382
|
+
dpi=400,
|
|
383
|
+
bbox_inches="tight",
|
|
384
|
+
)
|
|
385
|
+
print("Figure saved as share/weekly_size.pdf")
|
|
386
|
+
|
|
387
|
+
if show: plt.show()
|
|
388
|
+
|
|
389
|
+
def anomalous_diffusion_plot(export: bool = False, show: bool = True) -> None:
|
|
390
|
+
set_matplotlib_global_params()
|
|
391
|
+
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
|
|
392
|
+
|
|
393
|
+
x = np.linspace(0, 10, 100)
|
|
394
|
+
|
|
395
|
+
plt.plot(x, x**0.8, label=r"$\alpha = 0.8$" + "\n(subdiffusion)", color=COLORS["UK"], linewidth=3)
|
|
396
|
+
plt.plot(x, x**1, label=r"$\alpha = 1$" + "\n(normal diffusion)", color="#000000", linewidth=3)
|
|
397
|
+
plt.plot(x, x**1.2, label=r"$\alpha = 1.2$" + "\n(superdiffusion)", color=COLORS["USA"], linewidth=3)
|
|
398
|
+
|
|
399
|
+
plt.legend(
|
|
400
|
+
fontsize=13,
|
|
401
|
+
loc="upper left",
|
|
402
|
+
title=r"variance $\propto \text{time}^\alpha$",
|
|
403
|
+
title_fontsize=15
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
plt.xlabel("time")
|
|
407
|
+
plt.ylabel("variance")
|
|
408
|
+
|
|
409
|
+
ax.set_xticks([])
|
|
410
|
+
ax.set_yticks([])
|
|
411
|
+
|
|
412
|
+
plt.xlim(0, 10)
|
|
413
|
+
plt.ylim(0, 10)
|
|
414
|
+
|
|
415
|
+
if export:
|
|
416
|
+
fig.savefig(
|
|
417
|
+
"share/anomalous_diffusion.pdf",
|
|
418
|
+
dpi=400,
|
|
419
|
+
bbox_inches="tight",
|
|
420
|
+
)
|
|
421
|
+
print("Figure saved as share/anomalous_diffusion.pdf")
|
|
422
|
+
|
|
423
|
+
if show: plt.show()
|
|
424
|
+
|
|
425
|
+
def check_synthetic_data_exists() -> bool:
|
|
426
|
+
"""
|
|
427
|
+
Check if the synthetic data output files exist.
|
|
428
|
+
"""
|
|
429
|
+
_files = [
|
|
430
|
+
"tests/data/test4/synthdata1_out_stats.tsv",
|
|
431
|
+
"tests/data/test4/synthdata2_out_stats.tsv",
|
|
432
|
+
"tests/data/test4/synthdata1_out_regression_results.json",
|
|
433
|
+
"tests/data/test4/synthdata2_out_regression_results.json"
|
|
434
|
+
]
|
|
435
|
+
|
|
436
|
+
for file in _files:
|
|
437
|
+
if not os.path.exists(file):
|
|
438
|
+
return False
|
|
439
|
+
|
|
440
|
+
return True
|
|
441
|
+
|
|
442
|
+
def run_synthetic_data_tests() -> None:
|
|
443
|
+
"""
|
|
444
|
+
Run the synthetic data tests to generate the required files.
|
|
445
|
+
"""
|
|
446
|
+
print("Running synthetic data tests to generate required files...")
|
|
447
|
+
|
|
448
|
+
# Create output directory
|
|
449
|
+
os.makedirs("tests/data/test4", exist_ok=True)
|
|
450
|
+
|
|
451
|
+
# Run tests for S1 dataset
|
|
452
|
+
result1 = subprocess.run(
|
|
453
|
+
[
|
|
454
|
+
"PyEvoMotion",
|
|
455
|
+
"tests/data/test4/S1.fasta",
|
|
456
|
+
"tests/data/test4/S1.tsv",
|
|
457
|
+
"tests/data/test4/synthdata1_out",
|
|
458
|
+
"-ep"
|
|
459
|
+
],
|
|
460
|
+
stdout=subprocess.PIPE,
|
|
461
|
+
stderr=subprocess.PIPE,
|
|
462
|
+
text=True
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
if result1.stderr:
|
|
466
|
+
print(result1.stdout)
|
|
467
|
+
print(result1.stderr)
|
|
468
|
+
raise RuntimeError("Failed to process S1 dataset")
|
|
469
|
+
|
|
470
|
+
# Run tests for S2 dataset
|
|
471
|
+
result2 = subprocess.run(
|
|
472
|
+
[
|
|
473
|
+
"PyEvoMotion",
|
|
474
|
+
"tests/data/test4/S2.fasta",
|
|
475
|
+
"tests/data/test4/S2.tsv",
|
|
476
|
+
"tests/data/test4/synthdata2_out",
|
|
477
|
+
"-ep"
|
|
478
|
+
],
|
|
479
|
+
stdout=subprocess.PIPE,
|
|
480
|
+
stderr=subprocess.PIPE,
|
|
481
|
+
text=True
|
|
482
|
+
)
|
|
483
|
+
|
|
484
|
+
if result2.stderr:
|
|
485
|
+
print(result2.stdout)
|
|
486
|
+
print(result2.stderr)
|
|
487
|
+
raise RuntimeError("Failed to process S2 dataset")
|
|
488
|
+
|
|
489
|
+
def load_synthetic_data_df() -> pd.DataFrame:
|
|
490
|
+
if not check_synthetic_data_exists():
|
|
491
|
+
run_synthetic_data_tests()
|
|
492
|
+
|
|
493
|
+
return pd.read_csv(
|
|
494
|
+
"tests/data/test4/synthdata1_out_stats.tsv",
|
|
495
|
+
sep="\t",
|
|
496
|
+
).merge(
|
|
497
|
+
pd.read_csv(
|
|
498
|
+
"tests/data/test4/synthdata2_out_stats.tsv",
|
|
499
|
+
sep="\t",
|
|
500
|
+
),
|
|
501
|
+
on="date",
|
|
502
|
+
how="outer",
|
|
503
|
+
suffixes=(" synt1", " synt2"),
|
|
504
|
+
)
|
|
505
|
+
|
|
506
|
+
def _get_mean_model(data: dict, kind: str) -> tuple[callable, float, dict]:
|
|
507
|
+
"""Extract mean model from data, handling both old and new formats.
|
|
508
|
+
|
|
509
|
+
:param data: The regression results data dictionary
|
|
510
|
+
:type data: dict
|
|
511
|
+
:param kind: The dataset kind identifier (for error messages)
|
|
512
|
+
:type kind: str
|
|
513
|
+
:return: Tuple of (lambda function, r2 value, confidence intervals)
|
|
514
|
+
:rtype: tuple[callable, float, dict]
|
|
515
|
+
"""
|
|
516
|
+
# Try different possible key formats
|
|
517
|
+
possible_keys = [
|
|
518
|
+
"mean number of mutations model", # New format (current)
|
|
519
|
+
"mean number of mutations per 7D model", # Old format
|
|
520
|
+
"mean number of substitutions model", # Alternative format
|
|
521
|
+
"mean number of substitutions per 7D model" # Alternative old format
|
|
522
|
+
]
|
|
523
|
+
|
|
524
|
+
for mean_key in possible_keys:
|
|
525
|
+
if mean_key in data:
|
|
526
|
+
params = data[mean_key]["parameters"]
|
|
527
|
+
r2 = data[mean_key]["r2"]
|
|
528
|
+
confidence_intervals = data[mean_key].get("confidence_intervals", {})
|
|
529
|
+
return lambda x: params["m"] * x + params["b"], r2, confidence_intervals
|
|
530
|
+
|
|
531
|
+
raise KeyError(f"Could not find mean model in {kind} data. Available keys: {list(data.keys())}")
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
def _get_var_model(data: dict, kind: str) -> tuple[callable, float, dict]:
|
|
535
|
+
"""Extract variance model from data, handling both old and new formats.
|
|
536
|
+
|
|
537
|
+
:param data: The regression results data dictionary
|
|
538
|
+
:type data: dict
|
|
539
|
+
:param kind: The dataset kind identifier (for error messages)
|
|
540
|
+
:type kind: str
|
|
541
|
+
:return: Tuple of (lambda function, r2 value, confidence intervals)
|
|
542
|
+
:rtype: tuple[callable, float, dict]
|
|
543
|
+
"""
|
|
544
|
+
# Try different possible key formats
|
|
545
|
+
possible_keys = [
|
|
546
|
+
"scaled var number of mutations model", # New format (current)
|
|
547
|
+
"scaled var number of mutations per 7D model", # Old format
|
|
548
|
+
"scaled var number of substitutions model", # Alternative format
|
|
549
|
+
"scaled var number of substitutions per 7D model" # Alternative old format
|
|
550
|
+
]
|
|
551
|
+
|
|
552
|
+
for var_key in possible_keys:
|
|
553
|
+
if var_key in data:
|
|
554
|
+
# Check if it has model_selection (new format)
|
|
555
|
+
if "model_selection" in data[var_key]:
|
|
556
|
+
# New format with model selection
|
|
557
|
+
model_selection = data[var_key]["model_selection"]
|
|
558
|
+
selected = model_selection["selected"]
|
|
559
|
+
|
|
560
|
+
if selected == "linear" and "linear_model" in data[var_key]:
|
|
561
|
+
# Use linear model
|
|
562
|
+
linear_model = data[var_key]["linear_model"]
|
|
563
|
+
params = linear_model["parameters"]
|
|
564
|
+
r2 = linear_model["r2"]
|
|
565
|
+
confidence_intervals = linear_model.get("confidence_intervals", {})
|
|
566
|
+
return lambda x: params["m"] * x, r2, confidence_intervals
|
|
567
|
+
elif selected == "power_law" and "power_law_model" in data[var_key]:
|
|
568
|
+
# Use power law model
|
|
569
|
+
power_law_model = data[var_key]["power_law_model"]
|
|
570
|
+
params = power_law_model["parameters"]
|
|
571
|
+
r2 = power_law_model["r2"]
|
|
572
|
+
confidence_intervals = power_law_model.get("confidence_intervals", {})
|
|
573
|
+
return lambda x: params["d"] * (x ** params["alpha"]), r2, confidence_intervals
|
|
574
|
+
else:
|
|
575
|
+
# Old format or new format without model selection - direct parameters
|
|
576
|
+
params = data[var_key]["parameters"]
|
|
577
|
+
r2 = data[var_key]["r2"]
|
|
578
|
+
confidence_intervals = data[var_key].get("confidence_intervals", {})
|
|
579
|
+
if "m" in params:
|
|
580
|
+
# Linear model: mx
|
|
581
|
+
return lambda x: params["m"] * x, r2, confidence_intervals
|
|
582
|
+
elif "d" in params and "alpha" in params:
|
|
583
|
+
# Power law model: d*x^alpha
|
|
584
|
+
return lambda x: params["d"] * (x ** params["alpha"]), r2, confidence_intervals
|
|
585
|
+
|
|
586
|
+
raise KeyError(f"Could not find variance model in {kind} data. Available keys: {list(data.keys())}")
|
|
587
|
+
|
|
588
|
+
|
|
589
|
+
def load_synthetic_data_models() -> dict[str, dict[str, callable]]:
|
|
590
|
+
if not check_synthetic_data_exists():
|
|
591
|
+
run_synthetic_data_tests()
|
|
592
|
+
|
|
593
|
+
_kinds = ("synt1", "synt2")
|
|
594
|
+
_file = "tests/data/test4/synthdata{}_out_regression_results.json"
|
|
595
|
+
|
|
596
|
+
_contents = {}
|
|
597
|
+
|
|
598
|
+
for k in _kinds:
|
|
599
|
+
with open(_file.format(k[-1])) as f:
|
|
600
|
+
_contents[k] = json.load(f)
|
|
601
|
+
|
|
602
|
+
return {
|
|
603
|
+
"synt1": {
|
|
604
|
+
"mean": list(_get_mean_model(_contents["synt1"], "synt1")),
|
|
605
|
+
"var": list(_get_var_model(_contents["synt1"], "synt1"))
|
|
606
|
+
},
|
|
607
|
+
"synt2": {
|
|
608
|
+
"mean": list(_get_mean_model(_contents["synt2"], "synt2")),
|
|
609
|
+
"var": list(_get_var_model(_contents["synt2"], "synt2"))
|
|
610
|
+
},
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
def synthetic_data_plot(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: bool = True) -> None:
|
|
614
|
+
set_matplotlib_global_params()
|
|
615
|
+
fig, ax = plt.subplots(2, 2, figsize=(12, 10))
|
|
616
|
+
|
|
617
|
+
# Flatten axes for easier iteration
|
|
618
|
+
ax = ax.flatten()
|
|
619
|
+
|
|
620
|
+
# Plot counter for subplot index
|
|
621
|
+
plot_idx = 0
|
|
622
|
+
|
|
623
|
+
for case in ("mean", "var"):
|
|
624
|
+
for col in (f"{case} number of mutations synt1", f"{case} number of mutations synt2"):
|
|
625
|
+
_type = col.split()[-1].upper()
|
|
626
|
+
|
|
627
|
+
# Scatter plot
|
|
628
|
+
ax[plot_idx].scatter(
|
|
629
|
+
df.index,
|
|
630
|
+
df[col],
|
|
631
|
+
color="#76d6ff",
|
|
632
|
+
edgecolor="k",
|
|
633
|
+
zorder=2,
|
|
634
|
+
)
|
|
635
|
+
|
|
636
|
+
# Line plot
|
|
637
|
+
_x = np.arange(-10, 50, 0.5)
|
|
638
|
+
ax[plot_idx].plot(
|
|
639
|
+
_x,
|
|
640
|
+
safe_map(models[_type.lower()][case][0], _x),
|
|
641
|
+
color="#76d6ff",
|
|
642
|
+
label=rf"$R^2 = {round(models[_type.lower()][case][1], 2):.2f}$",
|
|
643
|
+
linewidth=3,
|
|
644
|
+
zorder=1,
|
|
645
|
+
)
|
|
646
|
+
|
|
647
|
+
# Plot confidence intervals if available and enabled
|
|
648
|
+
if PLOT_CONFIDENCE_INTERVALS and len(models[_type.lower()][case]) > 2 and models[_type.lower()][case][2]:
|
|
649
|
+
confidence_intervals = models[_type.lower()][case][2]
|
|
650
|
+
|
|
651
|
+
# Determine model type for confidence interval calculation
|
|
652
|
+
if case == "mean":
|
|
653
|
+
model_type = "linear_mean"
|
|
654
|
+
else: # case == "var"
|
|
655
|
+
# Check if it's linear or power law based on the model function
|
|
656
|
+
if "alpha" in confidence_intervals:
|
|
657
|
+
model_type = "power_law"
|
|
658
|
+
else:
|
|
659
|
+
model_type = "linear_var"
|
|
660
|
+
|
|
661
|
+
lower_bounds, upper_bounds = _calculate_confidence_bounds(
|
|
662
|
+
_x, models[_type.lower()][case][0], confidence_intervals, model_type
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
if lower_bounds is not None and upper_bounds is not None:
|
|
666
|
+
# Plot confidence interval as filled area
|
|
667
|
+
ax[plot_idx].fill_between(
|
|
668
|
+
_x,
|
|
669
|
+
lower_bounds,
|
|
670
|
+
upper_bounds,
|
|
671
|
+
color="#76d6ff",
|
|
672
|
+
alpha=0.2,
|
|
673
|
+
zorder=0,
|
|
674
|
+
)
|
|
675
|
+
|
|
676
|
+
# Styling
|
|
677
|
+
ax[plot_idx].set_xlim(-0.5, 40.5)
|
|
678
|
+
if case == "mean":
|
|
679
|
+
ax[plot_idx].set_ylim(-0.25, 20.25)
|
|
680
|
+
ax[plot_idx].set_ylabel(f"{case} (# mutations)")
|
|
681
|
+
else: # var case
|
|
682
|
+
if _type == "SYNT1":
|
|
683
|
+
ax[plot_idx].set_ylim(-0.5, 40.5)
|
|
684
|
+
else:
|
|
685
|
+
ax[plot_idx].set_ylim(-0.1, 10.1)
|
|
686
|
+
ax[plot_idx].set_ylabel(f"{case}iance (# mutations)")
|
|
687
|
+
|
|
688
|
+
ax[plot_idx].set_xlabel("time (wk)")
|
|
689
|
+
ax[plot_idx].legend(
|
|
690
|
+
fontsize=16,
|
|
691
|
+
loc="upper left",
|
|
692
|
+
)
|
|
693
|
+
|
|
694
|
+
plot_idx += 1
|
|
695
|
+
|
|
696
|
+
fig.suptitle(" ", fontsize=1) # To get some space on top
|
|
697
|
+
fig.tight_layout()
|
|
698
|
+
|
|
699
|
+
# Add subplot annotations
|
|
700
|
+
plt.annotate("a", (0.02, 0.935), xycoords="figure fraction", fontsize=28, fontweight="bold")
|
|
701
|
+
plt.annotate("b", (0.505, 0.935), xycoords="figure fraction", fontsize=28, fontweight="bold")
|
|
702
|
+
plt.annotate("c", (0.02, 0.465), xycoords="figure fraction", fontsize=28, fontweight="bold")
|
|
703
|
+
plt.annotate("d", (0.505, 0.465), xycoords="figure fraction", fontsize=28, fontweight="bold")
|
|
704
|
+
|
|
705
|
+
if export:
|
|
706
|
+
fig.savefig(
|
|
707
|
+
"share/synth_figure.pdf",
|
|
264
708
|
dpi=400,
|
|
265
709
|
bbox_inches="tight",
|
|
266
710
|
)
|
|
267
|
-
print("Figure saved as share/
|
|
711
|
+
print("Figure saved as share/synth_figure.pdf")
|
|
268
712
|
|
|
269
713
|
if show: plt.show()
|
|
714
|
+
|
|
715
|
+
|
|
716
|
+
def load_additional_uk_stats() -> dict[str, pd.DataFrame]:
|
|
717
|
+
"""
|
|
718
|
+
Load the additional UK stats files for different time windows.
|
|
719
|
+
"""
|
|
720
|
+
_files = {
|
|
721
|
+
"5D": "tests/data/test3/output/20250517164757/UKout_5D_stats.tsv",
|
|
722
|
+
"10D": "tests/data/test3/output/20250517173133/UKout_10D_stats.tsv",
|
|
723
|
+
"14D": "tests/data/test3/output/20250517181004/UKout_14D_stats.tsv",
|
|
724
|
+
"7D": "share/figUK_stats.tsv"
|
|
725
|
+
}
|
|
726
|
+
|
|
727
|
+
return {
|
|
728
|
+
k: pd.read_csv(v, sep="\t")
|
|
729
|
+
for k, v in _files.items()
|
|
730
|
+
}
|
|
731
|
+
|
|
732
|
+
def load_additional_uk_models() -> dict[str, dict[str, callable]]:
|
|
733
|
+
"""
|
|
734
|
+
Load the additional UK models for different time windows.
|
|
735
|
+
"""
|
|
736
|
+
_files = {
|
|
737
|
+
"5D": "tests/data/test3/output/20250517164757/UKout_5D_regression_results.json",
|
|
738
|
+
"10D": "tests/data/test3/output/20250517173133/UKout_10D_regression_results.json",
|
|
739
|
+
"14D": "tests/data/test3/output/20250517181004/UKout_14D_regression_results.json",
|
|
740
|
+
"7D": "share/figUK_regression_results.json"
|
|
741
|
+
}
|
|
742
|
+
|
|
743
|
+
_contents = {}
|
|
744
|
+
for k, v in _files.items():
|
|
745
|
+
with open(v) as f:
|
|
746
|
+
_contents[k] = json.load(f)
|
|
747
|
+
return {
|
|
748
|
+
k: {
|
|
749
|
+
"mean": [
|
|
750
|
+
{
|
|
751
|
+
"m": _contents[k]["mean number of mutations model"]["parameters"]["m"],
|
|
752
|
+
"b": _contents[k]["mean number of mutations model"]["parameters"]["b"]
|
|
753
|
+
},
|
|
754
|
+
_contents[k]["mean number of mutations model"]["r2"],
|
|
755
|
+
_contents[k]["mean number of mutations model"]["confidence_intervals"]
|
|
756
|
+
],
|
|
757
|
+
"var": [
|
|
758
|
+
{
|
|
759
|
+
"d": _contents[k]["scaled var number of mutations model"]["power_law_model"]["parameters"]["d"],
|
|
760
|
+
"alpha": _contents[k]["scaled var number of mutations model"]["power_law_model"]["parameters"]["alpha"]
|
|
761
|
+
},
|
|
762
|
+
_contents[k]["scaled var number of mutations model"]["power_law_model"]["r2"],
|
|
763
|
+
_contents[k]["scaled var number of mutations model"]["power_law_model"]["confidence_intervals"]
|
|
764
|
+
]
|
|
765
|
+
}
|
|
766
|
+
for k in _files.keys()
|
|
767
|
+
}
|
|
768
|
+
|
|
769
|
+
def plot_uk_time_windows(stats: dict[str, pd.DataFrame], models: dict[str, dict[str, callable]], export: bool = False, show: bool = True) -> None:
|
|
770
|
+
"""
|
|
771
|
+
Plot a 1x4 subplot of UK data with different time windows.
|
|
772
|
+
|
|
773
|
+
Args:
|
|
774
|
+
stats: Dictionary of dataframes containing the stats for each time window
|
|
775
|
+
models: Dictionary of models for each time window
|
|
776
|
+
export: Whether to export the figure
|
|
777
|
+
show: Whether to show the figure
|
|
778
|
+
"""
|
|
779
|
+
set_matplotlib_global_params()
|
|
780
|
+
fig, ax = plt.subplots(2, 4, figsize=(24, 12))
|
|
781
|
+
|
|
782
|
+
# Order of time windows to plot
|
|
783
|
+
windows = ["5D", "7D", "10D", "14D"]
|
|
784
|
+
|
|
785
|
+
for idx, window in enumerate(windows):
|
|
786
|
+
df = stats[window]
|
|
787
|
+
model = models[window]
|
|
788
|
+
scaling = { # For the models to be comparable to the 7D model, we need to scale the x-axis by the square of the time window ratio
|
|
789
|
+
"5D": (5/7)**2,
|
|
790
|
+
"7D": 1,
|
|
791
|
+
"10D": (10/7)**2,
|
|
792
|
+
"14D": (14/7)**2,
|
|
793
|
+
}
|
|
794
|
+
for idx2, case in enumerate(("mean", "var")):
|
|
795
|
+
|
|
796
|
+
if case == "mean":
|
|
797
|
+
# Plot mean
|
|
798
|
+
ax[idx2, idx].scatter(
|
|
799
|
+
df["dt_idx"],
|
|
800
|
+
df["mean number of mutations"],
|
|
801
|
+
color=COLORS["UK"],
|
|
802
|
+
edgecolor="k",
|
|
803
|
+
zorder=2,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
_x = np.arange(-0.5, 51, 0.5)
|
|
807
|
+
ax[idx2, idx].plot(
|
|
808
|
+
_x,
|
|
809
|
+
model["mean"][0]["m"]*_x + model["mean"][0]["b"],
|
|
810
|
+
color=COLORS["UK"],
|
|
811
|
+
label=rf"Mean ($R^2 = {round(model['mean'][1], 2):.2f})$",
|
|
812
|
+
linewidth=3,
|
|
813
|
+
zorder=1,
|
|
814
|
+
)
|
|
815
|
+
|
|
816
|
+
# Plot confidence intervals if available and enabled
|
|
817
|
+
if PLOT_CONFIDENCE_INTERVALS and len(model["mean"]) > 2 and model["mean"][2]:
|
|
818
|
+
confidence_intervals = model["mean"][2]
|
|
819
|
+
|
|
820
|
+
# For mean, it's always linear model
|
|
821
|
+
model_type = "linear_mean"
|
|
822
|
+
|
|
823
|
+
# Calculate confidence bounds
|
|
824
|
+
lower_bounds, upper_bounds = _calculate_confidence_bounds(
|
|
825
|
+
_x, model["mean"][0], confidence_intervals, model_type
|
|
826
|
+
)
|
|
827
|
+
|
|
828
|
+
if lower_bounds is not None and upper_bounds is not None:
|
|
829
|
+
# Plot confidence interval as filled area
|
|
830
|
+
ax[idx2, idx].fill_between(
|
|
831
|
+
_x,
|
|
832
|
+
lower_bounds,
|
|
833
|
+
upper_bounds,
|
|
834
|
+
color=COLORS["UK"],
|
|
835
|
+
alpha=0.2,
|
|
836
|
+
zorder=0,
|
|
837
|
+
)
|
|
838
|
+
|
|
839
|
+
elif case == "var":
|
|
840
|
+
# Plot variance
|
|
841
|
+
ax[idx2, idx].scatter(
|
|
842
|
+
df["dt_idx"],
|
|
843
|
+
df["var number of mutations"] - df["var number of mutations"].min(),
|
|
844
|
+
color=COLORS["UK"],
|
|
845
|
+
edgecolor="k",
|
|
846
|
+
zorder=2,
|
|
847
|
+
)
|
|
848
|
+
|
|
849
|
+
ax[idx2, idx].plot(
|
|
850
|
+
_x,
|
|
851
|
+
model["var"][0]["d"]*(_x*scaling[window])**model["var"][0]["alpha"],
|
|
852
|
+
color=COLORS["UK"],
|
|
853
|
+
label=rf"Var ($R^2 = {round(model['var'][1], 2):.2f})$",
|
|
854
|
+
linewidth=3,
|
|
855
|
+
zorder=1,
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
# Plot confidence intervals if available and enabled
|
|
859
|
+
if PLOT_CONFIDENCE_INTERVALS and len(model["var"]) > 2 and model["var"][2]:
|
|
860
|
+
confidence_intervals = model["var"][2]
|
|
861
|
+
|
|
862
|
+
# For variance, it's always power law model
|
|
863
|
+
model_type = "power_law"
|
|
864
|
+
|
|
865
|
+
# Calculate confidence bounds for the scaled x values
|
|
866
|
+
scaled_x = _x * scaling[window]
|
|
867
|
+
lower_bounds, upper_bounds = _calculate_confidence_bounds(
|
|
868
|
+
scaled_x, model["var"][0], confidence_intervals, model_type
|
|
869
|
+
)
|
|
870
|
+
|
|
871
|
+
if lower_bounds is not None and upper_bounds is not None:
|
|
872
|
+
# Plot confidence interval as filled area
|
|
873
|
+
ax[idx2, idx].fill_between(
|
|
874
|
+
_x,
|
|
875
|
+
lower_bounds,
|
|
876
|
+
upper_bounds,
|
|
877
|
+
color=COLORS["UK"],
|
|
878
|
+
alpha=0.2,
|
|
879
|
+
zorder=0,
|
|
880
|
+
)
|
|
881
|
+
|
|
882
|
+
# Styling
|
|
883
|
+
ax[idx2, idx].set_xlim(-0.5, 40.5)
|
|
884
|
+
|
|
885
|
+
if case == "mean":
|
|
886
|
+
ax[idx2, idx].set_ylim(29.5, 45.5)
|
|
887
|
+
else:
|
|
888
|
+
ax[idx2, idx].set_ylim(-0.5, 10.5)
|
|
889
|
+
|
|
890
|
+
ax[idx2, idx].set_xlabel("time (wk)")
|
|
891
|
+
if idx == 0:
|
|
892
|
+
ax[idx2, idx].set_ylabel(f"{case} (# mutations)")
|
|
893
|
+
|
|
894
|
+
ax[idx2, idx].legend(
|
|
895
|
+
fontsize=12,
|
|
896
|
+
loc="upper left",
|
|
897
|
+
)
|
|
898
|
+
|
|
899
|
+
if export:
|
|
900
|
+
fig.savefig(
|
|
901
|
+
"share/uk_time_windows.pdf",
|
|
902
|
+
dpi=400,
|
|
903
|
+
bbox_inches="tight",
|
|
904
|
+
)
|
|
905
|
+
|
|
906
|
+
if show:
|
|
907
|
+
plt.show()
|
|
908
|
+
|
|
909
|
+
def load_model_selection_results(directory: str) -> list[dict]:
|
|
910
|
+
"""Load all regression results from a directory for model selection analysis.
|
|
911
|
+
|
|
912
|
+
This function recursively walks through the directory tree to find all
|
|
913
|
+
regression results JSON files, supporting both flat and nested directory structures.
|
|
914
|
+
|
|
915
|
+
Expected structure:
|
|
916
|
+
directory/
|
|
917
|
+
├── {timestamp}/
|
|
918
|
+
│ ├── {dataset_01}/
|
|
919
|
+
│ │ └── *_out_regression_results.json
|
|
920
|
+
│ ├── {dataset_02}/
|
|
921
|
+
│ │ └── *_out_regression_results.json
|
|
922
|
+
│ └── ...
|
|
923
|
+
└── (or any nested structure)
|
|
924
|
+
|
|
925
|
+
:param directory: Root directory to search for regression results
|
|
926
|
+
:type directory: str
|
|
927
|
+
:return: List of dictionaries containing model selection information
|
|
928
|
+
:rtype: list[dict]
|
|
929
|
+
"""
|
|
930
|
+
results = []
|
|
931
|
+
|
|
932
|
+
# Walk through directory tree to find all regression results files
|
|
933
|
+
# This works with any directory structure (flat or nested)
|
|
934
|
+
for root, dirs, files in os.walk(directory):
|
|
935
|
+
for file in files:
|
|
936
|
+
if file.endswith("out_regression_results.json"):
|
|
937
|
+
file_path = os.path.join(root, file)
|
|
938
|
+
try:
|
|
939
|
+
with open(file_path, 'r') as f:
|
|
940
|
+
data = json.load(f)
|
|
941
|
+
# Extract the model selection info
|
|
942
|
+
model_selection = data.get("scaled var number of substitutions model", {}).get("model_selection", {})
|
|
943
|
+
results.append({
|
|
944
|
+
'file': file_path,
|
|
945
|
+
'selected_model': model_selection.get("selected", "unknown"),
|
|
946
|
+
'linear_AIC': model_selection.get("linear_AIC", None),
|
|
947
|
+
'power_law_AIC': model_selection.get("power_law_AIC", None),
|
|
948
|
+
'delta_AIC_linear': model_selection.get("delta_AIC_linear", None),
|
|
949
|
+
'delta_AIC_power_law': model_selection.get("delta_AIC_power_law", None),
|
|
950
|
+
'akaike_weight_linear': model_selection.get("akaike_weight_linear", None),
|
|
951
|
+
'akaike_weight_power_law': model_selection.get("akaike_weight_power_law", None)
|
|
952
|
+
})
|
|
953
|
+
except Exception as e:
|
|
954
|
+
print(f"Error loading {file_path}: {e}")
|
|
955
|
+
|
|
956
|
+
return results
|
|
957
|
+
|
|
958
|
+
def create_confusion_matrix_plot(export: bool = False, show: bool = True) -> None:
|
|
959
|
+
"""
|
|
960
|
+
Create a confusion matrix plot for model selection accuracy analysis.
|
|
961
|
+
|
|
962
|
+
Analyzes regression results from test5 synthetic datasets to assess
|
|
963
|
+
model selection accuracy. Works with organized directory structure:
|
|
964
|
+
|
|
965
|
+
tests/data/test5/
|
|
966
|
+
├── linear/output/{timestamp}/
|
|
967
|
+
│ ├── linear_01/
|
|
968
|
+
│ │ └── linear_01_out_regression_results.json
|
|
969
|
+
│ ├── linear_02/
|
|
970
|
+
│ └── ...
|
|
971
|
+
└── powerlaw/output/{timestamp}/
|
|
972
|
+
├── powerlaw_01/
|
|
973
|
+
└── ...
|
|
974
|
+
|
|
975
|
+
Args:
|
|
976
|
+
export: Whether to export the figure to share/confusion_matrix_heatmap.pdf
|
|
977
|
+
show: Whether to display the figure
|
|
978
|
+
"""
|
|
979
|
+
set_matplotlib_global_params()
|
|
980
|
+
|
|
981
|
+
# Define paths - searches recursively through all subdirectories
|
|
982
|
+
base_path = "tests/data/test5"
|
|
983
|
+
linear_dir = os.path.join(base_path, "linear", "output")
|
|
984
|
+
powerlaw_dir = os.path.join(base_path, "powerlaw", "output")
|
|
985
|
+
|
|
986
|
+
print("Loading model selection results...")
|
|
987
|
+
|
|
988
|
+
# Load results from both directories
|
|
989
|
+
linear_results = load_model_selection_results(linear_dir)
|
|
990
|
+
powerlaw_results = load_model_selection_results(powerlaw_dir)
|
|
991
|
+
|
|
992
|
+
print(f"Loaded {len(linear_results)} linear results")
|
|
993
|
+
print(f"Loaded {len(powerlaw_results)} powerlaw results")
|
|
994
|
+
|
|
995
|
+
# Analyze results
|
|
996
|
+
linear_success = sum(1 for r in linear_results if r['selected_model'] == 'linear')
|
|
997
|
+
linear_failure = len(linear_results) - linear_success
|
|
998
|
+
|
|
999
|
+
powerlaw_success = sum(1 for r in powerlaw_results if r['selected_model'] == 'power_law')
|
|
1000
|
+
powerlaw_failure = len(powerlaw_results) - powerlaw_success
|
|
1001
|
+
|
|
1002
|
+
# Create confusion matrix
|
|
1003
|
+
# Format: [True Linear, False Linear], [False Powerlaw, True Powerlaw]
|
|
1004
|
+
# Transpose to flip axes: true model on x-axis, predicted model on y-axis
|
|
1005
|
+
confusion_matrix = np.array([
|
|
1006
|
+
[linear_success, linear_failure], # Predicted Linear: [True Linear, False Linear]
|
|
1007
|
+
[powerlaw_failure, powerlaw_success] # Predicted Powerlaw: [False Powerlaw, True Powerlaw]
|
|
1008
|
+
])
|
|
1009
|
+
|
|
1010
|
+
# Create the plot
|
|
1011
|
+
fig, ax = plt.subplots(figsize=(8, 6))
|
|
1012
|
+
|
|
1013
|
+
# Create custom colormap from white to UK blue
|
|
1014
|
+
|
|
1015
|
+
colors = ['white', '#76d6ff']
|
|
1016
|
+
n_bins = 30
|
|
1017
|
+
cmap = LinearSegmentedColormap.from_list('custom', colors, N=n_bins)
|
|
1018
|
+
|
|
1019
|
+
# Create heatmap
|
|
1020
|
+
im = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap)
|
|
1021
|
+
|
|
1022
|
+
# Add colorbar
|
|
1023
|
+
cbar = ax.figure.colorbar(im, ax=ax)
|
|
1024
|
+
cbar.set_label('Count', rotation=270, labelpad=20)
|
|
1025
|
+
|
|
1026
|
+
# Set ticks and labels
|
|
1027
|
+
ax.set_xticks([0, 1])
|
|
1028
|
+
ax.set_yticks([0, 1])
|
|
1029
|
+
ax.set_xticklabels(['Linear', 'Power Law']) # Actual model (x-axis)
|
|
1030
|
+
ax.set_yticklabels(['Linear', 'Power Law']) # Predicted model (y-axis)
|
|
1031
|
+
|
|
1032
|
+
# Add text annotations
|
|
1033
|
+
thresh = confusion_matrix.max() / 2.
|
|
1034
|
+
for i in range(confusion_matrix.shape[0]):
|
|
1035
|
+
for j in range(confusion_matrix.shape[1]):
|
|
1036
|
+
ax.text(j, i, format(confusion_matrix[i, j], 'd'),
|
|
1037
|
+
ha="center", va="center",
|
|
1038
|
+
color="white" if confusion_matrix[i, j] > thresh else "black",
|
|
1039
|
+
fontsize=16, fontweight='bold')
|
|
1040
|
+
|
|
1041
|
+
# Labels and title
|
|
1042
|
+
ax.set_xlabel('Actual Model', fontsize=16)
|
|
1043
|
+
ax.set_ylabel('Predicted Model', fontsize=16)
|
|
1044
|
+
ax.set_title('Model Selection Confusion Matrix', fontsize=18, fontweight='bold')
|
|
1045
|
+
|
|
1046
|
+
# Calculate and display accuracy
|
|
1047
|
+
total_tests = len(linear_results) + len(powerlaw_results)
|
|
1048
|
+
total_successes = linear_success + powerlaw_success
|
|
1049
|
+
overall_accuracy = total_successes / total_tests if total_tests > 0 else 0
|
|
1050
|
+
|
|
1051
|
+
# Add accuracy text
|
|
1052
|
+
ax.text(0.5, -0.25, f'Overall Accuracy: {overall_accuracy:.3f} ({total_successes}/{total_tests})',
|
|
1053
|
+
transform=ax.transAxes, ha='center', fontsize=14, fontweight='bold')
|
|
1054
|
+
|
|
1055
|
+
plt.tight_layout()
|
|
1056
|
+
|
|
1057
|
+
if export:
|
|
1058
|
+
fig.savefig(
|
|
1059
|
+
"share/confusion_matrix_heatmap.pdf",
|
|
1060
|
+
dpi=400,
|
|
1061
|
+
bbox_inches="tight",
|
|
1062
|
+
)
|
|
1063
|
+
print("Confusion matrix saved as share/confusion_matrix_heatmap.pdf")
|
|
1064
|
+
|
|
1065
|
+
if show:
|
|
1066
|
+
plt.show()
|
|
270
1067
|
|
|
271
1068
|
#´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
|
|
272
1069
|
# MAIN #
|
|
@@ -290,7 +1087,6 @@ def main(export: bool = False) -> None:
|
|
|
290
1087
|
f"share/figdata{country}.tsv",
|
|
291
1088
|
f"share/fig{country}",
|
|
292
1089
|
"-k", "total",
|
|
293
|
-
"-n", "5",
|
|
294
1090
|
"-dt", "7D",
|
|
295
1091
|
"-dr", "2020-10-01..2021-08-01",
|
|
296
1092
|
"-ep",
|
|
@@ -301,8 +1097,27 @@ def main(export: bool = False) -> None:
|
|
|
301
1097
|
df = load_final_data_df()
|
|
302
1098
|
models = load_models()
|
|
303
1099
|
|
|
304
|
-
#
|
|
305
|
-
|
|
1100
|
+
# Main plot
|
|
1101
|
+
plot_main_figure(df, models, export=export)
|
|
1102
|
+
|
|
1103
|
+
# # Size plot
|
|
1104
|
+
size_plot(df, export=export)
|
|
1105
|
+
|
|
1106
|
+
# # Anomalous diffusion plot
|
|
1107
|
+
anomalous_diffusion_plot(export=export)
|
|
1108
|
+
|
|
1109
|
+
# # Synthetic data plot
|
|
1110
|
+
synth_df = load_synthetic_data_df()
|
|
1111
|
+
synth_models = load_synthetic_data_models()
|
|
1112
|
+
synthetic_data_plot(synth_df, synth_models, export=export)
|
|
1113
|
+
|
|
1114
|
+
# # UK time windows plot
|
|
1115
|
+
additional_uk_stats = load_additional_uk_stats()
|
|
1116
|
+
additional_uk_models = load_additional_uk_models()
|
|
1117
|
+
plot_uk_time_windows(additional_uk_stats, additional_uk_models, export=export)
|
|
1118
|
+
|
|
1119
|
+
# # Confusion matrix plot
|
|
1120
|
+
create_confusion_matrix_plot(export=export)
|
|
306
1121
|
|
|
307
1122
|
|
|
308
1123
|
if __name__ == "__main__":
|