pythonflex 0.3.3__py3-none-any.whl → 0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pythonflex/analysis.py CHANGED
@@ -11,7 +11,6 @@ from pathlib import Path
11
11
  from art import tprint
12
12
  from bitarray import bitarray
13
13
  from joblib import Parallel, delayed, dump, load
14
- import matplotlib.pyplot as plt
15
14
  from numba import njit, prange
16
15
  import numpy as np
17
16
  import pandas as pd
@@ -20,8 +19,8 @@ from tqdm import tqdm
20
19
 
21
20
  # Local/application-specific imports
22
21
  from .logging_config import log
23
- from .preprocessing import filter_matrix_by_genes
24
- from .utils import dsave, dload, _sanitize
22
+ from .preprocessing import filter_matrix_by_genes, filter_duplicate_terms
23
+ from .utils import dsave, dload, _sanitize, normalize_analysis_genes
25
24
 
26
25
  import matplotlib as mpl
27
26
 
@@ -36,6 +35,8 @@ def deep_update(source, overrides):
36
35
 
37
36
  def initialize(config={}):
38
37
 
38
+ user_overrides = config if isinstance(config, dict) else {}
39
+
39
40
  default_config = {
40
41
  "min_genes_in_complex": 3,
41
42
  "min_genes_per_complex_analysis": 2,
@@ -43,7 +44,10 @@ def initialize(config={}):
43
44
  "gold_standard": "CORUM",
44
45
  "color_map": "RdYlBu",
45
46
  "jaccard": True,
46
- "use_common_genes": True,
47
+ # Which genes are used for analysis (drives used_genes intersection)
48
+ # - 'shared' : use genes common to all datasets (common_genes)
49
+ # - 'dataset_specific' : use genes present in each dataset individually
50
+ "analysis_genes": "shared",
47
51
  "plotting": {
48
52
  "save_plot": True,
49
53
  "show_plot": True,
@@ -55,6 +59,11 @@ def initialize(config={}):
55
59
  "drop_na": False,
56
60
  },
57
61
  "corr_function": "numpy",
62
+ "per_complex": {
63
+ "n_jobs": 4,
64
+ "chunk_size": 200,
65
+ "max_nbytes": "100M",
66
+ },
58
67
  "logging": { # Added: Default logging config
59
68
  "visible_levels": ["DONE"] # if needed #, "PROGRESS", "STARTED", "INFO"
60
69
  }
@@ -65,6 +74,26 @@ def initialize(config={}):
65
74
  config = deep_update(default_config, config)
66
75
  else:
67
76
  config = default_config
77
+
78
+ # Backward compatibility: if user provided legacy key but not the new one,
79
+ # map it to analysis_genes. (We must look at the original overrides, because
80
+ # defaults always include analysis_genes.)
81
+ analysis_genes_provided = (
82
+ isinstance(user_overrides, dict)
83
+ and "analysis_genes" in user_overrides
84
+ and user_overrides.get("analysis_genes") is not None
85
+ and str(user_overrides.get("analysis_genes")).strip() != ""
86
+ )
87
+ if (
88
+ isinstance(user_overrides, dict)
89
+ and "use_common_genes" in user_overrides
90
+ and not analysis_genes_provided
91
+ ):
92
+ config["analysis_genes"] = (
93
+ "shared" if bool(user_overrides.get("use_common_genes")) else "dataset_specific"
94
+ )
95
+
96
+ config["analysis_genes"] = normalize_analysis_genes(config.get("analysis_genes"))
68
97
 
69
98
  # Extract visible_levels from the merged config and set logging visibility immediately (before any logs)
70
99
  visible_levels = config.get("logging", {}).get("visible_levels", ["DONE"])
@@ -110,12 +139,15 @@ def update_matploblib_config(config=None, font_family="Arial", layout="single"):
110
139
  if config is None:
111
140
  config = {}
112
141
  # Fallback if chosen font missing
142
+ requested_font_family = font_family
113
143
  try:
114
144
  from matplotlib.font_manager import findfont, FontProperties
115
145
  findfont(FontProperties(family=font_family))
116
146
  except Exception:
117
147
  font_family = "Helvetica" # Nature prefers Helvetica if Arial unavailable
118
- print(f"Warning: '{font_family}' not found, falling back to 'Helvetica'.")
148
+ log.warning(
149
+ f"Font '{requested_font_family}' not found; falling back to '{font_family}'."
150
+ )
119
151
 
120
152
  # Figure size presets (Nature: single ≈ 89 mm, double ≈ 183 mm at 25.4 mm/inch)
121
153
  if isinstance(layout, tuple):
@@ -190,50 +222,114 @@ def update_matploblib_config(config=None, font_family="Arial", layout="single"):
190
222
  "svg.fonttype": "none",
191
223
  })
192
224
 
193
- def pra(dataset_name, matrix, is_corr=False):
194
- log.info(f"******************** {dataset_name} ********************")
195
- log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
225
+
226
+ def _sort_ascending_for_dataset(dataset_name):
227
+ sorting = dload("input", "sorting")
228
+ if not isinstance(sorting, dict):
229
+ return False
230
+ sort_order = str(sorting.get(dataset_name, "high")).strip().lower()
231
+ return sort_order == "low"
232
+
233
+
234
+ def prepare_terms_for_dataset(dataset_name, matrix):
235
+ """Prepare dataset-specific gold-standard terms and filtered matrix.
236
+
237
+ This computes:
238
+ - terms['used_genes'] as the intersection of terms['all_genes'] with either
239
+ shared genes (config['analysis_genes']=='shared') or the dataset genes
240
+ (config['analysis_genes']=='dataset_specific').
241
+ - genes_present_in_terms_<dataset_name>
242
+
243
+ Side effects:
244
+ - stores dataset-specific terms and genes list under:
245
+ dsave(..., 'common', f'terms_{dataset_name}')
246
+ dsave(..., 'common', f'genes_present_in_terms_{dataset_name}')
247
+
248
+ Returns:
249
+ (terms, genes_present, matrix_filtered)
250
+ """
196
251
  config = dload("config")
197
- use_common_genes = config.get("use_common_genes", True)
252
+ if config is None:
253
+ raise RuntimeError(
254
+ "prepare_terms_for_dataset(): config not found. Run initialize() first."
255
+ )
198
256
 
199
257
  terms_data = dload("common", "terms")
200
258
  if terms_data is None or not isinstance(terms_data, pd.DataFrame):
201
- raise ValueError("Expected 'terms' to be a DataFrame, but got None or invalid type.")
259
+ raise ValueError(
260
+ "prepare_terms_for_dataset(): expected 'terms' to be a DataFrame, but got None or invalid type. "
261
+ "Make sure to run load_gold_standard() first."
262
+ )
202
263
  terms = terms_data.copy()
203
- sorting = dload("input", "sorting")
204
- sort_order = sorting.get(dataset_name, "high")
205
-
206
- if not is_corr:
207
- matrix = perform_corr(matrix, config.get("corr_function"))
208
264
 
