PyEvoMotion 0.1.1__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,6 +10,7 @@ import numpy as np
10
10
  import pandas as pd
11
11
  import matplotlib as mpl
12
12
  import matplotlib.pyplot as plt
13
+ from matplotlib.colors import LinearSegmentedColormap
13
14
 
14
15
  #´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
15
16
  # CONSTANTS #
@@ -20,6 +21,9 @@ COLORS = {
20
21
  "USA": "#FF6346",
21
22
  }
22
23
 
24
+ # Control confidence interval plotting
25
+ PLOT_CONFIDENCE_INTERVALS = False
26
+
23
27
  #´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
24
28
  # FUNCTIONS #
25
29
  #.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•#
@@ -169,36 +173,12 @@ def load_models() -> dict[str, dict[str, callable]]:
169
173
 
170
174
  return {
171
175
  "USA": {
172
- "mean": [
173
- lambda x: (
174
- _contents["USA"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
175
- + _contents["USA"]["mean number of mutations per 7D model"]["parameters"]["b"]
176
- ),
177
- _contents["USA"]["mean number of mutations per 7D model"]["r2"],
178
- ],
179
- "var": [
180
- lambda x: (
181
- _contents["USA"]["scaled var number of mutations per 7D model"]["parameters"]["d"]
182
- *(x**_contents["USA"]["scaled var number of mutations per 7D model"]["parameters"]["alpha"])
183
- ),
184
- _contents["USA"]["scaled var number of mutations per 7D model"]["r2"],
185
- ]
176
+ "mean": list(_get_mean_model(_contents["USA"], "USA")),
177
+ "var": list(_get_var_model(_contents["USA"], "USA"))
186
178
  },
187
179
  "UK": {
188
- "mean": [
189
- lambda x: (
190
- _contents["UK"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
191
- + _contents["UK"]["mean number of mutations per 7D model"]["parameters"]["b"]
192
- ),
193
- _contents["UK"]["mean number of mutations per 7D model"]["r2"],
194
- ],
195
- "var": [
196
- lambda x: (
197
- _contents["UK"]["scaled var number of mutations per 7D model"]["parameters"]["d"]
198
- *(x**_contents["UK"]["scaled var number of mutations per 7D model"]["parameters"]["alpha"])
199
- ),
200
- _contents["UK"]["scaled var number of mutations per 7D model"]["r2"],
201
- ]
180
+ "mean": list(_get_mean_model(_contents["UK"], "UK")),
181
+ "var": list(_get_var_model(_contents["UK"], "UK"))
202
182
  },
203
183
  }
204
184
 
@@ -211,6 +191,62 @@ def safe_map(f: callable, x: list[int | float]) -> list[int | float]:
211
191
  _results.append(None)
212
192
  return _results
213
193
 
194
+
195
+ def _calculate_confidence_bounds(x_values: np.ndarray, model_func: callable, confidence_intervals: dict, model_type: str) -> tuple[np.ndarray, np.ndarray]:
196
+ """Calculate confidence interval bounds for a model.
197
+
198
+ :param x_values: X values to calculate bounds for
199
+ :type x_values: np.ndarray
200
+ :param model_func: The model function
201
+ :type model_func: callable
202
+ :param confidence_intervals: Dictionary of confidence intervals for parameters
203
+ :type confidence_intervals: dict
204
+ :param model_type: Type of model ('linear_mean', 'linear_var', 'power_law')
205
+ :type model_type: str
206
+ :return: Tuple of (lower_bounds, upper_bounds)
207
+ :rtype: tuple[np.ndarray, np.ndarray]
208
+ """
209
+ if not confidence_intervals:
210
+ # No confidence intervals available, return None bounds
211
+ return None, None
212
+
213
+ if model_type == "linear_mean":
214
+ # For linear mean model: mx + b
215
+ if "m" in confidence_intervals and "b" in confidence_intervals:
216
+ m_lower, m_upper = confidence_intervals["m"]
217
+ b_lower, b_upper = confidence_intervals["b"]
218
+
219
+ # Calculate bounds for the linear function
220
+ lower_bounds = m_lower * x_values + b_lower
221
+ upper_bounds = m_upper * x_values + b_upper
222
+
223
+ return lower_bounds, upper_bounds
224
+
225
+ elif model_type == "linear_var":
226
+ # For linear variance model: mx
227
+ if "m" in confidence_intervals:
228
+ m_lower, m_upper = confidence_intervals["m"]
229
+
230
+ lower_bounds = m_lower * x_values
231
+ upper_bounds = m_upper * x_values
232
+
233
+ return lower_bounds, upper_bounds
234
+
235
+ elif model_type == "power_law":
236
+ # For power law model: d*x^alpha
237
+ if "d" in confidence_intervals and "alpha" in confidence_intervals:
238
+ d_lower, d_upper = confidence_intervals["d"]
239
+ alpha_lower, alpha_upper = confidence_intervals["alpha"]
240
+
241
+ # For power law, we need to be careful about the bounds
242
+ # We'll use the parameter bounds to create approximate confidence bounds
243
+ lower_bounds = d_lower * (x_values ** alpha_lower)
244
+ upper_bounds = d_upper * (x_values ** alpha_upper)
245
+
246
+ return lower_bounds, upper_bounds
247
+
248
+ return None, None
249
+
214
250
  def plot_main_figure(df: pd.DataFrame, models: dict[str, any], export: bool = False, show: bool = True) -> None:
215
251
  set_matplotlib_global_params()
216
252
  fig, ax = plt.subplots(2, 1, figsize=(6, 10))
@@ -229,14 +265,48 @@ def plot_main_figure(df: pd.DataFrame, models: dict[str, any], export: bool = Fa
229
265
  )
230
266
 
231
267
  _x = np.arange(-10, 60, 0.5)
268
+ _x_shifted = _x + (8 if _country == "USA" else 0)
269
+
270
+ # Plot the main model line
232
271
  ax[idx].plot(
233
- _x + (8 if _country == "USA" else 0),
272
+ _x_shifted,
234
273
  safe_map(models[_country][case][0], _x),
235
274
  color=COLORS[_country],
236
275
  label=rf"{_country} ($R^2 = {round(models[_country][case][1], 2):.2f})$",
237
276
  linewidth=3,
238
277
  zorder=1,
239
278
  )
279
+
280
+ # Plot confidence intervals if available and enabled
281
+ if PLOT_CONFIDENCE_INTERVALS and len(models[_country][case]) > 2 and models[_country][case][2]:
282
+ confidence_intervals = models[_country][case][2]
283
+
284
+ # Determine model type for confidence interval calculation
285
+ if case == "mean":
286
+ model_type = "linear_mean"
287
+ else: # case == "var"
288
+ # Check if it's linear or power law based on the model function
289
+ # We'll determine this by checking if the model has 'alpha' parameter
290
+ if "alpha" in confidence_intervals:
291
+ model_type = "power_law"
292
+ else:
293
+ model_type = "linear_var"
294
+
295
+ lower_bounds, upper_bounds = _calculate_confidence_bounds(
296
+ _x, models[_country][case][0], confidence_intervals, model_type
297
+ )
298
+
299
+ if lower_bounds is not None and upper_bounds is not None:
300
+ # Plot confidence interval as filled area
301
+ # The x-axis shift is already applied to _x_shifted, so we use the original bounds
302
+ ax[idx].fill_between(
303
+ _x_shifted,
304
+ lower_bounds,
305
+ upper_bounds,
306
+ color=COLORS[_country],
307
+ alpha=0.2,
308
+ zorder=0,
309
+ )
240
310
 
241
311
  # Styling
242
312
  ax[idx].set_xlim(-0.5, 40.5)
@@ -264,11 +334,11 @@ def plot_main_figure(df: pd.DataFrame, models: dict[str, any], export: bool = Fa
264
334
 
265
335
  if export:
266
336
  fig.savefig(
267
- "share/figure.pdf",
337
+ "share/figure.eps",
268
338
  dpi=400,
269
339
  bbox_inches="tight",
270
340
  )
271
- print("Figure saved as share/figure.pdf")
341
+ print("Figure saved as share/figure.eps")
272
342
 
273
343
  if show: plt.show()
274
344
 
@@ -382,8 +452,8 @@ def run_synthetic_data_tests() -> None:
382
452
  result1 = subprocess.run(
383
453
  [
384
454
  "PyEvoMotion",
385
- "S1.fasta",
386
- "S1.tsv",
455
+ "tests/data/test4/S1.fasta",
456
+ "tests/data/test4/S1.tsv",
387
457
  "tests/data/test4/synthdata1_out",
388
458
  "-ep"
389
459
  ],
@@ -401,8 +471,8 @@ def run_synthetic_data_tests() -> None:
401
471
  result2 = subprocess.run(
402
472
  [
403
473
  "PyEvoMotion",
404
- "S2.fasta",
405
- "S2.tsv",
474
+ "tests/data/test4/S2.fasta",
475
+ "tests/data/test4/S2.tsv",
406
476
  "tests/data/test4/synthdata2_out",
407
477
  "-ep"
408
478
  ],
@@ -433,6 +503,89 @@ def load_synthetic_data_df() -> pd.DataFrame:
433
503
  suffixes=(" synt1", " synt2"),
434
504
  )
435
505
 
506
+ def _get_mean_model(data: dict, kind: str) -> tuple[callable, float, dict]:
507
+ """Extract mean model from data, handling both old and new formats.
508
+
509
+ :param data: The regression results data dictionary
510
+ :type data: dict
511
+ :param kind: The dataset kind identifier (for error messages)
512
+ :type kind: str
513
+ :return: Tuple of (lambda function, r2 value, confidence intervals)
514
+ :rtype: tuple[callable, float, dict]
515
+ """
516
+ # Try different possible key formats
517
+ possible_keys = [
518
+ "mean number of mutations model", # New format (current)
519
+ "mean number of mutations per 7D model", # Old format
520
+ "mean number of substitutions model", # Alternative format
521
+ "mean number of substitutions per 7D model" # Alternative old format
522
+ ]
523
+
524
+ for mean_key in possible_keys:
525
+ if mean_key in data:
526
+ params = data[mean_key]["parameters"]
527
+ r2 = data[mean_key]["r2"]
528
+ confidence_intervals = data[mean_key].get("confidence_intervals", {})
529
+ return lambda x: params["m"] * x + params["b"], r2, confidence_intervals
530
+
531
+ raise KeyError(f"Could not find mean model in {kind} data. Available keys: {list(data.keys())}")
532
+
533
+
534
+ def _get_var_model(data: dict, kind: str) -> tuple[callable, float, dict]:
535
+ """Extract variance model from data, handling both old and new formats.
536
+
537
+ :param data: The regression results data dictionary
538
+ :type data: dict
539
+ :param kind: The dataset kind identifier (for error messages)
540
+ :type kind: str
541
+ :return: Tuple of (lambda function, r2 value, confidence intervals)
542
+ :rtype: tuple[callable, float, dict]
543
+ """
544
+ # Try different possible key formats
545
+ possible_keys = [
546
+ "scaled var number of mutations model", # New format (current)
547
+ "scaled var number of mutations per 7D model", # Old format
548
+ "scaled var number of substitutions model", # Alternative format
549
+ "scaled var number of substitutions per 7D model" # Alternative old format
550
+ ]
551
+
552
+ for var_key in possible_keys:
553
+ if var_key in data:
554
+ # Check if it has model_selection (new format)
555
+ if "model_selection" in data[var_key]:
556
+ # New format with model selection
557
+ model_selection = data[var_key]["model_selection"]
558
+ selected = model_selection["selected"]
559
+
560
+ if selected == "linear" and "linear_model" in data[var_key]:
561
+ # Use linear model
562
+ linear_model = data[var_key]["linear_model"]
563
+ params = linear_model["parameters"]
564
+ r2 = linear_model["r2"]
565
+ confidence_intervals = linear_model.get("confidence_intervals", {})
566
+ return lambda x: params["m"] * x, r2, confidence_intervals
567
+ elif selected == "power_law" and "power_law_model" in data[var_key]:
568
+ # Use power law model
569
+ power_law_model = data[var_key]["power_law_model"]
570
+ params = power_law_model["parameters"]
571
+ r2 = power_law_model["r2"]
572
+ confidence_intervals = power_law_model.get("confidence_intervals", {})
573
+ return lambda x: params["d"] * (x ** params["alpha"]), r2, confidence_intervals
574
+ else:
575
+ # Old format or new format without model selection - direct parameters
576
+ params = data[var_key]["parameters"]
577
+ r2 = data[var_key]["r2"]
578
+ confidence_intervals = data[var_key].get("confidence_intervals", {})
579
+ if "m" in params:
580
+ # Linear model: mx
581
+ return lambda x: params["m"] * x, r2, confidence_intervals
582
+ elif "d" in params and "alpha" in params:
583
+ # Power law model: d*x^alpha
584
+ return lambda x: params["d"] * (x ** params["alpha"]), r2, confidence_intervals
585
+
586
+ raise KeyError(f"Could not find variance model in {kind} data. Available keys: {list(data.keys())}")
587
+
588
+
436
589
  def load_synthetic_data_models() -> dict[str, dict[str, callable]]:
437
590
  if not check_synthetic_data_exists():
438
591
  run_synthetic_data_tests()
@@ -448,35 +601,12 @@ def load_synthetic_data_models() -> dict[str, dict[str, callable]]:
448
601
 
449
602
  return {
450
603
  "synt1": {
451
- "mean": [
452
- lambda x: (
453
- _contents["synt1"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
454
- + _contents["synt1"]["mean number of mutations per 7D model"]["parameters"]["b"]
455
- ),
456
- _contents["synt1"]["mean number of mutations per 7D model"]["r2"],
457
- ],
458
- "var": [
459
- lambda x: (
460
- _contents["synt1"]["scaled var number of mutations per 7D model"]["parameters"]["m"]*x
461
- ),
462
- _contents["synt1"]["scaled var number of mutations per 7D model"]["r2"],
463
- ]
604
+ "mean": list(_get_mean_model(_contents["synt1"], "synt1")),
605
+ "var": list(_get_var_model(_contents["synt1"], "synt1"))
464
606
  },
465
607
  "synt2": {
466
- "mean": [
467
- lambda x: (
468
- _contents["synt2"]["mean number of mutations per 7D model"]["parameters"]["m"]*x
469
- + _contents["synt2"]["mean number of mutations per 7D model"]["parameters"]["b"]
470
- ),
471
- _contents["synt2"]["mean number of mutations per 7D model"]["r2"],
472
- ],
473
- "var": [
474
- lambda x: (
475
- _contents["synt2"]["scaled var number of mutations per 7D model"]["parameters"]["d"]
476
- *(x**_contents["synt2"]["scaled var number of mutations per 7D model"]["parameters"]["alpha"])
477
- ),
478
- _contents["synt2"]["scaled var number of mutations per 7D model"]["r2"],
479
- ]
608
+ "mean": list(_get_mean_model(_contents["synt2"], "synt2")),
609
+ "var": list(_get_var_model(_contents["synt2"], "synt2"))
480
610
  },
481
611
  }
482
612
 
@@ -513,6 +643,35 @@ def synthetic_data_plot(df: pd.DataFrame, models: dict[str, any], export: bool =
513
643
  linewidth=3,
514
644
  zorder=1,
515
645
  )
646
+
647
+ # Plot confidence intervals if available and enabled
648
+ if PLOT_CONFIDENCE_INTERVALS and len(models[_type.lower()][case]) > 2 and models[_type.lower()][case][2]:
649
+ confidence_intervals = models[_type.lower()][case][2]
650
+
651
+ # Determine model type for confidence interval calculation
652
+ if case == "mean":
653
+ model_type = "linear_mean"
654
+ else: # case == "var"
655
+ # Check if it's linear or power law based on the model function
656
+ if "alpha" in confidence_intervals:
657
+ model_type = "power_law"
658
+ else:
659
+ model_type = "linear_var"
660
+
661
+ lower_bounds, upper_bounds = _calculate_confidence_bounds(
662
+ _x, models[_type.lower()][case][0], confidence_intervals, model_type
663
+ )
664
+
665
+ if lower_bounds is not None and upper_bounds is not None:
666
+ # Plot confidence interval as filled area
667
+ ax[plot_idx].fill_between(
668
+ _x,
669
+ lower_bounds,
670
+ upper_bounds,
671
+ color="#76d6ff",
672
+ alpha=0.2,
673
+ zorder=0,
674
+ )
516
675
 
517
676
  # Styling
518
677
  ax[plot_idx].set_xlim(-0.5, 40.5)
@@ -589,17 +748,19 @@ def load_additional_uk_models() -> dict[str, dict[str, callable]]:
589
748
  k: {
590
749
  "mean": [
591
750
  {
592
- "m": _contents[k][f"mean number of mutations per {k} model"]["parameters"]["m"],
593
- "b": _contents[k][f"mean number of mutations per {k} model"]["parameters"]["b"]
751
+ "m": _contents[k]["mean number of mutations model"]["parameters"]["m"],
752
+ "b": _contents[k]["mean number of mutations model"]["parameters"]["b"]
594
753
  },
595
- _contents[k][f"mean number of mutations per {k} model"]["r2"]
754
+ _contents[k]["mean number of mutations model"]["r2"],
755
+ _contents[k]["mean number of mutations model"]["confidence_intervals"]
596
756
  ],
597
757
  "var": [
598
758
  {
599
- "d": _contents[k][f"scaled var number of mutations per {k} model"]["parameters"]["d"],
600
- "alpha": _contents[k][f"scaled var number of mutations per {k} model"]["parameters"]["alpha"]
759
+ "d": _contents[k]["scaled var number of mutations model"]["power_law_model"]["parameters"]["d"],
760
+ "alpha": _contents[k]["scaled var number of mutations model"]["power_law_model"]["parameters"]["alpha"]
601
761
  },
602
- _contents[k][f"scaled var number of mutations per {k} model"]["r2"],
762
+ _contents[k]["scaled var number of mutations model"]["power_law_model"]["r2"],
763
+ _contents[k]["scaled var number of mutations model"]["power_law_model"]["confidence_intervals"]
603
764
  ]
604
765
  }
605
766
  for k in _files.keys()
@@ -624,18 +785,18 @@ def plot_uk_time_windows(stats: dict[str, pd.DataFrame], models: dict[str, dict[
624
785
  for idx, window in enumerate(windows):
625
786
  df = stats[window]
626
787
  model = models[window]
627
- scaling = {
628
- "5D": 5/7,
788
+ scaling = { # For the models to be comparable to the 7D model, we need to scale the x-axis by the square of the time window ratio
789
+ "5D": (5/7)**2,
629
790
  "7D": 1,
630
- "10D": 10/7,
631
- "14D": 14/7,
791
+ "10D": (10/7)**2,
792
+ "14D": (14/7)**2,
632
793
  }
633
794
  for idx2, case in enumerate(("mean", "var")):
634
795
 
635
796
  if case == "mean":
636
797
  # Plot mean
637
798
  ax[idx2, idx].scatter(
638
- df.index.to_numpy()*scaling[window],
799
+ df["dt_idx"],
639
800
  df["mean number of mutations"],
640
801
  color=COLORS["UK"],
641
802
  edgecolor="k",
@@ -645,17 +806,40 @@ def plot_uk_time_windows(stats: dict[str, pd.DataFrame], models: dict[str, dict[
645
806
  _x = np.arange(-0.5, 51, 0.5)
646
807
  ax[idx2, idx].plot(
647
808
  _x,
648
- model["mean"][0]["m"]*(_x/scaling[window]) + model["mean"][0]["b"],
809
+ model["mean"][0]["m"]*_x + model["mean"][0]["b"],
649
810
  color=COLORS["UK"],
650
811
  label=rf"Mean ($R^2 = {round(model['mean'][1], 2):.2f})$",
651
812
  linewidth=3,
652
813
  zorder=1,
653
814
  )
815
+
816
+ # Plot confidence intervals if available and enabled
817
+ if PLOT_CONFIDENCE_INTERVALS and len(model["mean"]) > 2 and model["mean"][2]:
818
+ confidence_intervals = model["mean"][2]
819
+
820
+ # For mean, it's always linear model
821
+ model_type = "linear_mean"
822
+
823
+ # Calculate confidence bounds
824
+ lower_bounds, upper_bounds = _calculate_confidence_bounds(
825
+ _x, model["mean"][0], confidence_intervals, model_type
826
+ )
827
+
828
+ if lower_bounds is not None and upper_bounds is not None:
829
+ # Plot confidence interval as filled area
830
+ ax[idx2, idx].fill_between(
831
+ _x,
832
+ lower_bounds,
833
+ upper_bounds,
834
+ color=COLORS["UK"],
835
+ alpha=0.2,
836
+ zorder=0,
837
+ )
654
838
 
655
839
  elif case == "var":
656
840
  # Plot variance
657
841
  ax[idx2, idx].scatter(
658
- df.index.to_numpy()*scaling[window],
842
+ df["dt_idx"],
659
843
  df["var number of mutations"] - df["var number of mutations"].min(),
660
844
  color=COLORS["UK"],
661
845
  edgecolor="k",
@@ -664,12 +848,36 @@ def plot_uk_time_windows(stats: dict[str, pd.DataFrame], models: dict[str, dict[
664
848
 
665
849
  ax[idx2, idx].plot(
666
850
  _x,
667
- model["var"][0]["d"]*(_x/scaling[window])**model["var"][0]["alpha"],
851
+ model["var"][0]["d"]*(_x*scaling[window])**model["var"][0]["alpha"],
668
852
  color=COLORS["UK"],
669
853
  label=rf"Var ($R^2 = {round(model['var'][1], 2):.2f})$",
670
854
  linewidth=3,
671
855
  zorder=1,
672
856
  )
857
+
858
+ # Plot confidence intervals if available and enabled
859
+ if PLOT_CONFIDENCE_INTERVALS and len(model["var"]) > 2 and model["var"][2]:
860
+ confidence_intervals = model["var"][2]
861
+
862
+ # For variance, it's always power law model
863
+ model_type = "power_law"
864
+
865
+ # Calculate confidence bounds for the scaled x values
866
+ scaled_x = _x * scaling[window]
867
+ lower_bounds, upper_bounds = _calculate_confidence_bounds(
868
+ scaled_x, model["var"][0], confidence_intervals, model_type
869
+ )
870
+
871
+ if lower_bounds is not None and upper_bounds is not None:
872
+ # Plot confidence interval as filled area
873
+ ax[idx2, idx].fill_between(
874
+ _x,
875
+ lower_bounds,
876
+ upper_bounds,
877
+ color=COLORS["UK"],
878
+ alpha=0.2,
879
+ zorder=0,
880
+ )
673
881
 
674
882
  # Styling
675
883
  ax[idx2, idx].set_xlim(-0.5, 40.5)
@@ -698,6 +906,165 @@ def plot_uk_time_windows(stats: dict[str, pd.DataFrame], models: dict[str, dict[
698
906
  if show:
699
907
  plt.show()
700
908
 
909
+ def load_model_selection_results(directory: str) -> list[dict]:
910
+ """Load all regression results from a directory for model selection analysis.
911
+
912
+ This function recursively walks through the directory tree to find all
913
+ regression results JSON files, supporting both flat and nested directory structures.
914
+
915
+ Expected structure:
916
+ directory/
917
+ ├── {timestamp}/
918
+ │ ├── {dataset_01}/
919
+ │ │ └── *_out_regression_results.json
920
+ │ ├── {dataset_02}/
921
+ │ │ └── *_out_regression_results.json
922
+ │ └── ...
923
+ └── (or any nested structure)
924
+
925
+ :param directory: Root directory to search for regression results
926
+ :type directory: str
927
+ :return: List of dictionaries containing model selection information
928
+ :rtype: list[dict]
929
+ """
930
+ results = []
931
+
932
+ # Walk through directory tree to find all regression results files
933
+ # This works with any directory structure (flat or nested)
934
+ for root, dirs, files in os.walk(directory):
935
+ for file in files:
936
+ if file.endswith("out_regression_results.json"):
937
+ file_path = os.path.join(root, file)
938
+ try:
939
+ with open(file_path, 'r') as f:
940
+ data = json.load(f)
941
+ # Extract the model selection info
942
+ model_selection = data.get("scaled var number of substitutions model", {}).get("model_selection", {})
943
+ results.append({
944
+ 'file': file_path,
945
+ 'selected_model': model_selection.get("selected", "unknown"),
946
+ 'linear_AIC': model_selection.get("linear_AIC", None),
947
+ 'power_law_AIC': model_selection.get("power_law_AIC", None),
948
+ 'delta_AIC_linear': model_selection.get("delta_AIC_linear", None),
949
+ 'delta_AIC_power_law': model_selection.get("delta_AIC_power_law", None),
950
+ 'akaike_weight_linear': model_selection.get("akaike_weight_linear", None),
951
+ 'akaike_weight_power_law': model_selection.get("akaike_weight_power_law", None)
952
+ })
953
+ except Exception as e:
954
+ print(f"Error loading {file_path}: {e}")
955
+
956
+ return results
957
+
958
+ def create_confusion_matrix_plot(export: bool = False, show: bool = True) -> None:
959
+ """
960
+ Create a confusion matrix plot for model selection accuracy analysis.
961
+
962
+ Analyzes regression results from test5 synthetic datasets to assess
963
+ model selection accuracy. Works with organized directory structure:
964
+
965
+ tests/data/test5/
966
+ ├── linear/output/{timestamp}/
967
+ │ ├── linear_01/
968
+ │ │ └── linear_01_out_regression_results.json
969
+ │ ├── linear_02/
970
+ │ └── ...
971
+ └── powerlaw/output/{timestamp}/
972
+ ├── powerlaw_01/
973
+ └── ...
974
+
975
+ Args:
976
+ export: Whether to export the figure to share/confusion_matrix_heatmap.pdf
977
+ show: Whether to display the figure
978
+ """
979
+ set_matplotlib_global_params()
980
+
981
+ # Define paths - searches recursively through all subdirectories
982
+ base_path = "tests/data/test5"
983
+ linear_dir = os.path.join(base_path, "linear", "output")
984
+ powerlaw_dir = os.path.join(base_path, "powerlaw", "output")
985
+
986
+ print("Loading model selection results...")
987
+
988
+ # Load results from both directories
989
+ linear_results = load_model_selection_results(linear_dir)
990
+ powerlaw_results = load_model_selection_results(powerlaw_dir)
991
+
992
+ print(f"Loaded {len(linear_results)} linear results")
993
+ print(f"Loaded {len(powerlaw_results)} powerlaw results")
994
+
995
+ # Analyze results
996
+ linear_success = sum(1 for r in linear_results if r['selected_model'] == 'linear')
997
+ linear_failure = len(linear_results) - linear_success
998
+
999
+ powerlaw_success = sum(1 for r in powerlaw_results if r['selected_model'] == 'power_law')
1000
+ powerlaw_failure = len(powerlaw_results) - powerlaw_success
1001
+
1002
+ # Create confusion matrix
1003
+ # Format: [True Linear, False Linear], [False Powerlaw, True Powerlaw]
1004
+ # Transpose to flip axes: true model on x-axis, predicted model on y-axis
1005
+ confusion_matrix = np.array([
1006
+ [linear_success, linear_failure], # Predicted Linear: [True Linear, False Linear]
1007
+ [powerlaw_failure, powerlaw_success] # Predicted Powerlaw: [False Powerlaw, True Powerlaw]
1008
+ ])
1009
+
1010
+ # Create the plot
1011
+ fig, ax = plt.subplots(figsize=(8, 6))
1012
+
1013
+ # Create custom colormap from white to UK blue
1014
+
1015
+ colors = ['white', '#76d6ff']
1016
+ n_bins = 30
1017
+ cmap = LinearSegmentedColormap.from_list('custom', colors, N=n_bins)
1018
+
1019
+ # Create heatmap
1020
+ im = ax.imshow(confusion_matrix, interpolation='nearest', cmap=cmap)
1021
+
1022
+ # Add colorbar
1023
+ cbar = ax.figure.colorbar(im, ax=ax)
1024
+ cbar.set_label('Count', rotation=270, labelpad=20)
1025
+
1026
+ # Set ticks and labels
1027
+ ax.set_xticks([0, 1])
1028
+ ax.set_yticks([0, 1])
1029
+ ax.set_xticklabels(['Linear', 'Power Law']) # Actual model (x-axis)
1030
+ ax.set_yticklabels(['Linear', 'Power Law']) # Predicted model (y-axis)
1031
+
1032
+ # Add text annotations
1033
+ thresh = confusion_matrix.max() / 2.
1034
+ for i in range(confusion_matrix.shape[0]):
1035
+ for j in range(confusion_matrix.shape[1]):
1036
+ ax.text(j, i, format(confusion_matrix[i, j], 'd'),
1037
+ ha="center", va="center",
1038
+ color="white" if confusion_matrix[i, j] > thresh else "black",
1039
+ fontsize=16, fontweight='bold')
1040
+
1041
+ # Labels and title
1042
+ ax.set_xlabel('Actual Model', fontsize=16)
1043
+ ax.set_ylabel('Predicted Model', fontsize=16)
1044
+ ax.set_title('Model Selection Confusion Matrix', fontsize=18, fontweight='bold')
1045
+
1046
+ # Calculate and display accuracy
1047
+ total_tests = len(linear_results) + len(powerlaw_results)
1048
+ total_successes = linear_success + powerlaw_success
1049
+ overall_accuracy = total_successes / total_tests if total_tests > 0 else 0
1050
+
1051
+ # Add accuracy text
1052
+ ax.text(0.5, -0.25, f'Overall Accuracy: {overall_accuracy:.3f} ({total_successes}/{total_tests})',
1053
+ transform=ax.transAxes, ha='center', fontsize=14, fontweight='bold')
1054
+
1055
+ plt.tight_layout()
1056
+
1057
+ if export:
1058
+ fig.savefig(
1059
+ "share/confusion_matrix_heatmap.pdf",
1060
+ dpi=400,
1061
+ bbox_inches="tight",
1062
+ )
1063
+ print("Confusion matrix saved as share/confusion_matrix_heatmap.pdf")
1064
+
1065
+ if show:
1066
+ plt.show()
1067
+
701
1068
  #´:°•.°+.*•´.*:˚.°*.˚•´.°:°•.°•.*•´.*:˚.°*.˚•´.°:°•.°+.*•´.*:#
702
1069
  # MAIN #
703
1070
  #.•°:°.´+˚.*°.˚:*.´•*.+°.•°:´*.´•*.•°.•°:°.´:•˚°.*°.˚:*.´+°.•#
@@ -733,22 +1100,25 @@ def main(export: bool = False) -> None:
733
1100
  # Main plot
734
1101
  plot_main_figure(df, models, export=export)
735
1102
 
736
- # Size plot
1103
+ # # Size plot
737
1104
  size_plot(df, export=export)
738
1105
 
739
- # Anomalous diffusion plot
1106
+ # # Anomalous diffusion plot
740
1107
  anomalous_diffusion_plot(export=export)
741
1108
 
742
- # Synthetic data plot
1109
+ # # Synthetic data plot
743
1110
  synth_df = load_synthetic_data_df()
744
1111
  synth_models = load_synthetic_data_models()
745
1112
  synthetic_data_plot(synth_df, synth_models, export=export)
746
1113
 
747
- # UK time windows plot
1114
+ # # UK time windows plot
748
1115
  additional_uk_stats = load_additional_uk_stats()
749
1116
  additional_uk_models = load_additional_uk_models()
750
1117
  plot_uk_time_windows(additional_uk_stats, additional_uk_models, export=export)
751
1118
 
1119
+ # # Confusion matrix plot
1120
+ create_confusion_matrix_plot(export=export)
1121
+
752
1122
 
753
1123
  if __name__ == "__main__":
754
1124