PyEvoMotion 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,12 +4,25 @@ import json
4
4
  import zipfile
5
5
  import warnings
6
6
  import urllib.request
7
+ import subprocess
7
8
 
8
9
  import numpy as np
9
10
  import pandas as pd
10
11
  import matplotlib as mpl
11
12
  import matplotlib.pyplot as plt
13
+ from matplotlib.colors import LinearSegmentedColormap
12
14
 
15
+ #´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
16
+ # CONSTANTS #
17
+ #.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•#
18
+
19
+ COLORS = {
20
+ "UK": "#76d6ff",
21
+ "USA": "#FF6346",
22
+ }
23
+
24
+ # Control confidence interval plotting
25
+ PLOT_CONFIDENCE_INTERVALS = False
13
26
 
14
27
  #´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
15
28
  # FUNCTIONS #
@@ -160,36 +173,12 @@ def load_models() -> dict[str, dict[str, callable]]:
160
173
 
161
174
  return {
162
175
  "USA": {
163
- "mean": [
164
- lambda x: (
165
- _contents["USA"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
166
- + _contents["USA"]["mean number of mutations per 7D model"]["parameters"]["b"]
167
- ),
168
- _contents["USA"]["mean number of mutations per 7D model"]["r2"],
169
- ],
170
- "var": [
171
- lambda x: (
172
- _contents["USA"]["scaled var number of mutations per 7D model"]["parameters"]["d"]
173
- *(x**_contents["USA"]["scaled var number of mutations per 7D model"]["parameters"]["alpha"])
174
- ),
175
- _contents["USA"]["scaled var number of mutations per 7D model"]["r2"],
176
- ]
176
+ "mean": list(_get_mean_model(_contents["USA"], "USA")),
177
+ "var": list(_get_var_model(_contents["USA"], "USA"))
177
178
  },
178
179
  "UK": {
179
- "mean": [
180
- lambda x: (
181
- _contents["UK"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
182
- + _contents["UK"]["mean number of mutations per 7D model"]["parameters"]["b"]
183
- ),
184
- _contents["UK"]["mean number of mutations per 7D model"]["r2"],
185
- ],
186
- "var": [
187
- lambda x: (
188
- _contents["UK"]["scaled var number of mutations per 7D model"]["parameters"]["d"]
189
- *(x**_contents["UK"]["scaled var number of mutations per 7D model"]["parameters"]["alpha"])
190
- ),
191
- _contents["UK"]["scaled var number of mutations per 7D model"]["r2"],
192
- ]
180
+ "mean": list(_get_mean_model(_contents["UK"], "UK")),
181
+ "var": list(_get_var_model(_contents["UK"], "UK"))
193
182
  },
194
183
  }
195
184
 
@@ -202,15 +191,66 @@ def safe_map(f: callable, x: list[int | float]) -> list[int | float]:
202
191
  _results.append(None)
203
192
  return _results
204
193
 
205
- def plot(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: bool = True) -> None:
194
+
195
+ def _calculate_confidence_bounds(x_values: np.ndarray, model_func: callable, confidence_intervals: dict, model_type: str) -> tuple[np.ndarray, np.ndarray]:
196
+ """Calculate confidence interval bounds for a model.
197
+
198
+ :param x_values: X values to calculate bounds for
199
+ :type x_values: np.ndarray
200
+ :param model_func: The model function
201
+ :type model_func: callable
202
+ :param confidence_intervals: Dictionary of confidence intervals for parameters
203
+ :type confidence_intervals: dict
204
+ :param model_type: Type of model ('linear_mean', 'linear_var', 'power_law')
205
+ :type model_type: str
206
+ :return: Tuple of (lower_bounds, upper_bounds)
207
+ :rtype: tuple[np.ndarray, np.ndarray]
208
+ """
209
+ if not confidence_intervals:
210
+ # No confidence intervals available, return None bounds
211
+ return None, None
212
+
213
+ if model_type == "linear_mean":
214
+ # For linear mean model: mx + b
215
+ if "m" in confidence_intervals and "b" in confidence_intervals:
216
+ m_lower, m_upper = confidence_intervals["m"]
217
+ b_lower, b_upper = confidence_intervals["b"]
218
+
219
+ # Calculate bounds for the linear function
220
+ lower_bounds = m_lower * x_values + b_lower
221
+ upper_bounds = m_upper * x_values + b_upper
222
+
223
+ return lower_bounds, upper_bounds
224
+
225
+ elif model_type == "linear_var":
226
+ # For linear variance model: mx
227
+ if "m" in confidence_intervals:
228
+ m_lower, m_upper = confidence_intervals["m"]
229
+
230
+ lower_bounds = m_lower * x_values
231
+ upper_bounds = m_upper * x_values
232
+
233
+ return lower_bounds, upper_bounds
234
+
235
+ elif model_type == "power_law":
236
+ # For power law model: d*x^alpha
237
+ if "d" in confidence_intervals and "alpha" in confidence_intervals:
238
+ d_lower, d_upper = confidence_intervals["d"]
239
+ alpha_lower, alpha_upper = confidence_intervals["alpha"]
240
+
241
+ # For power law, we need to be careful about the bounds
242
+ # We'll use the parameter bounds to create approximate confidence bounds
243
+ lower_bounds = d_lower * (x_values ** alpha_lower)
244
+ upper_bounds = d_upper * (x_values ** alpha_upper)
245
+
246
+ return lower_bounds, upper_bounds
247
+
248
+ return None, None
249
+
250
+ def plot_main_figure(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: bool = True) -> None:
206
251
  set_matplotlib_global_params()
207
252
  fig, ax = plt.subplots(2, 1, figsize=(6, 10))
208
253
 
209
- colors = {
210
- "UK": "#76d6ff",
211
- "USA": "#FF6346",
212
- }
213
-
214
254
  for idx, case in enumerate(("mean", "var")):
215
255
  for col in (f"{case} number of mutations USA", f"{case} number of mutations UK"):
216
256
 
@@ -219,20 +259,54 @@ def plot(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: b
219
259
  ax[idx].scatter(
220
260
  df.index,
221
261
  df[col] - (df[col].min() if idx == 1 else 0),
222
- color=colors[_country],
262
+ color=COLORS[_country],
223
263
  edgecolor="k",
224
264
  zorder=2,
225
265
  )
226
266
 
227
- _x = np.arange(-10, 50, 0.5)
267
+ _x = np.arange(-10, 60, 0.5)
268
+ _x_shifted = _x + (8 if _country == "USA" else 0)
269
+
270
+ # Plot the main model line
228
271
  ax[idx].plot(
229
- _x + (8 if _country == "USA" else 0),
272
+ _x_shifted,
230
273
  safe_map(models[_country][case][0], _x),
231
- color=colors[_country],
274
+ color=COLORS[_country],
232
275
  label=rf"{_country} ($R^2 = {round(models[_country][case][1], 2):.2f})$",
233
276
  linewidth=3,
234
277
  zorder=1,
235
278
  )
279
+
280
+ # Plot confidence intervals if available and enabled
281
+ if PLOT_CONFIDENCE_INTERVALS and len(models[_country][case]) > 2 and models[_country][case][2]:
282
+ confidence_intervals = models[_country][case][2]
283
+
284
+ # Determine model type for confidence interval calculation
285
+ if case == "mean":
286
+ model_type = "linear_mean"
287
+ else: # case == "var"
288
+ # Check if it's linear or power law based on the model function
289
+ # We'll determine this by checking if the model has 'alpha' parameter
290
+ if "alpha" in confidence_intervals:
291
+ model_type = "power_law"
292
+ else:
293
+ model_type = "linear_var"
294
+
295
+ lower_bounds, upper_bounds = _calculate_confidence_bounds(
296
+ _x, models[_country][case][0], confidence_intervals, model_type
297
+ )
298
+
299
+ if lower_bounds is not None and upper_bounds is not None:
300
+ # Plot confidence interval as filled area
301
+ # The x-axis shift is already applied to _x_shifted, so we use the original bounds
302
+ ax[idx].fill_between(
303
+ _x_shifted,
304
+ lower_bounds,
305
+ upper_bounds,
306
+ color=COLORS[_country],
307
+ alpha=0.2,
308
+ zorder=0,
309
+ )
236
310
 
237
311
  # Styling
238
312
  ax[idx].set_xlim(-0.5, 40.5)
@@ -260,13 +334,736 @@ def plot(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: b
260
334
 
261
335
  if export:
262
336
  fig.savefig(
263
- "share/figure.pdf",
337
+ "share/figure.eps",
338
+ dpi=400,
339
+ bbox_inches="tight",
340
+ )
341
+ print("Figure saved as share/figure.eps")
342
+
343
+ if show: plt.show()
344
+
345
+ def size_plot(df: pd.DataFrame, export: bool = False, show: bool = True) -> None:
346
+ set_matplotlib_global_params()
347
+ fig, ax = plt.subplots(1, 1, figsize=(6, 6))
348
+
349
+ # Plot UK first
350
+ markerline, stemlines, baseline = ax.stem(df.index, df[f"size UK"], label="UK")
351
+ plt.setp(stemlines, color=COLORS["UK"])
352
+ plt.setp(markerline, color=COLORS["UK"], markeredgecolor="k")
353
+ plt.setp(baseline, color="#ffffff")
354
+
355
+ # Plot USA
356
+ markerline, stemlines, baseline = ax.stem(df.index, df[f"size USA"], label="USA")
357
+ plt.setp(stemlines, color=COLORS["USA"])
358
+ plt.setp(markerline, color=COLORS["USA"], markeredgecolor="k")
359
+ plt.setp(baseline, color="#ffffff")
360
+
361
+ # Plot UK again but with slight transparency on the stem
362
+ markerline, stemlines, baseline = ax.stem(df.index, df[f"size UK"])
363
+ plt.setp(stemlines, color=COLORS["UK"], alpha=0.5)
364
+ plt.setp(markerline, color=COLORS["UK"], markeredgecolor="#000000")
365
+ plt.setp(baseline, color="#ffffff")
366
+
367
+ ax.set_ylim(0, 405)
368
+ ax.set_xlim(-0.5, 40.5)
369
+
370
+ ax.set_xlabel("time (wk)")
371
+ ax.set_ylabel("Number of sequences")
372
+
373
+ ax.legend(
374
+ fontsize=16,
375
+ loc="upper right",
376
+ bbox_to_anchor=(1.08, 1.08)
377
+ )
378
+
379
+ if export:
380
+ fig.savefig(
381
+ "share/weekly_size.pdf",
382
+ dpi=400,
383
+ bbox_inches="tight",
384
+ )
385
+ print("Figure saved as share/weekly_size.pdf")
386
+
387
+ if show: plt.show()
388
+
389
+ def anomalous_diffusion_plot(export: bool = False, show: bool = True) -> None:
390
+ set_matplotlib_global_params()
391
+ fig, ax = plt.subplots(1, 1, figsize=(6, 6))
392
+
393
+ x = np.linspace(0, 10, 100)
394
+
395
+ plt.plot(x, x**0.8, label=r"$\alpha = 0.8$" + "\n(subdiffusion)", color=COLORS["UK"], linewidth=3)
396
+ plt.plot(x, x**1, label=r"$\alpha = 1$" + "\n(normal diffusion)", color="#000000", linewidth=3)
397
+ plt.plot(x, x**1.2, label=r"$\alpha = 1.2$" + "\n(superdiffusion)", color=COLORS["USA"], linewidth=3)
398
+
399
+ plt.legend(
400
+ fontsize=13,
401
+ loc="upper left",
402
+ title=r"variance $\propto \text{time}^\alpha$",
403
+ title_fontsize=15
404
+ )
405
+
406
+ plt.xlabel("time")
407
+ plt.ylabel("variance")
408
+
409
+ ax.set_xticks([])
410
+ ax.set_yticks([])
411
+
412
+ plt.xlim(0, 10)
413
+ plt.ylim(0, 10)
414
+
415
+ if export:
416
+ fig.savefig(
417
+ "share/anomalous_diffusion.pdf",
418
+ dpi=400,
419
+ bbox_inches="tight",
420
+ )
421
+ print("Figure saved as share/anomalous_diffusion.pdf")
422
+
423
+ if show: plt.show()
424
+
425
+ def check_synthetic_data_exists() -> bool:
426
+ """
427
+ Check if the synthetic data output files exist.
428
+ """
429
+ _files = [
430
+ "tests/data/test4/synthdata1_out_stats.tsv",
431
+ "tests/data/test4/synthdata2_out_stats.tsv",
432
+ "tests/data/test4/synthdata1_out_regression_results.json",
433
+ "tests/data/test4/synthdata2_out_regression_results.json"
434
+ ]
435
+
436
+ for file in _files:
437
+ if not os.path.exists(file):
438
+ return False
439
+
440
+ return True
441
+
442
+ def run_synthetic_data_tests() -> None:
443
+ """
444
+ Run the synthetic data tests to generate the required files.
445
+ """
446
+ print("Running synthetic data tests to generate required files...")
447
+
448
+ # Create output directory
449
+ os.makedirs("tests/data/test4", exist_ok=True)
450
+
451
+ # Run tests for S1 dataset
452
+ result1 = subprocess.run(
453
+ [
454
+ "PyEvoMotion",
455
+ "tests/data/test4/S1.fasta",
456
+ "tests/data/test4/S1.tsv",
457
+ "tests/data/test4/synthdata1_out",
458
+ "-ep"
459
+ ],
460
+ stdout=subprocess.PIPE,
461
+ stderr=subprocess.PIPE,
462
+ text=True
463
+ )
464
+
465
+ if result1.stderr:
466
+ print(result1.stdout)
467
+ print(result1.stderr)
468
+ raise RuntimeError("Failed to process S1 dataset")
469
+
470
+ # Run tests for S2 dataset
471
+ result2 = subprocess.run(
472
+ [
473
+ "PyEvoMotion",
474
+ "tests/data/test4/S2.fasta",
475
+ "tests/data/test4/S2.tsv",
476
+ "tests/data/test4/synthdata2_out",
477
+ "-ep"
478
+ ],
479
+ stdout=subprocess.PIPE,
480
+ stderr=subprocess.PIPE,
481
+ text=True
482
+ )
483
+
484
+ if result2.stderr:
485
+ print(result2.stdout)
486
+ print(result2.stderr)
487
+ raise RuntimeError("Failed to process S2 dataset")
488
+
489
+ def load_synthetic_data_df() -> pd.DataFrame:
490
+ if not check_synthetic_data_exists():
491
+ run_synthetic_data_tests()
492
+
493
+ return pd.read_csv(
494
+ "tests/data/test4/synthdata1_out_stats.tsv",
495
+ sep="\t",
496
+ ).merge(
497
+ pd.read_csv(
498
+ "tests/data/test4/synthdata2_out_stats.tsv",
499
+ sep="\t",
500
+ ),
501
+ on="date",
502
+ how="outer",
503
+ suffixes=(" synt1", " synt2"),
504
+ )
505
+
506
+ def _get_mean_model(data: dict, kind: str) -> tuple[callable, float, dict]:
507
+ """Extract mean model from data, handling both old and new formats.
508
+
509
+ :param data: The regression results data dictionary
510
+ :type data: dict
511
+ :param kind: The dataset kind identifier (for error messages)
512
+ :type kind: str
513
+ :return: Tuple of (lambda function, r2 value, confidence intervals)
514
+ :rtype: tuple[callable, float, dict]
515
+ """
516
+ # Try different possible key formats
517
+ possible_keys = [
518
+ "mean number of mutations model", # New format (current)
519
+ "mean number of mutations per 7D model", # Old format
520
+ "mean number of substitutions model", # Alternative format
521
+ "mean number of substitutions per 7D model" # Alternative old format
522
+ ]
523
+
524
+ for mean_key in possible_keys:
525
+ if mean_key in data:
526
+ params = data[mean_key]["parameters"]
527
+ r2 = data[mean_key]["r2"]
528
+ confidence_intervals = data[mean_key].get("confidence_intervals", {})
529
+ return lambda x: params["m"] * x + params["b"], r2, confidence_intervals
530
+
531
+ raise KeyError(f"Could not find mean model in {kind} data. Available keys: {list(data.keys())}")
532
+
533
+
534
+ def _get_var_model(data: dict, kind: str) -> tuple[callable, float, dict]:
535
+ """Extract variance model from data, handling both old and new formats.
536
+
537
+ :param data: The regression results data dictionary
538
+ :type data: dict
539
+ :param kind: The dataset kind identifier (for error messages)
540
+ :type kind: str
541
+ :return: Tuple of (lambda function, r2 value, confidence intervals)
542
+ :rtype: tuple[callable, float, dict]
543
+ """
544
+ # Try different possible key formats
545
+ possible_keys = [
546
+ "scaled var number of mutations model", # New format (current)
547
+ "scaled var number of mutations per 7D model", # Old format
548
+ "scaled var number of substitutions model", # Alternative format
549
+ "scaled var number of substitutions per 7D model" # Alternative old format
550
+ ]
551
+
552
+ for var_key in possible_keys:
553
+ if var_key in data:
554
+ # Check if it has model_selection (new format)
555
+ if "model_selection" in data[var_key]:
556
+ # New format with model selection
557
+ model_selection = data[var_key]["model_selection"]
558
+ selected = model_selection["selected"]
559
+
560
+ if selected == "linear" and "linear_model" in data[var_key]:
561
+ # Use linear model
562
+ linear_model = data[var_key]["linear_model"]
563
+ params = linear_model["parameters"]
564
+ r2 = linear_model["r2"]
565
+ confidence_intervals = linear_model.get("confidence_intervals", {})
566
+ return lambda x: params["m"] * x, r2, confidence_intervals
567
+ elif selected == "power_law" and "power_law_model" in data[var_key]:
568
+ # Use power law model
569
+ power_law_model = data[var_key]["power_law_model"]
570
+ params = power_law_model["parameters"]
571
+ r2 = power_law_model["r2"]
572
+ confidence_intervals = power_law_model.get("confidence_intervals", {})
573
+ return lambda x: params["d"] * (x ** params["alpha"]), r2, confidence_intervals
574
+ else:
575
+ # Old format or new format without model selection - direct parameters
576
+ params = data[var_key]["parameters"]
577
+ r2 = data[var_key]["r2"]
578
+ confidence_intervals = data[var_key].get("confidence_intervals", {})
579
+ if "m" in params:
580
+ # Linear model: mx
581
+ return lambda x: params["m"] * x, r2, confidence_intervals
582
+ elif "d" in params and "alpha" in params:
583
+ # Power law model: d*x^alpha
584
+ return lambda x: params["d"] * (x ** params["alpha"]), r2, confidence_intervals
585
+
586
+ raise KeyError(f"Could not find variance model in {kind} data. Available keys: {list(data.keys())}")
587
+
588
+
589
+ def load_synthetic_data_models() -> dict[str, dict[str, callable]]:
590
+ if not check_synthetic_data_exists():
591
+ run_synthetic_data_tests()
592
+
593
+ _kinds = ("synt1", "synt2")
594
+ _file = "tests/data/test4/synthdata{}_out_regression_results.json"
595
+
596
+ _contents = {}
597
+
598
+ for k in _kinds:
599
+ with open(_file.format(k[-1])) as f:
600
+ _contents[k] = json.load(f)
601
+
602
+ return {
603
+ "synt1": {
604
+ "mean": list(_get_mean_model(_contents["synt1"], "synt1")),
605
+ "var": list(_get_var_model(_contents["synt1"], "synt1"))
606
+ },
607
+ "synt2": {
608
+ "mean": list(_get_mean_model(_contents["synt2"], "synt2")),
609
+ "var": list(_get_var_model(_contents["synt2"], "synt2"))
610
+ },
611
+ }
612
+
613
+ def synthetic_data_plot(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: bool = True) -> None:
614
+ set_matplotlib_global_params()
615
+ fig, ax = plt.subplots(2, 2, figsize=(12, 10))
616
+
617
+ # Flatten axes for easier iteration
618
+ ax = ax.flatten()
619
+
620
+ # Plot counter for subplot index
621
+ plot_idx = 0
622
+
623
+ for case in ("mean", "var"):
624
+ for col in (f"{case} number of mutations synt1", f"{case} number of mutations synt2"):
625
+ _type = col.split()[-1].upper()
626
+
627
+ # Scatter plot
628
+ ax[plot_idx].scatter(
629
+ df.index,
630
+ df[col],
631
+ color="#76d6ff",
632
+ edgecolor="k",
633
+ zorder=2,
634
+ )
635
+
636
+ # Line plot
637
+ _x = np.arange(-10, 50, 0.5)
638
+ ax[plot_idx].plot(
639
+ _x,
640
+ safe_map(models[_type.lower()][case][0], _x),
641
+ color="#76d6ff",
642
+ label=rf"$R^2 = {round(models[_type.lower()][case][1], 2):.2f}$",
643
+ linewidth=3,
644
+ zorder=1,
645
+ )
646
+
647
+ # Plot confidence intervals if available and enabled
648
+ if PLOT_CONFIDENCE_INTERVALS and len(models[_type.lower()][case]) > 2 and models[_type.lower()][case][2]:
649
+ confidence_intervals = models[_type.lower()][case][2]
650
+
651
+ # Determine model type for confidence interval calculation
652
+ if case == "mean":
653
+ model_type = "linear_mean"
654
+ else: # case == "var"
655
+ # Check if it's linear or power law based on the model function
656
+ if "alpha" in confidence_intervals:
657
+ model_type = "power_law"
658
+ else:
659
+ model_type = "linear_var"
660
+
661
+ lower_bounds, upper_bounds = _calculate_confidence_bounds(
662
+ _x, models[_type.lower()][case][0], confidence_intervals, model_type
663
+ )
664
+
665
+ if lower_bounds is not None and upper_bounds is not None:
666
+ # Plot confidence interval as filled area
667
+ ax[plot_idx].fill_between(
668
+ _x,
669
+ lower_bounds,
670
+ upper_bounds,
671
+ color="#76d6ff",
672
+ alpha=0.2,
673
+ zorder=0,
674
+ )
675
+
676
+ # Styling
677
+ ax[plot_idx].set_xlim(-0.5, 40.5)
678
+ if case == "mean":
679
+ ax[plot_idx].set_ylim(-0.25, 20.25)
680
+ ax[plot_idx].set_ylabel(f"{case} (# mutations)")
681
+ else: # var case
682
+ if _type == "SYNT1":
683
+ ax[plot_idx].set_ylim(-0.5, 40.5)
684
+ else:
685
+ ax[plot_idx].set_ylim(-0.1, 10.1)
686
+ ax[plot_idx].set_ylabel(f"{case}iance (# mutations)")
687
+
688
+ ax[plot_idx].set_xlabel("time (wk)")
689
+ ax[plot_idx].legend(
690
+ fontsize=16,
691
+ loc="upper left",
692
+ )
693
+
694
+ plot_idx += 1
695
+
696
+ fig.suptitle(" ", fontsize=1) # To get some space on top
697
+ fig.tight_layout()
698
+
699
+ # Add subplot annotations
700
+ plt.annotate("a", (0.02, 0.935), xycoords="figure fraction", fontsize=28, fontweight="bold")
701
+ plt.annotate("b", (0.505, 0.935), xycoords="figure fraction", fontsize=28, fontweight="bold")
702
+ plt.annotate("c", (0.02, 0.465), xycoords="figure fraction", fontsize=28, fontweight="bold")
703
+ plt.annotate("d", (0.505, 0.465), xycoords="figure fraction", fontsize=28, fontweight="bold")
704
+
705
+ if export:
706
+ fig.savefig(
707
+ "share/synth_figure.pdf",
264
708
  dpi=400,
265
709
  bbox_inches="tight",
266
710
  )
267
- print("Figure saved as share/figure.pdf")
711
+ print("Figure saved as share/synth_figure.pdf")
268
712
 
269
713
  if show: plt.show()
714
+
715
+
716
+ def load_additional_uk_stats() -> dict[str, pd.DataFrame]:
717
+ """
718
+ Load the additional UK stats files for different time windows.
719
+ """
720
+ _files = {
721
+ "5D": "tests/data/test3/output/20250517164757/UKout_5D_stats.tsv",
722
+ "10D": "tests/data/test3/output/20250517173133/UKout_10D_stats.tsv",
723
+ "14D": "tests/data/test3/output/20250517181004/UKout_14D_stats.tsv",
724
+ "7D": "share/figUK_stats.tsv"
725
+ }
726
+
727
+ return {
728
+ k: pd.read_csv(v, sep="\t")
729
+ for k, v in _files.items()
730
+ }
731
+
732
+ def load_additional_uk_models() -> dict[str, dict[str, callable]]:
733
+ """
734
+ Load the additional UK models for different time windows.
735
+ """
736
+ _files = {
737
+ "5D": "tests/data/test3/output/20250517164757/UKout_5D_regression_results.json",
738
+ "10D": "tests/data/test3/output/20250517173133/UKout_10D_regression_results.json",
739
+ "14D": "tests/data/test3/output/20250517181004/UKout_14D_regression_results.json",
740
+ "7D": "share/figUK_regression_results.json"
741
+ }
742
+
743
+ _contents = {}
744
+ for k, v in _files.items():
745
+ with open(v) as f:
746
+ _contents[k] = json.load(f)
747
+ return {
748
+ k: {
749
+ "mean": [
750
+ {
751
+ "m": _contents[k]["mean number of mutations model"]["parameters"]["m"],
752
+ "b": _contents[k]["mean number of mutations model"]["parameters"]["b"]
753
+ },
754
+ _contents[k]["mean number of mutations model"]["r2"],
755
+ _contents[k]["mean number of mutations model"]["confidence_intervals"]
756
+ ],
757
+ "var": [
758
+ {
759
+ "d": _contents[k]["scaled var number of mutations model"]["power_law_model"]["parameters"]["d"],
760
+ "alpha": _contents[k]["scaled var number of mutations model"]["power_law_model"]["parameters"]["alpha"]
761
+ },
762
+ _contents[k]["scaled var number of mutations model"]["power_law_model"]["r2"],
763
+ _contents[k]["scaled var number of mutations model"]["power_law_model"]["confidence_intervals"]
764
+ ]
765
+ }
766
+ for k in _files.keys()
767
+ }
768
+
769
+ def plot_uk_time_windows(stats: dict[str, pd.DataFrame], models: dict[str, dict[str, callable]], export: bool = False, show: bool = True) -> None:
770
+ """
771
+ Plot a 1x4 subplot of UK data with different time windows.
772
+
773
+ Args:
774
+ stats: Dictionary of dataframes containing the stats for each time window
775
+ models: Dictionary of models for each time window
776
+ export: Whether to export the figure
777
+ show: Whether to show the figure
778
+ """
779
+ set_matplotlib_global_params()
780
+ fig, ax = plt.subplots(2, 4, figsize=(24, 12))
781
+
782
+ # Order of time windows to plot
783
+ windows = ["5D", "7D", "10D", "14D"]
784
+
785
+ for idx, window in enumerate(windows):
786
+ df = stats[window]
787
+ model = models[window]
788
+ scaling = { # For the models to be comparable to the 7D model, we need to scale the x-axis by the square of the time window ratio
789
+ "5D": (5/7)**2,
790
+ "7D": 1,
791
+ "10D": (10/7)**2,
792
+ "14D": (14/7)**2,
793
+ }
794
+ for idx2, case in enumerate(("mean", "var")):
795
+
796
+ if case == "mean":
797
+ # Plot mean
798
+ ax[idx2, idx].scatter(
799
+ df["dt_idx"],
800
+ df["mean number of mutations"],
801
+ color=COLORS["UK"],
802
+ edgecolor="k",
803
+ zorder=2,
804
+ )
805
+
806
+ _x = np.arange(-0.5, 51, 0.5)
807
+ ax[idx2, idx].plot(
808
+ _x,
809
+ model["mean"][0]["m"]*_x + model["mean"][0]["b"],
810
+ color=COLORS["UK"],
811
+ label=rf"Mean ($R^2 = {round(model['mean'][1], 2):.2f})$",
812
+ linewidth=3,
813
+ zorder=1,
814
+ )
815
+
816
+ # Plot confidence intervals if available and enabled
817
+ if PLOT_CONFIDENCE_INTERVALS and len(model["mean"]) > 2 and model["mean"][2]:
818
+ confidence_intervals = model["mean"][2]
819
+
820
+ # For mean, it's always linear model
821
+ model_type = "linear_mean"
822
+
823
+ # Calculate confidence bounds
824
+ lower_bounds, upper_bounds = _calculate_confidence_bounds(
825
+ _x, model["mean"][0], confidence_intervals, model_type
826
+ )
827
+
828
+ if lower_bounds is not None and upper_bounds is not None:
829
+ # Plot confidence interval as filled area
830
+ ax[idx2, idx].fill_between(
831
+ _x,
832
+ lower_bounds,
833
+ upper_bounds,
834
+ color=COLORS["UK"],
835
+ alpha=0.2,
836
+ zorder=0,
837
+ )
838
+
839
+ elif case == "var":
840
+ # Plot variance
841
+ ax[idx2, idx].scatter(
842
+ df["dt_idx"],
843
+ df["var number of mutations"] - df["var number of mutations"].min(),
844
+ color=COLORS["UK"],
845
+ edgecolor="k",
846
+ zorder=2,
847
+ )
848
+
849
+ ax[idx2, idx].plot(
850
+ _x,
851
+ model["var"][0]["d"]*(_x*scaling[window])**model["var"][0]["alpha"],
852
+ color=COLORS["UK"],
853
+ label=rf"Var ($R^2 = {round(model['var'][1], 2):.2f})$",
854
+ linewidth=3,
855
+ zorder=1,
856
+ )
857
+
858
+ # Plot confidence intervals if available and enabled
859
+ if PLOT_CONFIDENCE_INTERVALS and len(model["var"]) > 2 and model["var"][2]:
860
+ confidence_intervals = model["var"][2]
861
+
862
+ # For variance, it's always power law model
863
+ model_type = "power_law"
864
+
865
+ # Calculate confidence bounds for the scaled x values
866
+ scaled_x = _x * scaling[window]
867
+ lower_bounds, upper_bounds = _calculate_confidence_bounds(
868
+ scaled_x, model["var"][0], confidence_intervals, model_type
869
+ )
870
+
871
+ if lower_bounds is not None and upper_bounds is not None:
872
+ # Plot confidence interval as filled area
873
+ ax[idx2, idx].fill_between(
874
+ _x,
875
+ lower_bounds,
876
+ upper_bounds,
877
+ color=COLORS["UK"],
878
+ alpha=0.2,
879
+ zorder=0,
880
+ )
881
+
882
+ # Styling
883
+ ax[idx2, idx].set_xlim(-0.5, 40.5)
884
+
885
+ if case == "mean":
886
+ ax[idx2, idx].set_ylim(29.5, 45.5)
887
+ else:
888
+ ax[idx2, idx].set_ylim(-0.5, 10.5)
889
+
890
+ ax[idx2, idx].set_xlabel("time (wk)")
891
+ if idx == 0:
892
+ ax[idx2, idx].set_ylabel(f"{case} (# mutations)")
893
+
894
+ ax[idx2, idx].legend(
895
+ fontsize=12,
896
+ loc="upper left",
897
+ )
898
+
899
+ if export:
900
+ fig.savefig(
901
+ "share/uk_time_windows.pdf",
902
+ dpi=400,
903
+ bbox_inches="tight",
904
+ )
905
+
906
+ if show:
907
+ plt.show()
908
+
909
+ def load_model_selection_results(directory: str) -> list[dict]:
910
+ """Load all regression results from a directory for model selection analysis.
911
+
912
+ This function recursively walks through the directory tree to find all
913
+ regression results JSON files, supporting both flat and nested directory structures.
914
+
915
+ Expected structure:
916
+ directory/
917
+ ├── {timestamp}/
918
+ │ ├── {dataset_01}/
919
+ │ │ └── *_out_regression_results.json
920
+ │ ├── {dataset_02}/
921
+ │ │ └── *_out_regression_results.json
922
+ │ └── ...
923
+ └── (or any nested structure)
924
+
925
+ :param directory: Root directory to search for regression results
926
+ :type directory: str
927
+ :return: List of dictionaries containing model selection information
928
+ :rtype: list[dict]
929
+ """
930
+ results = []
931
+
932
+ # Walk through directory tree to find all regression results files
933
+ # This works with any directory structure (flat or nested)
934
+ for root, dirs, files in os.walk(directory):
935
+ for file in files:
936
+ if file.endswith("out_regression_results.json"):
937
+ file_path = os.path.join(root, file)
938
+ try:
939
+ with open(file_path, 'r') as f:
940
+ data = json.load(f)
941
+ # Extract the model selection info
942
+ model_selection = data.get("scaled var number of substitutions model", {}).get("model_selection", {})
943
+ results.append({
944
+ 'file': file_path,
945
+ 'selected_model': model_selection.get("selected", "unknown"),
946
+ 'linear_AIC': model_selection.get("linear_AIC", None),
947
+ 'power_law_AIC': model_selection.get("power_law_AIC", None),
948
+ 'delta_AIC_linear': model_selection.get("delta_AIC_linear", None),
949
+ 'delta_AIC_power_law': model_selection.get("delta_AIC_power_law", None),
950
+ 'akaike_weight_linear': model_selection.get("akaike_weight_linear", None),
951
+ 'akaike_weight_power_law': model_selection.get("akaike_weight_power_law", None)
952
+ })
953
+ except Exception as e:
954
+ print(f"Error loading {file_path}: {e}")
955
+
956
+ return results
957
+
958
+ def create_confusion_matrix_plot(export: bool = False, show: bool = True) -> None:
959
+ """
960
+ Create a confusion matrix plot for model selection accuracy analysis.
961
+
962
+ Analyzes regression results from test5 synthetic datasets to assess
963
+ model selection accuracy. Works with organized directory structure:
964
+
965
+ tests/data/test5/
966
+ ├── linear/output/{timestamp}/
967
+ │ ├── linear_01/
968
+ │ │ └── linear_01_out_regression_results.json
969
+ │ ├── linear_02/
970
+ │ └── ...
971
+ └── powerlaw/output/{timestamp}/
972
+ ├── powerlaw_01/
973
+ └── ...
974
+
975
+ Args:
976
+ export: Whether to export the figure to share/confusion_matrix_heatmap.pdf
977
+ show: Whether to display the figure
978
+ """
979
+ set_matplotlib_global_params()
980
+
981
+ # Define paths - searches recursively through all subdirectories
982
+ base_path = "tests/data/test5"
983
+ linear_dir = os.path.join(base_path, "linear", "output")
984
+ powerlaw_dir = os.path.join(base_path, "powerlaw", "output")
985
+
986
+ print("Loading model selection results...")
987
+
988
+ # Load results from both directories
989
+ linear_results = load_model_selection_results(linear_dir)
990
+ powerlaw_results = load_model_selection_results(powerlaw_dir)
991
+
992
+ print(f"Loaded {len(linear_results)} linear results")
993
+ print(f"Loaded {len(powerlaw_results)} powerlaw results")
994
+
995
+ # Analyze results
996
+ linear_success = sum(1 for r in linear_results if r['selected_model'] == 'linear')
997
+ linear_failure = len(linear_results) - linear_success
998
+
999
+ powerlaw_success = sum(1 for r in powerlaw_results if r['selected_model'] == 'power_law')
1000
+ powerlaw_failure = len(powerlaw_results) - powerlaw_success
1001
+
1002
+ # Create confusion matrix
1003
+ # Format: [True Linear, False Linear], [False Powerlaw, True Powerlaw]
1004
+ # Transpose to flip axes: true model on x-axis, predicted model on y-axis
1005
+ confusion_matrix = np.array([
1006
+ [linear_success, linear_failure], # Predicted Linear: [True Linear, False Linear]
1007
+ [powerlaw_failure, powerlaw_success] # Predicted Powerlaw: [False Powerlaw, True Powerlaw]
1008
+ ])
1009
+
1010
+ # Create the plot
1011
+ fig, ax = plt.subplots(figsize=(8, 6))
1012
+
1013
+ # Create custom colormap from white to UK blue
1014
+
1015
+ colors = ['white', '#76d6ff']
1016
+ n_bins = 30
1017
+ cmap = LinearSegmentedColormap.from_list('custom', colors, N=n_bins)
1018
+
1019
+ # Create heatmap
1020
+ im = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap)
1021
+
1022
+ # Add colorbar
1023
+ cbar = ax.figure.colorbar(im, ax=ax)
1024
+ cbar.set_label('Count', rotation=270, labelpad=20)
1025
+
1026
+ # Set ticks and labels
1027
+ ax.set_xticks([0, 1])
1028
+ ax.set_yticks([0, 1])
1029
+ ax.set_xticklabels(['Linear', 'Power Law']) # Actual model (x-axis)
1030
+ ax.set_yticklabels(['Linear', 'Power Law']) # Predicted model (y-axis)
1031
+
1032
+ # Add text annotations
1033
+ thresh = confusion_matrix.max() / 2.
1034
+ for i in range(confusion_matrix.shape[0]):
1035
+ for j in range(confusion_matrix.shape[1]):
1036
+ ax.text(j, i, format(confusion_matrix[i, j], 'd'),
1037
+ ha="center", va="center",
1038
+ color="white" if confusion_matrix[i, j] > thresh else "black",
1039
+ fontsize=16, fontweight='bold')
1040
+
1041
+ # Labels and title
1042
+ ax.set_xlabel('Actual Model', fontsize=16)
1043
+ ax.set_ylabel('Predicted Model', fontsize=16)
1044
+ ax.set_title('Model Selection Confusion Matrix', fontsize=18, fontweight='bold')
1045
+
1046
+ # Calculate and display accuracy
1047
+ total_tests = len(linear_results) + len(powerlaw_results)
1048
+ total_successes = linear_success + powerlaw_success
1049
+ overall_accuracy = total_successes / total_tests if total_tests > 0 else 0
1050
+
1051
+ # Add accuracy text
1052
+ ax.text(0.5, -0.25, f'Overall Accuracy: {overall_accuracy:.3f} ({total_successes}/{total_tests})',
1053
+ transform=ax.transAxes, ha='center', fontsize=14, fontweight='bold')
1054
+
1055
+ plt.tight_layout()
1056
+
1057
+ if export:
1058
+ fig.savefig(
1059
+ "share/confusion_matrix_heatmap.pdf",
1060
+ dpi=400,
1061
+ bbox_inches="tight",
1062
+ )
1063
+ print("Confusion matrix saved as share/confusion_matrix_heatmap.pdf")
1064
+
1065
+ if show:
1066
+ plt.show()
270
1067
 
271
1068
  #´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
272
1069
  # MAIN #
@@ -290,7 +1087,6 @@ def main(export: bool = False) -> None:
290
1087
  f"share/figdata{country}.tsv",
291
1088
  f"share/fig{country}",
292
1089
  "-k", "total",
293
- "-n", "5",
294
1090
  "-dt", "7D",
295
1091
  "-dr", "2020-10-01..2021-08-01",
296
1092
  "-ep",
@@ -301,8 +1097,27 @@ def main(export: bool = False) -> None:
301
1097
  df = load_final_data_df()
302
1098
  models = load_models()
303
1099
 
304
- # Plot
305
- plot(df, models, export=export)
1100
+ # Main plot
1101
+ plot_main_figure(df, models, export=export)
1102
+
1103
+ # # Size plot
1104
+ size_plot(df, export=export)
1105
+
1106
+ # # Anomalous diffusion plot
1107
+ anomalous_diffusion_plot(export=export)
1108
+
1109
+ # # Synthetic data plot
1110
+ synth_df = load_synthetic_data_df()
1111
+ synth_models = load_synthetic_data_models()
1112
+ synthetic_data_plot(synth_df, synth_models, export=export)
1113
+
1114
+ # # UK time windows plot
1115
+ additional_uk_stats = load_additional_uk_stats()
1116
+ additional_uk_models = load_additional_uk_models()
1117
+ plot_uk_time_windows(additional_uk_stats, additional_uk_models, export=export)
1118
+
1119
+ # # Confusion matrix plot
1120
+ create_confusion_matrix_plot(export=export)
306
1121
 
307
1122
 
308
1123
  if __name__ == "__main__":