209
- # Apply per-dataset gene filtering based on use_common_genes setting
210
- if use_common_genes:
211
- # Use common genes approach (current behavior)
212
- common_genes = dload("common", "common_genes")
213
- if not common_genes:
214
- raise ValueError("Common genes not found.")
215
- common_genes_set = set(common_genes)
216
- terms["used_genes"] = terms["all_genes"].apply(lambda x: list(set(x) & common_genes_set))
217
- log.info(f"Using common genes approach: {len(common_genes)} genes")
218
- else:
219
- # Use per-dataset approach (new behavior)
220
- dataset_genes_set = set(matrix.index)
221
- terms["used_genes"] = terms["all_genes"].apply(lambda x: list(set(x) & dataset_genes_set))
222
- log.info(f"Using per-dataset approach for {dataset_name}: {len(dataset_genes_set)} genes in dataset")
265
+ analysis_genes = normalize_analysis_genes(config.get("analysis_genes"))
266
+
267
+ if analysis_genes == "shared":
268
+ common_genes = dload("common", "common_genes")
269
+ if common_genes is None:
270
+ raise ValueError(
271
+ "prepare_terms_for_dataset(): common genes not found. "
272
+ "Run get_common_genes() or set analysis_genes='dataset_specific'."
273
+ )
274
+
275
+ common_genes_list = list(common_genes)
276
+ if len(common_genes_list) == 0:
277
+ raise ValueError(
278
+ "prepare_terms_for_dataset(): common genes is empty. "
279
+ "Run get_common_genes() or set analysis_genes='dataset_specific'."
280
+ )
281
+
282
+ gene_universe = set(common_genes_list)
283
+ log.info(f"Using shared genes approach: {len(gene_universe)} genes")
284
+ else:
285
+ gene_universe = set(matrix.index)
286
+ log.info(
287
+ f"Using dataset-specific approach for {dataset_name}: {len(gene_universe)} genes in dataset"
288
+ )
289
+
290
+ terms["used_genes"] = terms["all_genes"].apply(
291
+ lambda genes: list(set(genes) & gene_universe)
292
+ )
223
293
 
224
- # Filter terms by minimum genes after dataset-specific filtering
294
+ min_genes_raw = config.get("min_genes_in_complex", 3)
295
+ min_genes = int(min_genes_raw) if min_genes_raw is not None else 3
225
296
  terms["n_used_genes"] = terms["used_genes"].apply(len)
226
- terms = terms[terms["n_used_genes"] >= config['min_genes_in_complex']]
227
-
228
- # Get genes present in terms for this specific dataset
229
- genes_present = list(set([gene for genes_list in terms["used_genes"] for gene in genes_list]))
297
+ terms = terms[terms["n_used_genes"] >= min_genes]
298
+
299
+ if bool(config.get("jaccard", False)):
300
+ before = len(terms)
301
+ terms = filter_duplicate_terms(terms)
302
+ log.done(
303
+ f"After Jaccard duplicate used_genes filtering for {dataset_name}: "
304
+ f"{len(terms)} terms ({before - len(terms)} removed)"
305
+ )
306
+
307
+ genes_present = list(
308
+ set(gene for genes_list in terms["used_genes"] for gene in genes_list)
309
+ )
230
310
  log.info(f"Genes present in terms for {dataset_name}: {len(genes_present)}")
231
-
232
- matrix = filter_matrix_by_genes(matrix, genes_present)
311
+
312
+ matrix_filtered = filter_matrix_by_genes(matrix, genes_present)
313
+
314
+ dsave(terms, "common", f"terms_{dataset_name}")
315
+ dsave(genes_present, "common", f"genes_present_in_terms_{dataset_name}")
316
+
317
+ return terms, genes_present, matrix_filtered
318
+
319
+ def pra(dataset_name, matrix, is_corr=False):
320
+ log.info(f"******************** {dataset_name} ********************")
321
+ log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
322
+ config = dload("config")
323
+ ascending = _sort_ascending_for_dataset(dataset_name)
324
+
325
+ if not is_corr:
326
+ matrix = perform_corr(matrix, config.get("corr_function"))
327
+
328
+ terms, _genes_present, matrix = prepare_terms_for_dataset(dataset_name, matrix)
233
329
  log.info(f"Matrix shape: {matrix.shape}")
234
330
  df = binary(matrix)
235
331
  log.info(f"Pair-wise shape: {df.shape}")
236
- df = quick_sort(df, ascending=(sort_order == "low"))
332
+ df = quick_sort(df, ascending=ascending)
237
333
 
238
334
  log.started("Building gene-to-pair indices")
239
335
  gold_pair_to_complex = _build_gold_pair_to_complex(terms)
@@ -251,23 +347,22 @@ def pra(dataset_name, matrix, is_corr=False):
251
347
  if df["prediction"].sum() == 0:
252
348
  log.info("No true positives found in dataset.")
253
349
  pr_auc = np.nan
350
+ df["tp"] = 0
351
+ df["precision"] = np.nan
352
+ df["recall"] = np.nan
254
353
  else:
255
354
  tp = df["prediction"].cumsum()
256
355
  df["tp"] = tp
257
356
  precision = tp / (np.arange(len(df)) + 1)
258
357
  recall = tp / tp.iloc[-1]
259
- pr_auc = metrics.auc(recall, precision)
260
358
  df["precision"] = precision
261
359
  df["recall"] = recall
360
+ pr_auc = metrics.auc(recall, precision) if len(recall) >= 2 else np.nan
262
361
 
263
362
  log.info(f"PR-AUC: {pr_auc:.4f}, Number of true positives: {df['prediction'].sum()}")
264
363
  dsave(df, "pra", dataset_name)
265
364
  dsave(pr_auc, "pr_auc", dataset_name)
266
- dsave( _corrected_auc(df) , "corrected_pr_auc", dataset_name)
267
-
268
- # Save dataset-specific terms for per-complex analysis
269
- dsave(terms, "common", f"terms_{dataset_name}")
270
- dsave(genes_present, "common", f"genes_present_in_terms_{dataset_name}")
365
+ dsave(_corrected_auc(df), "corrected_pr_auc", dataset_name)
271
366
 
272
367
  log.done(f"Global PRA completed for {dataset_name}")
273
368
  return df
@@ -277,7 +372,12 @@ def pra(dataset_name, matrix, is_corr=False):
277
372
  # --------------------------------------------------------------------------
278
373
 
279
374
  def _corrected_auc(df: pd.DataFrame) -> float:
280
- return np.trapz(df["precision"], df["recall"]) - df["precision"].iloc[-1]
375
+ if df.empty or "precision" not in df.columns or "recall" not in df.columns:
376
+ return np.nan
377
+ valid = df[["precision", "recall"]].replace([np.inf, -np.inf], np.nan).dropna()
378
+ if len(valid) < 2:
379
+ return np.nan
380
+ return np.trapz(valid["precision"], valid["recall"]) - valid["precision"].iloc[-1]
281
381
 
282
382
  def _build_gene_to_pair_indices(pairwise_df):
283
383
  indices = pairwise_df.index.values
@@ -394,28 +494,52 @@ def _process_chunk(chunk_terms, min_genes, memmap_path, gene_to_pair_indices):
394
494
  # Return error info for debugging
395
495
  return {'error': str(e), 'chunk_size': len(chunk_terms)}
396
496
 
397
- def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
497
+ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=None, n_jobs=None):
398
498
  log.started(f"*** Per-complex PRA started - {dataset_name} ***")
399
499
  config = dload("config")
400
-
401
- # Use dataset-specific terms and genes from pra function
402
- terms = dload("common", f"terms_{dataset_name}")
403
- genes_present = dload("common", f"genes_present_in_terms_{dataset_name}")
404
-
405
- if terms is None:
406
- log.warning(f"No dataset-specific terms found for {dataset_name}, using global terms")
407
- terms = dload("common", "terms")
408
- genes_present = dload("common", "genes_present_in_terms")
409
-
410
- sorting = dload("input", "sorting")
411
- sort_order = sorting.get(dataset_name, "high")
500
+ ascending = _sort_ascending_for_dataset(dataset_name)
501
+ per_complex_config = config.get("per_complex", {})
502
+ if not isinstance(per_complex_config, dict):
503
+ per_complex_config = {}
504
+ chunk_size_value = (
505
+ chunk_size if chunk_size is not None else per_complex_config.get("chunk_size", 200)
506
+ )
507
+ n_jobs_value = n_jobs if n_jobs is not None else per_complex_config.get("n_jobs", 4)
508
+ max_nbytes = per_complex_config.get("max_nbytes", "100M")
509
+
510
+ try:
511
+ effective_chunk_size = int(chunk_size_value)
512
+ effective_n_jobs = int(n_jobs_value)
513
+ except (TypeError, ValueError) as exc:
514
+ raise ValueError(
515
+ "per-complex chunk_size and n_jobs must be integer-compatible values."
516
+ ) from exc
517
+
518
+ if effective_chunk_size <= 0:
519
+ raise ValueError("per-complex chunk_size must be greater than 0.")
520
+ if effective_n_jobs <= 0:
521
+ raise ValueError("per-complex n_jobs must be greater than 0.")
522
+
412
523
  if not is_corr:
413
524
  matrix = perform_corr(matrix, config.get("corr_function"))
414
- matrix = filter_matrix_by_genes(matrix, genes_present)
525
+
526
+ # Prefer terms prepared by pra(); if absent, prepare them here so direct
527
+ # pra_percomplex() calls use the same dataset-specific gene universe.
528
+ terms = dload("common", f"terms_{dataset_name}")
529
+ genes_present = dload("common", f"genes_present_in_terms_{dataset_name}")
530
+
531
+ if not isinstance(terms, pd.DataFrame) or genes_present is None:
532
+ log.warning(
533
+ f"No dataset-specific terms found for {dataset_name}; preparing them now."
534
+ )
535
+ terms, genes_present, matrix = prepare_terms_for_dataset(dataset_name, matrix)
536
+ else:
537
+ matrix = filter_matrix_by_genes(matrix, genes_present)
538
+
415
539
  log.info(f"Matrix shape: {matrix.shape}")
416
540
  df = binary(matrix)
417
541
  log.info(f"Pair-wise shape: {df.shape}")
418
- df = quick_sort(df, ascending=(sort_order == "low"))
542
+ df = quick_sort(df, ascending=ascending)
419
543
  pairwise_df = df.copy()
420
544
  pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
421
545
  pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
@@ -433,26 +557,42 @@ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
433
557
  pairwise_df = _precompute_complex_ids(pairwise_df, gold_pair_to_complex)
434
558
  log.done("Precomputing complex IDs") #
435
559
 
560
+ chunks = [
561
+ terms.iloc[i:i + effective_chunk_size]
562
+ for i in range(0, len(terms), effective_chunk_size)
563
+ ]
564
+ min_genes = config["min_genes_per_complex_analysis"]
565
+
566
+ if not chunks:
567
+ terms["auc_score"] = pd.Series(dtype=float)
568
+ terms["corrected_auc_score"] = pd.Series(dtype=float)
569
+ dsave(terms, "pra_percomplex", dataset_name)
570
+ log.done("Per-complex PRA completed with no eligible terms.")
571
+ return terms
572
+
436
573
  log.info('Dumping pairwise_df to memmap')
437
574
  memmap_path = _dump_pairwise_memmap(pairwise_df, dataset_name)
438
575
  log.done('Dumping pairwise_df to memmap')
439
576
 
440
- # choose smaller chunks now that pickling cost is gone
441
- chunks = [terms.iloc[i:i+chunk_size] for i in range(0, len(terms), chunk_size)]
442
- min_genes = config["min_genes_per_complex_analysis"]
443
-
444
577
  # Initialize results variable
445
578
  results = None
446
579
 
447
580
  try:
448
581
  # Compatible parallel execution for older joblib versions
449
582
  log.started("Processing chunks in parallel")
450
-
583
+ actual_n_jobs = min(effective_n_jobs, len(chunks))
584
+ log.info(
585
+ "Per-complex parallel config: "
586
+ f"n_jobs={actual_n_jobs}, requested_n_jobs={effective_n_jobs}, "
587
+ f"chunk_size={effective_chunk_size}, chunks={len(chunks)}, "
588
+ f"max_nbytes={max_nbytes}"
589
+ )
590
+
451
591
  # Use a more conservative approach with older joblib
452
592
  results = Parallel(
453
- n_jobs=min(4, len(chunks)), # Limit to 4 workers or number of chunks
593
+ n_jobs=actual_n_jobs,
454
594
  temp_folder=os.path.dirname(memmap_path),
455
- max_nbytes='100M', # Set memory limit
595
+ max_nbytes=max_nbytes,
456
596
  verbose=1 # Show progress
457
597
  )(delayed(_process_chunk)(chunk, min_genes, memmap_path, gene_to_pair_indices)
458
598
  for chunk in tqdm(chunks, desc="Per-complex PRA"))
@@ -483,11 +623,13 @@ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
483
623
  # Merge results with enhanced error handling
484
624
  auc_scores = {}
485
625
  corrected_auc_scores = {}
626
+ errors = []
486
627
  if results:
487
628
  for i, res in enumerate(results):
488
629
  if isinstance(res, dict):
489
630
  if 'error' in res:
490
631
  log.error(f"Error in chunk {i}: {res['error']}")
632
+ errors.append(f"chunk {i}: {res['error']}")
491
633
  elif 'auc' in res and 'corrected_auc' in res:
492
634
  # New format with both AUC types
493
635
  auc_scores.update(res['auc'])
@@ -497,8 +639,15 @@ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
497
639
  auc_scores.update(res)
498
640
  elif isinstance(res, tuple) and len(res) >= 2 and res[0] is None:
499
641
  log.error(f"Chunk {i} error: {res[1]}")
642
+ errors.append(f"chunk {i}: {res[1]}")
500
643
  else:
501
644
  log.warning(f"Unexpected result type from chunk {i}: {type(res)} - {res}")
645
+ errors.append(f"chunk {i}: unexpected result type {type(res)}")
646
+
647
+ if errors:
648
+ preview = "; ".join(errors[:3])
649
+ extra = f" ({len(errors) - 3} more)" if len(errors) > 3 else ""
650
+ raise RuntimeError(f"Per-complex PRA failed in worker chunks: {preview}{extra}")
502
651
 
503
652
  # Add the computed AUC scores to the terms DataFrame.
504
653
  terms["auc_score"] = pd.Series(auc_scores)
@@ -510,10 +659,20 @@ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
510
659
  def complex_contributions(name):
511
660
  log.info(f"Computing complex contributions (Greedy) for dataset: {name}")
512
661
  pra = dload("pra", name)
513
- terms = dload("common", "terms")
662
+ terms = dload("common", f"terms_{name}")
663
+ if not isinstance(terms, pd.DataFrame):
664
+ # Fallback for backward compatibility
665
+ terms = dload("common", "terms")
666
+ if not isinstance(pra, pd.DataFrame) or pra.empty:
667
+ raise RuntimeError(f"complex_contributions(): PRA data for dataset '{name}' not found.")
668
+ if not isinstance(terms, pd.DataFrame) or terms.empty:
669
+ raise RuntimeError(f"complex_contributions(): terms for dataset '{name}' not found.")
514
670
 
515
- # Ensure pra is sorted by score descending (matches R's order by predicted descending)
516
- pra = pra.sort_values(by='score', ascending=False).reset_index(drop=True)
671
+ # Respect the dataset's score direction: high scores by default, low scores if configured.
672
+ pra = pra.sort_values(
673
+ by='score',
674
+ ascending=_sort_ascending_for_dataset(name),
675
+ ).reset_index(drop=True)
517
676
 
518
677
  # Compute cumulative TP and precision (matches R's TP.count = cumsum(true), Precision = TP / (1:n))
519
678
  pra['cumTP'] = pra['prediction'].cumsum()
@@ -808,6 +967,9 @@ def binary(corr):
808
967
 
809
968
  stack = corr.stack().rename_axis(index=['gene1', 'gene2']).\
810
969
  reset_index().rename(columns={0: 'score'})
970
+ if stack.empty:
971
+ log.done("Pair-wise conversion.")
972
+ return stack
811
973
  if has_mirror_of_first_pair(stack):
812
974
  log.info("Mirror pairs detected. Dropping them to ensure unique gene pairs.")
813
975
  stack = drop_mirror_pairs(stack)
@@ -858,7 +1020,7 @@ def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_pe
858
1020
  continue
859
1021
 
860
1022
  if category == "mpr_complexes_auc" and isinstance(data, dict):
861
- # Dict[dataset_name -> Dict[filter_key -> auc]]
1023
+ # Dict[dataset_name -> Dict[variant_key -> auc]]
862
1024
  try:
863
1025
  df = pd.DataFrame.from_dict(data, orient="index")
864
1026
  df.index.name = "Dataset"
@@ -901,21 +1063,10 @@ def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_pe
901
1063
 
902
1064
  log.done("Results saved to CSV files in the output folder.")
903
1065
 
904
- ################### mPR
905
- ################### mPR ###################
906
-
907
-
908
-
909
- ################### mPR ###################
910
- ################### mPR ###################
911
-
912
1066
  # -----------------------------------------------------------------------------
913
1067
  # mPR preparation (module-level precision–recall, Fig. 1E / 1F)
914
1068
  # -----------------------------------------------------------------------------
915
1069
 
916
- import numpy as np
917
- import pandas as pd
918
-
919
1070
 
920
1071
  def _mpr_get_mtRibo_ETCI_ids(terms_like):
921
1072
  """
@@ -970,34 +1121,16 @@ def _mpr_get_small_high_auprc_ids(
970
1121
  # Helpers implementing the FLEX stepwise module-level PR logic
971
1122
  # -------------------------------------------------------------------------
972
1123
 
973
- """
974
- CORRECT FIX for _mpr_build_pairs in analysis.py
975
-
976
- The issue: The current code marks filtered TPs as true=0, which makes them
977
- count as False Positives and dramatically lowers precision.
978
-
979
- The R code (getSubsetOfCoAnnRemoveIDs with replace=FALSE) REMOVES the
980
- filtered positive pairs entirely from the dataset.
981
-
982
- This is the key difference:
983
- - Current Python: Keeps all rows, filtered TPs become FPs → precision tanks
984
- - R code: Removes filtered TP rows → they don't affect precision at all
985
- """
986
-
987
- import numpy as np
988
- import pandas as pd
989
-
990
-
991
- def _mpr_build_pairs(pra, removed_ids=None):
1124
+ def _mpr_build_pairs(pra, removed_ids=None, ascending=False):
992
1125
  """
993
1126
  Build a Pairs.in.data-like table for mPR / stepwise contributions.
994
-
995
- FIXED: Removes rows that contain filtered complex IDs (matching R behavior)
996
- instead of marking them as true=0.
1127
+
1128
+ Rows containing filtered positive complex IDs are removed from the ranking,
1129
+ matching the FLEX stepwise module-level precision-recall behavior.
997
1130
 
998
1131
  Input:
999
1132
  pra : DataFrame with at least columns
1000
- - 'score' : ranking score (higher = better)
1133
+ - 'score' : ranking score
1001
1134
  - 'complex_id' : complex annotations
1002
1135
  removed_ids : set[int] of complexes to remove
1003
1136
 
@@ -1056,11 +1189,9 @@ def _mpr_build_pairs(pra, removed_ids=None):
1056
1189
  out["complex_ids"] = df[cid_col].apply(normalize_ids)
1057
1190
  out["true"] = out["complex_ids"].apply(lambda ids: 1 if len(ids) > 0 else 0)
1058
1191
 
1059
- # KEY FIX: Remove rows that are TPs AND contain a removed complex ID
1060
- # This matches the R behavior of getSubsetOfCoAnnRemoveIDs with replace=FALSE
1192
+ # Remove rows that are TPs and contain a removed complex ID.
1061
1193
  if removed_ids:
1062
1194
  should_remove_mask = df[cid_col].apply(should_remove)
1063
- # Only remove if it's a TP (true=1)
1064
1195
  remove_mask = should_remove_mask & (out["true"] == 1)
1065
1196
  out = out[~remove_mask].copy()
1066
1197
 
@@ -1070,41 +1201,11 @@ def _mpr_build_pairs(pra, removed_ids=None):
1070
1201
  lambda ids: [cid for cid in ids if cid not in removed_ids]
1071
1202
  )
1072
1203
 
1073
- # Sort by predicted descending
1074
- out = out.sort_values("predicted", ascending=False).reset_index(drop=True)
1204
+ # Sort by the dataset's configured score direction.
1205
+ out = out.sort_values("predicted", ascending=ascending).reset_index(drop=True)
1075
1206
  return out
1076
1207
 
1077
1208
 
1078
- # ============================================================================
1079
- # HOW TO APPLY THIS FIX
1080
- # ============================================================================
1081
- #
1082
- # In analysis.py, replace the _mpr_build_pairs function (around line 962-1025)
1083
- # with the _mpr_build_pairs_fixed function above.
1084
- #
1085
- # The key changes are:
1086
- #
1087
- # 1. REMOVE rows instead of marking true=0:
1088
- #
1089
- # OLD CODE:
1090
- # if removed_ids:
1091
- # ids = [cid for cid in ids if cid not in removed_ids]
1092
- # return ids
1093
- # ...
1094
- # out["true"] = out["complex_ids"].apply(lambda ids: 1 if len(ids) > 0 else 0)
1095
- #
1096
- # NEW CODE:
1097
- # # First compute true normally
1098
- # out["true"] = out["complex_ids"].apply(lambda ids: 1 if len(ids) > 0 else 0)
1099
- #
1100
- # # Then REMOVE rows that are TPs and contain removed IDs
1101
- # if removed_ids:
1102
- # should_remove_mask = df[cid_col].apply(should_remove)
1103
- # remove_mask = should_remove_mask & (out["true"] == 1)
1104
- # out = out[~remove_mask].copy()
1105
- #
1106
- # ============================================================================
1107
-
1108
1209
  def _mpr_precision_cutoffs_from_pairs(pairs, step=0.025):
1109
1210
  """
1110
1211
  Choose precision cutoffs similar to FLEX:
@@ -1136,7 +1237,7 @@ def _mpr_precision_cutoffs_from_pairs(pairs, step=0.025):
1136
1237
  return np.array(cuts, dtype=float)
1137
1238
 
1138
1239
 
1139
- def _mpr_stepwise_contributions(pairs, precision_cutoffs):
1240
+ def _mpr_stepwise_contributions(pairs, precision_cutoffs, ascending=False):
1140
1241
  """
1141
1242
  Greedy, stepwise TP allocation per complex at each precision cutoff.
1142
1243
 
@@ -1151,7 +1252,7 @@ def _mpr_stepwise_contributions(pairs, precision_cutoffs):
1151
1252
  contrib_df : DataFrame [complex_id x cutoff] with TP counts
1152
1253
  """
1153
1254
  pairs = pairs.copy()
1154
- pairs = pairs.sort_values("predicted", ascending=False).reset_index(drop=True)
1255
+ pairs = pairs.sort_values("predicted", ascending=ascending).reset_index(drop=True)
1155
1256
 
1156
1257
  true = pairs["true"].to_numpy(dtype=int)
1157
1258
  n = len(true)
@@ -1271,16 +1372,29 @@ def _mpr_module_coverage(contrib_df, terms, tp_th=1, percent_th=0.1):
1271
1372
  row = terms.loc[cid_int]
1272
1373
 
1273
1374
  n_genes = None
1274
-
1275
- # FIXED: Handle all_genes as list (how it's stored in preprocessing)
1276
- if "all_genes" in row.index:
1375
+
1376
+ # Prefer used_genes (genes actually in the dataset) for a fair coverage
1377
+ # fraction. This matters for GOBP/PATHWAY where all_genes >> used_genes.
1378
+ if "used_genes" in row.index:
1379
+ genes = row["used_genes"]
1380
+ if isinstance(genes, (list, np.ndarray)) and len(genes) > 0:
1381
+ n_genes = len(genes)
1382
+ if n_genes is None and "n_used_genes" in row.index:
1383
+ try:
1384
+ v = int(row["n_used_genes"])
1385
+ if v > 0:
1386
+ n_genes = v
1387
+ except (ValueError, TypeError):
1388
+ pass
1389
+
1390
+ # Fallback: all_genes (how it's stored in preprocessing)
1391
+ if n_genes is None and "all_genes" in row.index:
1277
1392
  genes = row["all_genes"]
1278
- if isinstance(genes, list):
1393
+ if isinstance(genes, (list, np.ndarray)):
1279
1394
  n_genes = len(genes)
1280
1395
  elif isinstance(genes, str):
1281
- # Fallback if stored as string
1282
1396
  n_genes = len([g for g in genes.split(";") if g])
1283
-
1397
+
1284
1398
  # Fallback to Genes column (original string from CORUM)
1285
1399
  if n_genes is None and "Genes" in row.index:
1286
1400
  genes_str = row["Genes"]
@@ -1337,7 +1451,7 @@ def _mpr_complexes_auc(
1337
1451
 
1338
1452
  We compute a normalized AUC by integrating precision over the *normalized*
1339
1453
  coverage axis:
1340
- AUC = \int y \, d(x/max_complexes)
1454
+ AUC = integral y d(x/max_complexes)
1341
1455
 
1342
1456
  This yields a score in [0, 1] (or NaN if insufficient data).
1343
1457
  """
@@ -1347,7 +1461,7 @@ def _mpr_complexes_auc(
1347
1461
  if cov.size == 0 or prec.size == 0:
1348
1462
  return 0.0
1349
1463
 
1350
- # Match plot_mpr_complexes_multi(): only count cov>0 (log-x cannot show 0)
1464
+ # Match plot_mpr_complex_coverage_curve(): only count cov>0 (log-x cannot show 0)
1351
1465
  mask = (
1352
1466
  np.isfinite(cov)
1353
1467
  & np.isfinite(prec)
@@ -1409,26 +1523,31 @@ def mpr_prepare(
1409
1523
  """
1410
1524
  pra = dload("pra", name)
1411
1525
  pra_percomplex = dload("pra_percomplex", name)
1412
- terms = dload("common", "terms")
1526
+ terms = dload("common", f"terms_{name}")
1527
+ if not isinstance(terms, pd.DataFrame):
1528
+ # Fallback for backward compatibility
1529
+ terms = dload("common", "terms")
1413
1530
 
1414
- if pra is None:
1531
+ if pra is None or not isinstance(pra, pd.DataFrame) or pra.empty:
1415
1532
  raise RuntimeError(
1416
1533
  f"mpr_prepare(): PRA data for dataset '{name}' not found "
1417
1534
  "(dload('pra', name))."
1418
1535
  )
1419
- if pra_percomplex is None:
1536
+ if pra_percomplex is None or not isinstance(pra_percomplex, pd.DataFrame) or pra_percomplex.empty:
1420
1537
  raise RuntimeError(
1421
1538
  f"mpr_prepare(): per-complex PRA data for dataset '{name}' not found "
1422
1539
  "(dload('pra_percomplex', name))."
1423
1540
  )
1424
- if terms is None:
1541
+ if terms is None or not isinstance(terms, pd.DataFrame) or terms.empty:
1425
1542
  raise RuntimeError(
1426
1543
  "mpr_prepare(): CORUM 'terms' table not found (dload('common', 'terms'))."
1427
1544
  )
1428
1545
 
1429
- # sort by score descending (ranking)
1546
+ ascending = _sort_ascending_for_dataset(name)
1547
+
1548
+ # Sort by the dataset's configured score direction.
1430
1549
  if "score" in pra.columns:
1431
- pra = pra.sort_values("score", ascending=False).reset_index(drop=True)
1550
+ pra = pra.sort_values("score", ascending=ascending).reset_index(drop=True)
1432
1551
  else:
1433
1552
  pra = pra.reset_index(drop=True)
1434
1553
 
@@ -1454,7 +1573,7 @@ def mpr_prepare(
1454
1573
 
1455
1574
  for label, removed in filter_sets.items():
1456
1575
  # 1) Build pairs table after removing complexes in `removed`
1457
- pairs = _mpr_build_pairs(pra, removed_ids=removed)
1576
+ pairs = _mpr_build_pairs(pra, removed_ids=removed, ascending=ascending)
1458
1577
 
1459
1578
  true = pairs["true"].to_numpy(dtype=int)
1460
1579
  n = len(true)
@@ -1481,13 +1600,24 @@ def mpr_prepare(
1481
1600
  if precision_cutoffs is None:
1482
1601
  precision_cutoffs = _mpr_precision_cutoffs_from_pairs(pairs)
1483
1602
 
1484
- contrib_df = _mpr_stepwise_contributions(pairs, precision_cutoffs)
1603
+ contrib_df = _mpr_stepwise_contributions(
1604
+ pairs,
1605
+ precision_cutoffs,
1606
+ ascending=ascending,
1607
+ )
1485
1608
  cov = _mpr_module_coverage(
1486
1609
  contrib_df,
1487
1610
  terms,
1488
1611
  tp_th=tp_th,
1489
1612
  percent_th=percent_th,
1490
1613
  )
1614
+ # precision_cutoffs are sorted ascending (low → high).
1615
+ # Coverage must be non-increasing in that direction: a more permissive
1616
+ # threshold (lower precision) should never yield fewer covered terms.
1617
+ # The independent greedy allocation per cutoff can violate this, so
1618
+ # enforce monotonicity by propagating the max from right to left.
1619
+ if cov.size > 0:
1620
+ cov = np.maximum.accumulate(cov[::-1])[::-1]
1491
1621
  coverage_curves[label] = cov
1492
1622
  complexes_auc[label] = _mpr_complexes_auc(
1493
1623
  cov,
@@ -1515,424 +1645,3 @@ def mpr_prepare(
1515
1645
 
1516
1646
  # Convenience: store AUCs as their own category for easy export / plotting.
1517
1647
  dsave(complexes_auc, "mpr_complexes_auc", name)
1518
-
1519
-
1520
-
1521
- ### OLD FUNCTIONS
1522
-
1523
- # new but withoutparallel
1524
-
1525
- # def pra_percomplex(dataset_name, matrix, is_corr=False):
1526
- # log.started(f"*** Per-complex PRA started - {dataset_name} ***")
1527
- # config = dload("config")
1528
- # terms = dload("tmp", "terms")
1529
- # genes_present = dload("tmp", "genes_present_in_terms")
1530
- # sorting = dload("input", "sorting")
1531
- # sort_order = sorting.get(dataset_name, "high")
1532
- # if not is_corr:
1533
- # matrix = perform_corr(matrix, config.get("corr_function"))
1534
- # matrix = filter_matrix_by_genes(matrix, genes_present)
1535
- # log.info(f"Matrix shape: {matrix.shape}")
1536
- # df = binary(matrix)
1537
- # log.info(f"Pair-wise shape: {df.shape}")
1538
- # df = quick_sort(df, ascending=(sort_order == "low"))
1539
- # pairwise_df = df.copy()
1540
- # pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
1541
- # pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
1542
-
1543
- # # Precompute a mapping from each gene to the row indices in the pairwise DataFrame where it appears.
1544
- # gene_to_pair_indices = {}
1545
- # for i, (gene_a, gene_b) in enumerate(zip(pairwise_df["gene1"], pairwise_df["gene2"])):
1546
- # gene_to_pair_indices.setdefault(gene_a, []).append(i)
1547
- # gene_to_pair_indices.setdefault(gene_b, []).append(i)
1548
- # log.done
1549
-
1550
- # # Build gold_pair_to_complex using sets for efficiency
1551
- # gold_pair_to_complex = defaultdict(set)
1552
- # for idx, row in terms.iterrows():
1553
- # genes = row.used_genes
1554
- # if len(genes) < 2:
1555
- # continue
1556
- # for i, g1 in enumerate(genes):
1557
- # for g2 in genes[i + 1:]:
1558
- # pair = tuple(sorted((g1, g2)))
1559
- # gold_pair_to_complex[pair].add(idx)
1560
-
1561
- # # Precompute complex_ids as semicolon-separated strings in pairwise_df
1562
- # pairs = [tuple(sorted((g1, g2))) for g1, g2 in zip(pairwise_df["gene1"], pairwise_df["gene2"])]
1563
- # pairwise_df['complex_ids'] = [';'.join(map(str, sorted(gold_pair_to_complex.get(pair, set())))) for pair in pairs]
1564
-
1565
- # # Initialize AUC scores
1566
- # auc_scores = {}
1567
- # # Loop over each gene complex
1568
- # for idx, row in tqdm(terms.iterrows()):
1569
- # gene_set = set(row.used_genes)
1570
- # if config["min_genes_per_complex_analysis"] > len(gene_set):
1571
- # continue
1572
- # # Collect all row indices in the pairwise data where either gene belongs to the complex.
1573
- # candidate_indices = bitarray(len(pairwise_df))
1574
- # for gene in gene_set:
1575
- # if gene in gene_to_pair_indices:
1576
- # candidate_indices[gene_to_pair_indices[gene]] = True
1577
-
1578
- # if not candidate_indices.any():
1579
- # continue
1580
-
1581
- # # Select only the relevant pairwise comparisons.
1582
- # selected_rows = np.unpackbits(candidate_indices).view(bool)[:len(pairwise_df)]
1583
- # sub_df = pairwise_df.iloc[selected_rows]
1584
-
1585
- # # Get current complex ID (assuming idx is the ID; adjust if row['ID'] is different)
1586
- # complex_id = str(idx) # Or str(row['ID']) if available
1587
-
1588
- # # Create true_label: 1 if complex_id in complex_ids (vectorized with str.contains)
1589
- # #true_label = sub_df['complex_ids'].str.contains(complex_id, regex=False).astype(int)
1590
-
1591
- # # Inside the loop, for each complex:
1592
- # # Inside the loop:
1593
- # complex_id = str(idx)
1594
- # # Use (?:^|;) and (?:;|$) to avoid capturing groups
1595
- # pattern = r'(?:^|;)' + re.escape(complex_id) + r'(?:;|$)'
1596
- # true_label = sub_df['complex_ids'].str.contains(pattern, regex=True).astype(int)
1597
- # # Filter to keep verified negatives (complex_ids == "") or positives for this complex (true_label == 1)
1598
- # complex_mask = (sub_df['complex_ids'] == "") | (true_label == 1)
1599
-
1600
- # # Use the masked true labels for AUPRC (avoids SettingWithCopyWarning)
1601
- # predictions = true_label[complex_mask]
1602
-
1603
- # if predictions.sum() == 0:
1604
- # continue
1605
- # # Compute cumulative true positives and derive precision and recall.
1606
- # true_positive_cumsum = predictions.cumsum()
1607
- # precision = true_positive_cumsum / (np.arange(len(predictions)) + 1)
1608
- # recall = true_positive_cumsum / true_positive_cumsum.iloc[-1]
1609
-
1610
- # if len(recall) < 2 or recall.iloc[-1] == 0:
1611
- # continue
1612
- # auc_scores[idx] = metrics.auc(recall, precision)
1613
-
1614
- # # Add the computed AUC scores to the terms DataFrame.
1615
- # terms["auc_score"] = pd.Series(auc_scores)
1616
- # terms.drop(columns=["hash"], inplace=True)
1617
- # dsave(terms, "pra_percomplex", dataset_name)
1618
- # log.done(f"Per-complex PRA completed.")
1619
- # return terms
1620
-
1621
- # it works quick but only maps 1 complex to each pair
1622
-
1623
- # def pra_percomplex_old_type_filtering(dataset_name, matrix, is_corr=False):
1624
- # log.started(f"*** Per-complex PRA started - {dataset_name} ***")
1625
- # config = dload("config")
1626
- # terms = dload("tmp", "terms")
1627
- # genes_present = dload("tmp", "genes_present_in_terms")
1628
- # sorting = dload("input", "sorting")
1629
- # sort_order = sorting.get(dataset_name, "high")
1630
- # if not is_corr:
1631
- # matrix = perform_corr(matrix, config.get("corr_function"))
1632
- # matrix = filter_matrix_by_genes(matrix, genes_present)
1633
- # log.info(f"Matrix shape: {matrix.shape}")
1634
- # df = binary(matrix)
1635
- # log.info(f"Pair-wise shape: {df.shape}")
1636
- # df = quick_sort(df, ascending=(sort_order == "low"))
1637
- # pairwise_df = df.copy()
1638
- # pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
1639
- # pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
1640
- # # Precompute a mapping from each gene to the row indices in the pairwise DataFrame where it appears.
1641
- # gene_to_pair_indices = {}
1642
- # for i, (gene_a, gene_b) in enumerate(zip(pairwise_df["gene1"], pairwise_df["gene2"])):
1643
- # gene_to_pair_indices.setdefault(gene_a, []).append(i)
1644
- # gene_to_pair_indices.setdefault(gene_b, []).append(i)
1645
- # # Initialize AUC scores (one for each complex) with NaNs.
1646
- # #auc_scores = np.full(len(terms), np.nan)
1647
- # auc_scores = {}
1648
- # # Loop over each gene complex
1649
- # for idx, row in tqdm(terms.iterrows()):
1650
- # gene_set = set(row.used_genes)
1651
-
1652
- # if config["min_genes_per_complex_analysis"] > len(gene_set):
1653
- # continue
1654
- # # Collect all row indices in the pairwise data where either gene belongs to the complex.
1655
- # candidate_indices = bitarray(len(pairwise_df))
1656
- # for gene in gene_set:
1657
- # if gene in gene_to_pair_indices:
1658
- # candidate_indices[gene_to_pair_indices[gene]] = True
1659
- # if not candidate_indices.any():
1660
- # continue
1661
- # # Select only the relevant pairwise comparisons.
1662
- # selected_rows = np.unpackbits(candidate_indices).view(bool)[:len(pairwise_df)]
1663
- # sub_df = pairwise_df.iloc[selected_rows]
1664
- # # A prediction is 1 if both genes in the pair are in the complex; otherwise 0.
1665
- # predictions = (sub_df["gene1"].isin(gene_set) & sub_df["gene2"].isin(gene_set)).astype(int)
1666
- # if predictions.sum() == 0:
1667
- # continue
1668
- # # Compute cumulative true positives and derive precision and recall.
1669
- # true_positive_cumsum = predictions.cumsum()
1670
- # precision = true_positive_cumsum / (np.arange(len(predictions)) + 1)
1671
- # recall = true_positive_cumsum / true_positive_cumsum.iloc[-1]
1672
-
1673
- # if len(recall) < 2 or recall.iloc[-1] == 0:
1674
- # continue
1675
- # auc_scores[idx] = metrics.auc(recall, precision)
1676
- # # Add the computed AUC scores to the terms DataFrame.
1677
- # terms["auc_score"] = pd.Series(auc_scores)
1678
- # terms.drop(columns=["hash"], inplace=True)
1679
- # dsave(terms, "pra_percomplex", dataset_name)
1680
- # log.done(f"Per-complex PRA completed.")
1681
- # return terms
1682
-
1683
- # OLD
1684
- # def pra_percomplex(dataset_name, matrix, is_corr=False):
1685
- # log.started(f"*** Per-complex PRA started for {dataset_name} ***")
1686
- # config = dload("config")
1687
- # terms = dload("tmp", "terms")
1688
- # genes_present = dload("tmp", "genes_present_in_terms")
1689
- # sorting = dload("input", "sorting")
1690
- # sort_order = sorting.get(dataset_name, "high")
1691
-
1692
- # if not is_corr:
1693
- # matrix = perform_corr(matrix, "numpy")
1694
- # matrix = filter_matrix_by_genes(matrix, genes_present)
1695
- # log.info(f"Matrix shape: {matrix.shape}")
1696
- # df = binary(matrix)
1697
- # log.info(f"Pair-wise shape: {df.shape}")
1698
- # df = quick_sort(df, ascending=(sort_order == "low"))
1699
- # # Precompute gene → row indices
1700
- # gene_to_rows = {}
1701
- # for i, (g1, g2) in enumerate(zip(df["gene1"], df["gene2"])):
1702
- # gene_to_rows.setdefault(g1, []).append(i)
1703
- # gene_to_rows.setdefault(g2, []).append(i)
1704
- # aucs = np.full(len(terms), np.nan)
1705
- # N = len(df)
1706
- # for idx, row in tqdm(terms.iterrows()):
1707
- # genes = set(row.used_genes)
1708
- # if len(genes) < config["min_complex_size_for_percomplex"]: # Skip small complexes
1709
- # continue
1710
- # # Get all row indices where either gene is in the complex
1711
- # candidate_idxs = set()
1712
- # for g in genes:
1713
- # candidate_idxs.update(gene_to_rows.get(g, []))
1714
- # candidate_idxs = sorted(candidate_idxs)
1715
- # if not candidate_idxs:
1716
- # continue
1717
- # # Use only relevant rows for prediction
1718
- # sub = df.loc[candidate_idxs]
1719
- # preds = (sub["gene1"].isin(genes) & sub["gene2"].isin(genes)).astype(int)
1720
- # if preds.sum() == 0:
1721
- # continue
1722
- # tp = preds.cumsum()
1723
- # prec = tp / (np.arange(len(preds)) + 1)
1724
- # recall = tp / tp.iloc[-1]
1725
- # if len(recall) < 2 or recall.iloc[-1] == 0:
1726
- # continue
1727
- # aucs[idx] = metrics.auc(recall, prec)
1728
- # terms["auc_score"] = aucs
1729
- # terms.drop(columns=["list", "set", "hash"], inplace=True)
1730
- # dsave(terms, "pra_percomplex", dataset_name)
1731
- # log.done(f"Per-complex PRA completed.")
1732
- # return terms
1733
-
1734
- # without greedy
1735
- # def complex_contributions(name):
1736
- # log.info(f"Computing complex contributions for dataset: {name}")
1737
-
1738
- # pra = dload("pra", name)
1739
- # terms = dload("tmp", "terms")
1740
- # d = pra.query('prediction == 1').drop(columns=['gene1', 'gene2'])
1741
- # results = {}
1742
- # thresholds = [round(i, 2) for i in np.arange(1, 0.0001, -0.025)]
1743
- # for cid in terms.ID.to_list():
1744
- # arr = []
1745
- # for threshold in thresholds:
1746
- # r = d[d.complex_id == cid].query('precision >= @threshold')
1747
- # arr.append(r.shape[0])
1748
- # results[cid] = arr
1749
-
1750
- # r = pd.DataFrame(results, index=thresholds).T
1751
- # t = terms[['ID', 'Name']].set_index('ID')
1752
- # r['Name'] = r.index.map(t.Name)
1753
- # r = r[list(reversed(list(r.columns)))]
1754
- # r = r.reset_index(drop=True)
1755
- # dsave(r, "complex_contributions", name)
1756
- # log.info(f"Complex contributions computation completed for dataset: {name}")
1757
- # return r
1758
-
1759
- # # new
1760
- # def complex_contributions(name):
1761
- # log.info(f"Computing complex contributions using R-style greedy logic for dataset: {name}")
1762
- # pra = dload("pra", name)
1763
- # terms = dload("common", "terms")
1764
-
1765
- # # Ensure pra is sorted by score descending
1766
- # pra = pra.sort_values(by='score', ascending=False).reset_index(drop=True)
1767
-
1768
- # # Compute cumulative TP and precision if not present
1769
- # pra['cumTP'] = pra['prediction'].cumsum()
1770
- # pra['rank'] = pra.index + 1
1771
- # pra['precision'] = pra['cumTP'] / pra['rank']
1772
-
1773
- # # R-style precision thresholds
1774
- # prec_min = pra['precision'].min()
1775
- # prec_max = pra['precision'].max()
1776
- # precision_cutoffs = [round(prec_min, 3)]
1777
- # cutoffs_range = np.arange(0.1, prec_max + 0.001, 0.025)
1778
- # precision_cutoffs += [round(t, 3) for t in cutoffs_range if t > prec_min]
1779
- # thresholds = sorted(set(precision_cutoffs)) # Ensure unique and sorted
1780
-
1781
- # results = {}
1782
- # for t in thresholds:
1783
- # if pra['precision'].max() < t:
1784
- # continue
1785
- # cand = pra[pra['precision'] >= t]
1786
- # if cand.empty:
1787
- # continue
1788
- # k = cand.index.max() # rightmost index where precision >= t
1789
- # tp_target = pra.loc[k, 'cumTP']
1790
- # # Find the smallest m where cumTP[m] >= tp_target
1791
- # ind = pra[pra['cumTP'] >= tp_target].index.min()
1792
- # if pd.isna(ind):
1793
- # continue
1794
- # # Select top (ind+1) rows
1795
- # tmp = pra.iloc[0:ind + 1].copy()
1796
- # # Filter for predicted positives (true == 1)
1797
- # tmp = tmp[tmp['prediction'] == 1]
1798
- # tmp = tmp[tmp["complex_id"].notnull()]
1799
- # tmp["ID"] = tmp["complex_id"].apply(lambda ids: ";".join(str(int(i)) for i in ids if pd.notnull(i)))
1800
- # # Now greedy logic
1801
- # final_contrib = {}
1802
- # while not tmp.empty:
1803
- # all_ids = tmp["ID"].str.split(";").explode()
1804
- # contrib = all_ids.value_counts()
1805
- # if contrib.empty:
1806
- # break
1807
- # top_id = contrib.idxmax()
1808
- # final_contrib[top_id] = contrib[top_id]
1809
- # tmp = tmp[~tmp["ID"].str.contains(rf"\b{top_id}\b", regex=True)]
1810
- # for cid, count in final_contrib.items():
1811
- # if cid not in results:
1812
- # results[cid] = [0] * len(thresholds)
1813
- # results[cid][thresholds.index(t)] = count
1814
-
1815
- # # Add back gold standard complexes with 0 contribution
1816
- # gold_ids = set(terms.index.astype(str))
1817
- # all_ids = set(results.keys())
1818
- # missing_ids = gold_ids - all_ids
1819
- # for cid in missing_ids:
1820
- # results[cid] = [0] * len(thresholds)
1821
-
1822
- # # Build result DataFrame
1823
- # r = pd.DataFrame(results, index=thresholds).T
1824
- # r['Name'] = r.index.astype(int).map(terms['Name'])
1825
- # r = r[['Name'] + [c for c in r.columns if c != 'Name']] # Name as first col
1826
- # r = r[(r.drop(columns="Name").sum(axis=1) > 0)]
1827
- # # Move ID to first column, keep Name second, then precision columns in order
1828
- # dsave(r, "complex_contributions", name)
1829
- # log.info(f"Greedy R-style complex contribution completed for dataset: {name}")
1830
- # return r
1831
-
1832
- # def pra(dataset_name, matrix, is_corr=False):
1833
- # log.info(f"******************** {dataset_name} ********************")
1834
- # log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
1835
- # config = dload("config")
1836
-
1837
- # terms_data = dload("tmp", "terms")
1838
- # if terms_data is None or not isinstance(terms_data, pd.DataFrame):
1839
- # raise ValueError("Expected 'terms' to be a DataFrame, but got None or invalid type.")
1840
- # terms = terms_data
1841
- # genes_present = dload("tmp", "genes_present_in_terms")
1842
- # sorting = dload("input", "sorting")
1843
- # sort_order = sorting.get(dataset_name, "high")
1844
-
1845
- # if not is_corr:
1846
- # matrix = perform_corr(matrix, config.get("corr_function"))
1847
-
1848
- # matrix = filter_matrix_by_genes(matrix, genes_present)
1849
-
1850
- # log.info(f"Matrix shape: {matrix.shape}")
1851
- # df = binary(matrix)
1852
- # log.info(f"Pair-wise shape: {df.shape}")
1853
- # df = quick_sort(df, ascending=(sort_order == "low"))
1854
-
1855
- # gold_pair_to_complex = defaultdict(list)
1856
- # for idx, row in terms.iterrows():
1857
- # genes = row.used_genes
1858
- # if len(genes) < 2:
1859
- # continue
1860
- # for i, g1 in enumerate(genes):
1861
- # for g2 in genes[i + 1:]:
1862
- # pair = tuple(sorted((g1, g2)))
1863
- # gold_pair_to_complex[pair].append(idx)
1864
-
1865
- # # Label predictions and complex IDs
1866
- # complex_ids = []
1867
- # predictions = []
1868
- # for g1, g2 in zip(df["gene1"], df["gene2"]):
1869
- # pair = tuple(sorted((g1, g2)))
1870
- # ids = gold_pair_to_complex.get(pair, [])
1871
- # if ids:
1872
- # predictions.append(1)
1873
- # complex_ids.append(ids)
1874
- # else:
1875
- # predictions.append(0)
1876
- # complex_ids.append([])
1877
-
1878
- # df["prediction"] = predictions
1879
- # df["complex_id"] = complex_ids
1880
-
1881
- # if df["prediction"].sum() == 0:
1882
- # log.info("No true positives found in dataset.")
1883
- # pr_auc = np.nan
1884
- # else:
1885
- # tp = df["prediction"].cumsum()
1886
- # df["tp"] = tp
1887
- # precision = tp / (np.arange(len(df)) + 1)
1888
- # recall = tp / tp.iloc[-1]
1889
- # pr_auc = metrics.auc(recall, precision)
1890
- # df["precision"] = precision
1891
- # df["recall"] = recall
1892
-
1893
- # log.info(f"PR-AUC: {pr_auc:.4f}, Number of true positives: {df['prediction'].sum()}")
1894
- # dsave(df, "pra", dataset_name)
1895
- # dsave(pr_auc, "pr_auc", dataset_name)
1896
- # log.done(f"Global PRA completed for {dataset_name}")
1897
- # return df, pr_auc
1898
-
1899
- # def compute_pra(df):
1900
- # log.info("Calculating precision-recall and AUC score.")
1901
- # if df.empty:
1902
- # log.warning("Empty DataFrame encountered in compute_pra. Returning empty DataFrame.")
1903
- # return df
1904
- # df["tp"] = df["prediction"].cumsum()
1905
- # df.reset_index(drop=True, inplace=True)
1906
- # df["precision"] = df["tp"] / (df.index + 1)
1907
- # df["recall"] = df["tp"] / df["tp"].iloc[-1]
1908
- # log.info("DONE: Calculating precision-recall AUC score.")
1909
- # return df
1910
-
1911
- # def pra(dataset_name, matrix, is_corr=False):
1912
- # log.info(f"PRA computation started for {dataset_name}.")
1913
- # genes_present_in_terms = dload("tmp", "genes_present_in_terms")
1914
- # #terms_hash_table = dload("tmp", "terms_hash_table")
1915
- # sorting_prefs = dload("input", "sorting")
1916
- # sort_order = sorting_prefs.get(dataset_name, "high")
1917
- # if not is_corr: matrix = perform_corr(matrix, "numpy")
1918
- # matrix = filter_matrix_by_genes(matrix, genes_present_in_terms)
1919
- # stack = binary(matrix)
1920
-
1921
- # log.info("Checking gene pairs against the gold standard.")
1922
- # gene_pairs = list(zip(stack["gene1"], stack["gene2"]))
1923
- # hashed_pairs = [hash(pair) for pair in gene_pairs]
1924
- # stack["complex_id"] = [terms_hash_table.get(h, 0) for h in hashed_pairs]
1925
- # stack["prediction"] = [1 if h in terms_hash_table else 0 for h in hashed_pairs]
1926
-
1927
- # annotated = stack.copy()
1928
- # if sort_order == "low":
1929
- # ann_sorted = quick_sort(annotated, ascending=True)
1930
- # else:
1931
- # ann_sorted = quick_sort(annotated)
1932
-
1933
- # pra = compute_pra(ann_sorted)
1934
- # pr_auc = metrics.auc(pra.recall, pra.precision)
1935
- # dsave(pra, "pra", dataset_name)
1936
- # dsave(pr_auc, "pr_auc", dataset_name)
1937
- # log.info(f"PRA computation completed for {dataset_name} (Sorting: {sort_order}).")
1938
- # return pra, pr_auc