pythonflex 0.3.4__py3-none-any.whl → 0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pythonflex/analysis.py CHANGED
@@ -11,7 +11,6 @@ from pathlib import Path
11
11
  from art import tprint
12
12
  from bitarray import bitarray
13
13
  from joblib import Parallel, delayed, dump, load
14
- import matplotlib.pyplot as plt
15
14
  from numba import njit, prange
16
15
  import numpy as np
17
16
  import pandas as pd
@@ -20,8 +19,8 @@ from tqdm import tqdm
20
19
 
21
20
  # Local/application-specific imports
22
21
  from .logging_config import log
23
- from .preprocessing import filter_matrix_by_genes
24
- from .utils import dsave, dload, _sanitize
22
+ from .preprocessing import filter_matrix_by_genes, filter_duplicate_terms
23
+ from .utils import dsave, dload, _sanitize, normalize_analysis_genes
25
24
 
26
25
  import matplotlib as mpl
27
26
 
@@ -36,6 +35,8 @@ def deep_update(source, overrides):
36
35
 
37
36
  def initialize(config={}):
38
37
 
38
+ user_overrides = config if isinstance(config, dict) else {}
39
+
39
40
  default_config = {
40
41
  "min_genes_in_complex": 3,
41
42
  "min_genes_per_complex_analysis": 2,
@@ -43,8 +44,10 @@ def initialize(config={}):
43
44
  "gold_standard": "CORUM",
44
45
  "color_map": "RdYlBu",
45
46
  "jaccard": True,
46
- "jaccard_threshold": 1.0,
47
- "use_common_genes": True,
47
+ # Which genes are used for analysis (drives used_genes intersection)
48
+ # - 'shared' : use genes common to all datasets (common_genes)
49
+ # - 'dataset_specific' : use genes present in each dataset individually
50
+ "analysis_genes": "shared",
48
51
  "plotting": {
49
52
  "save_plot": True,
50
53
  "show_plot": True,
@@ -56,6 +59,11 @@ def initialize(config={}):
56
59
  "drop_na": False,
57
60
  },
58
61
  "corr_function": "numpy",
62
+ "per_complex": {
63
+ "n_jobs": 4,
64
+ "chunk_size": 200,
65
+ "max_nbytes": "100M",
66
+ },
59
67
  "logging": { # Added: Default logging config
60
68
  "visible_levels": ["DONE"] # if needed #, "PROGRESS", "STARTED", "INFO"
61
69
  }
@@ -66,6 +74,26 @@ def initialize(config={}):
66
74
  config = deep_update(default_config, config)
67
75
  else:
68
76
  config = default_config
77
+
78
+ # Backward compatibility: if user provided legacy key but not the new one,
79
+ # map it to analysis_genes. (We must look at the original overrides, because
80
+ # defaults always include analysis_genes.)
81
+ analysis_genes_provided = (
82
+ isinstance(user_overrides, dict)
83
+ and "analysis_genes" in user_overrides
84
+ and user_overrides.get("analysis_genes") is not None
85
+ and str(user_overrides.get("analysis_genes")).strip() != ""
86
+ )
87
+ if (
88
+ isinstance(user_overrides, dict)
89
+ and "use_common_genes" in user_overrides
90
+ and not analysis_genes_provided
91
+ ):
92
+ config["analysis_genes"] = (
93
+ "shared" if bool(user_overrides.get("use_common_genes")) else "dataset_specific"
94
+ )
95
+
96
+ config["analysis_genes"] = normalize_analysis_genes(config.get("analysis_genes"))
69
97
 
70
98
  # Extract visible_levels from the merged config and set logging visibility immediately (before any logs)
71
99
  visible_levels = config.get("logging", {}).get("visible_levels", ["DONE"])
@@ -111,12 +139,15 @@ def update_matploblib_config(config=None, font_family="Arial", layout="single"):
111
139
  if config is None:
112
140
  config = {}
113
141
  # Fallback if chosen font missing
142
+ requested_font_family = font_family
114
143
  try:
115
144
  from matplotlib.font_manager import findfont, FontProperties
116
145
  findfont(FontProperties(family=font_family))
117
146
  except Exception:
118
147
  font_family = "Helvetica" # Nature prefers Helvetica if Arial unavailable
119
- print(f"Warning: '{font_family}' not found, falling back to 'Helvetica'.")
148
+ log.warning(
149
+ f"Font '{requested_font_family}' not found; falling back to '{font_family}'."
150
+ )
120
151
 
121
152
  # Figure size presets (Nature: single ≈ 89 mm, double ≈ 183 mm at 25.4 mm/inch)
122
153
  if isinstance(layout, tuple):
@@ -191,50 +222,114 @@ def update_matploblib_config(config=None, font_family="Arial", layout="single"):
191
222
  "svg.fonttype": "none",
192
223
  })
193
224
 
194
- def pra(dataset_name, matrix, is_corr=False):
195
- log.info(f"******************** {dataset_name} ********************")
196
- log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
225
+
226
+ def _sort_ascending_for_dataset(dataset_name):
227
+ sorting = dload("input", "sorting")
228
+ if not isinstance(sorting, dict):
229
+ return False
230
+ sort_order = str(sorting.get(dataset_name, "high")).strip().lower()
231
+ return sort_order == "low"
232
+
233
+
234
+ def prepare_terms_for_dataset(dataset_name, matrix):
235
+ """Prepare dataset-specific gold-standard terms and filtered matrix.
236
+
237
+ This computes:
238
+ - terms['used_genes'] as the intersection of terms['all_genes'] with either
239
+ shared genes (config['analysis_genes']=='shared') or the dataset genes
240
+ (config['analysis_genes']=='dataset_specific').
241
+ - genes_present_in_terms_<dataset_name>
242
+
243
+ Side effects:
244
+ - stores dataset-specific terms and genes list under:
245
+ dsave(..., 'common', f'terms_{dataset_name}')
246
+ dsave(..., 'common', f'genes_present_in_terms_{dataset_name}')
247
+
248
+ Returns:
249
+ (terms, genes_present, matrix_filtered)
250
+ """
197
251
  config = dload("config")
198
- use_common_genes = config.get("use_common_genes", True)
252
+ if config is None:
253
+ raise RuntimeError(
254
+ "prepare_terms_for_dataset(): config not found. Run initialize() first."
255
+ )
199
256
 
200
257
  terms_data = dload("common", "terms")
201
258
  if terms_data is None or not isinstance(terms_data, pd.DataFrame):
202
- raise ValueError("Expected 'terms' to be a DataFrame, but got None or invalid type.")
259
+ raise ValueError(
260
+ "prepare_terms_for_dataset(): expected 'terms' to be a DataFrame, but got None or invalid type. "
261
+ "Make sure to run load_gold_standard() first."
262
+ )
203
263
  terms = terms_data.copy()
204
- sorting = dload("input", "sorting")
205
- sort_order = sorting.get(dataset_name, "high")
206
-
207
- if not is_corr:
208
- matrix = perform_corr(matrix, config.get("corr_function"))
209
264
 
210
- # Apply per-dataset gene filtering based on use_common_genes setting
211
- if use_common_genes:
212
- # Use common genes approach (current behavior)
213
- common_genes = dload("common", "common_genes")
214
- if not common_genes:
215
- raise ValueError("Common genes not found.")
216
- common_genes_set = set(common_genes)
217
- terms["used_genes"] = terms["all_genes"].apply(lambda x: list(set(x) & common_genes_set))
218
- log.info(f"Using common genes approach: {len(common_genes)} genes")
219
- else:
220
- # Use per-dataset approach (new behavior)
221
- dataset_genes_set = set(matrix.index)
222
- terms["used_genes"] = terms["all_genes"].apply(lambda x: list(set(x) & dataset_genes_set))
223
- log.info(f"Using per-dataset approach for {dataset_name}: {len(dataset_genes_set)} genes in dataset")
265
+ analysis_genes = normalize_analysis_genes(config.get("analysis_genes"))
266
+
267
+ if analysis_genes == "shared":
268
+ common_genes = dload("common", "common_genes")
269
+ if common_genes is None:
270
+ raise ValueError(
271
+ "prepare_terms_for_dataset(): common genes not found. "
272
+ "Run get_common_genes() or set analysis_genes='dataset_specific'."
273
+ )
274
+
275
+ common_genes_list = list(common_genes)
276
+ if len(common_genes_list) == 0:
277
+ raise ValueError(
278
+ "prepare_terms_for_dataset(): common genes is empty. "
279
+ "Run get_common_genes() or set analysis_genes='dataset_specific'."
280
+ )
281
+
282
+ gene_universe = set(common_genes_list)
283
+ log.info(f"Using shared genes approach: {len(gene_universe)} genes")
284
+ else:
285
+ gene_universe = set(matrix.index)
286
+ log.info(
287
+ f"Using dataset-specific approach for {dataset_name}: {len(gene_universe)} genes in dataset"
288
+ )
289
+
290
+ terms["used_genes"] = terms["all_genes"].apply(
291
+ lambda genes: list(set(genes) & gene_universe)
292
+ )
224
293
 
225
- # Filter terms by minimum genes after dataset-specific filtering
294
+ min_genes_raw = config.get("min_genes_in_complex", 3)
295
+ min_genes = int(min_genes_raw) if min_genes_raw is not None else 3
226
296
  terms["n_used_genes"] = terms["used_genes"].apply(len)
227
- terms = terms[terms["n_used_genes"] >= config['min_genes_in_complex']]
228
-
229
- # Get genes present in terms for this specific dataset
230
- genes_present = list(set([gene for genes_list in terms["used_genes"] for gene in genes_list]))
297
+ terms = terms[terms["n_used_genes"] >= min_genes]
298
+
299
+ if bool(config.get("jaccard", False)):
300
+ before = len(terms)
301
+ terms = filter_duplicate_terms(terms)
302
+ log.done(
303
+ f"After Jaccard duplicate used_genes filtering for {dataset_name}: "
304
+ f"{len(terms)} terms ({before - len(terms)} removed)"
305
+ )
306
+
307
+ genes_present = list(
308
+ set(gene for genes_list in terms["used_genes"] for gene in genes_list)
309
+ )
231
310
  log.info(f"Genes present in terms for {dataset_name}: {len(genes_present)}")
232
-
233
- matrix = filter_matrix_by_genes(matrix, genes_present)
311
+
312
+ matrix_filtered = filter_matrix_by_genes(matrix, genes_present)
313
+
314
+ dsave(terms, "common", f"terms_{dataset_name}")
315
+ dsave(genes_present, "common", f"genes_present_in_terms_{dataset_name}")
316
+
317
+ return terms, genes_present, matrix_filtered
318
+
319
+ def pra(dataset_name, matrix, is_corr=False):
320
+ log.info(f"******************** {dataset_name} ********************")
321
+ log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
322
+ config = dload("config")
323
+ ascending = _sort_ascending_for_dataset(dataset_name)
324
+
325
+ if not is_corr:
326
+ matrix = perform_corr(matrix, config.get("corr_function"))
327
+
328
+ terms, _genes_present, matrix = prepare_terms_for_dataset(dataset_name, matrix)
234
329
  log.info(f"Matrix shape: {matrix.shape}")
235
330
  df = binary(matrix)
236
331
  log.info(f"Pair-wise shape: {df.shape}")
237
- df = quick_sort(df, ascending=(sort_order == "low"))
332
+ df = quick_sort(df, ascending=ascending)
238
333
 
239
334
  log.started("Building gene-to-pair indices")
240
335
  gold_pair_to_complex = _build_gold_pair_to_complex(terms)
@@ -252,23 +347,22 @@ def pra(dataset_name, matrix, is_corr=False):
252
347
  if df["prediction"].sum() == 0:
253
348
  log.info("No true positives found in dataset.")
254
349
  pr_auc = np.nan
350
+ df["tp"] = 0
351
+ df["precision"] = np.nan
352
+ df["recall"] = np.nan
255
353
  else:
256
354
  tp = df["prediction"].cumsum()
257
355
  df["tp"] = tp
258
356
  precision = tp / (np.arange(len(df)) + 1)
259
357
  recall = tp / tp.iloc[-1]
260
- pr_auc = metrics.auc(recall, precision)
261
358
  df["precision"] = precision
262
359
  df["recall"] = recall
360
+ pr_auc = metrics.auc(recall, precision) if len(recall) >= 2 else np.nan
263
361
 
264
362
  log.info(f"PR-AUC: {pr_auc:.4f}, Number of true positives: {df['prediction'].sum()}")
265
363
  dsave(df, "pra", dataset_name)
266
364
  dsave(pr_auc, "pr_auc", dataset_name)
267
- dsave( _corrected_auc(df) , "corrected_pr_auc", dataset_name)
268
-
269
- # Save dataset-specific terms for per-complex analysis
270
- dsave(terms, "common", f"terms_{dataset_name}")
271
- dsave(genes_present, "common", f"genes_present_in_terms_{dataset_name}")
365
+ dsave(_corrected_auc(df), "corrected_pr_auc", dataset_name)
272
366
 
273
367
  log.done(f"Global PRA completed for {dataset_name}")
274
368
  return df
@@ -278,7 +372,12 @@ def pra(dataset_name, matrix, is_corr=False):
278
372
  # --------------------------------------------------------------------------
279
373
 
280
374
  def _corrected_auc(df: pd.DataFrame) -> float:
281
- return np.trapz(df["precision"], df["recall"]) - df["precision"].iloc[-1]
375
+ if df.empty or "precision" not in df.columns or "recall" not in df.columns:
376
+ return np.nan
377
+ valid = df[["precision", "recall"]].replace([np.inf, -np.inf], np.nan).dropna()
378
+ if len(valid) < 2:
379
+ return np.nan
380
+ return np.trapz(valid["precision"], valid["recall"]) - valid["precision"].iloc[-1]
282
381
 
283
382
  def _build_gene_to_pair_indices(pairwise_df):
284
383
  indices = pairwise_df.index.values
@@ -395,28 +494,52 @@ def _process_chunk(chunk_terms, min_genes, memmap_path, gene_to_pair_indices):
395
494
  # Return error info for debugging
396
495
  return {'error': str(e), 'chunk_size': len(chunk_terms)}
397
496
 
398
- def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
497
+ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=None, n_jobs=None):
399
498
  log.started(f"*** Per-complex PRA started - {dataset_name} ***")
400
499
  config = dload("config")
401
-
402
- # Use dataset-specific terms and genes from pra function
403
- terms = dload("common", f"terms_{dataset_name}")
404
- genes_present = dload("common", f"genes_present_in_terms_{dataset_name}")
405
-
406
- if terms is None:
407
- log.warning(f"No dataset-specific terms found for {dataset_name}, using global terms")
408
- terms = dload("common", "terms")
409
- genes_present = dload("common", "genes_present_in_terms")
410
-
411
- sorting = dload("input", "sorting")
412
- sort_order = sorting.get(dataset_name, "high")
500
+ ascending = _sort_ascending_for_dataset(dataset_name)
501
+ per_complex_config = config.get("per_complex", {})
502
+ if not isinstance(per_complex_config, dict):
503
+ per_complex_config = {}
504
+ chunk_size_value = (
505
+ chunk_size if chunk_size is not None else per_complex_config.get("chunk_size", 200)
506
+ )
507
+ n_jobs_value = n_jobs if n_jobs is not None else per_complex_config.get("n_jobs", 4)
508
+ max_nbytes = per_complex_config.get("max_nbytes", "100M")
509
+
510
+ try:
511
+ effective_chunk_size = int(chunk_size_value)
512
+ effective_n_jobs = int(n_jobs_value)
513
+ except (TypeError, ValueError) as exc:
514
+ raise ValueError(
515
+ "per-complex chunk_size and n_jobs must be integer-compatible values."
516
+ ) from exc
517
+
518
+ if effective_chunk_size <= 0:
519
+ raise ValueError("per-complex chunk_size must be greater than 0.")
520
+ if effective_n_jobs <= 0:
521
+ raise ValueError("per-complex n_jobs must be greater than 0.")
522
+
413
523
  if not is_corr:
414
524
  matrix = perform_corr(matrix, config.get("corr_function"))
415
- matrix = filter_matrix_by_genes(matrix, genes_present)
525
+
526
+ # Prefer terms prepared by pra(); if absent, prepare them here so direct
527
+ # pra_percomplex() calls use the same dataset-specific gene universe.
528
+ terms = dload("common", f"terms_{dataset_name}")
529
+ genes_present = dload("common", f"genes_present_in_terms_{dataset_name}")
530
+
531
+ if not isinstance(terms, pd.DataFrame) or genes_present is None:
532
+ log.warning(
533
+ f"No dataset-specific terms found for {dataset_name}; preparing them now."
534
+ )
535
+ terms, genes_present, matrix = prepare_terms_for_dataset(dataset_name, matrix)
536
+ else:
537
+ matrix = filter_matrix_by_genes(matrix, genes_present)
538
+
416
539
  log.info(f"Matrix shape: {matrix.shape}")
417
540
  df = binary(matrix)
418
541
  log.info(f"Pair-wise shape: {df.shape}")
419
- df = quick_sort(df, ascending=(sort_order == "low"))
542
+ df = quick_sort(df, ascending=ascending)
420
543
  pairwise_df = df.copy()
421
544
  pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
422
545
  pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
@@ -434,26 +557,42 @@ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
434
557
  pairwise_df = _precompute_complex_ids(pairwise_df, gold_pair_to_complex)
435
558
  log.done("Precomputing complex IDs") #
436
559
 
560
+ chunks = [
561
+ terms.iloc[i:i + effective_chunk_size]
562
+ for i in range(0, len(terms), effective_chunk_size)
563
+ ]
564
+ min_genes = config["min_genes_per_complex_analysis"]
565
+
566
+ if not chunks:
567
+ terms["auc_score"] = pd.Series(dtype=float)
568
+ terms["corrected_auc_score"] = pd.Series(dtype=float)
569
+ dsave(terms, "pra_percomplex", dataset_name)
570
+ log.done("Per-complex PRA completed with no eligible terms.")
571
+ return terms
572
+
437
573
  log.info('Dumping pairwise_df to memmap')
438
574
  memmap_path = _dump_pairwise_memmap(pairwise_df, dataset_name)
439
575
  log.done('Dumping pairwise_df to memmap')
440
576
 
441
- # choose smaller chunks now that pickling cost is gone
442
- chunks = [terms.iloc[i:i+chunk_size] for i in range(0, len(terms), chunk_size)]
443
- min_genes = config["min_genes_per_complex_analysis"]
444
-
445
577
  # Initialize results variable
446
578
  results = None
447
579
 
448
580
  try:
449
581
  # Compatible parallel execution for older joblib versions
450
582
  log.started("Processing chunks in parallel")
451
-
583
+ actual_n_jobs = min(effective_n_jobs, len(chunks))
584
+ log.info(
585
+ "Per-complex parallel config: "
586
+ f"n_jobs={actual_n_jobs}, requested_n_jobs={effective_n_jobs}, "
587
+ f"chunk_size={effective_chunk_size}, chunks={len(chunks)}, "
588
+ f"max_nbytes={max_nbytes}"
589
+ )
590
+
452
591
  # Use a more conservative approach with older joblib
453
592
  results = Parallel(
454
- n_jobs=min(4, len(chunks)), # Limit to 4 workers or number of chunks
593
+ n_jobs=actual_n_jobs,
455
594
  temp_folder=os.path.dirname(memmap_path),
456
- max_nbytes='100M', # Set memory limit
595
+ max_nbytes=max_nbytes,
457
596
  verbose=1 # Show progress
458
597
  )(delayed(_process_chunk)(chunk, min_genes, memmap_path, gene_to_pair_indices)
459
598
  for chunk in tqdm(chunks, desc="Per-complex PRA"))
@@ -484,11 +623,13 @@ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
484
623
  # Merge results with enhanced error handling
485
624
  auc_scores = {}
486
625
  corrected_auc_scores = {}
626
+ errors = []
487
627
  if results:
488
628
  for i, res in enumerate(results):
489
629
  if isinstance(res, dict):
490
630
  if 'error' in res:
491
631
  log.error(f"Error in chunk {i}: {res['error']}")
632
+ errors.append(f"chunk {i}: {res['error']}")
492
633
  elif 'auc' in res and 'corrected_auc' in res:
493
634
  # New format with both AUC types
494
635
  auc_scores.update(res['auc'])
@@ -498,8 +639,15 @@ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
498
639
  auc_scores.update(res)
499
640
  elif isinstance(res, tuple) and len(res) >= 2 and res[0] is None:
500
641
  log.error(f"Chunk {i} error: {res[1]}")
642
+ errors.append(f"chunk {i}: {res[1]}")
501
643
  else:
502
644
  log.warning(f"Unexpected result type from chunk {i}: {type(res)} - {res}")
645
+ errors.append(f"chunk {i}: unexpected result type {type(res)}")
646
+
647
+ if errors:
648
+ preview = "; ".join(errors[:3])
649
+ extra = f" ({len(errors) - 3} more)" if len(errors) > 3 else ""
650
+ raise RuntimeError(f"Per-complex PRA failed in worker chunks: {preview}{extra}")
503
651
 
504
652
  # Add the computed AUC scores to the terms DataFrame.
505
653
  terms["auc_score"] = pd.Series(auc_scores)
@@ -511,10 +659,20 @@ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
511
659
  def complex_contributions(name):
512
660
  log.info(f"Computing complex contributions (Greedy) for dataset: {name}")
513
661
  pra = dload("pra", name)
514
- terms = dload("common", "terms")
662
+ terms = dload("common", f"terms_{name}")
663
+ if not isinstance(terms, pd.DataFrame):
664
+ # Fallback for backward compatibility
665
+ terms = dload("common", "terms")
666
+ if not isinstance(pra, pd.DataFrame) or pra.empty:
667
+ raise RuntimeError(f"complex_contributions(): PRA data for dataset '{name}' not found.")
668
+ if not isinstance(terms, pd.DataFrame) or terms.empty:
669
+ raise RuntimeError(f"complex_contributions(): terms for dataset '{name}' not found.")
515
670
 
516
- # Ensure pra is sorted by score descending (matches R's order by predicted descending)
517
- pra = pra.sort_values(by='score', ascending=False).reset_index(drop=True)
671
+ # Respect the dataset's score direction: high scores by default, low scores if configured.
672
+ pra = pra.sort_values(
673
+ by='score',
674
+ ascending=_sort_ascending_for_dataset(name),
675
+ ).reset_index(drop=True)
518
676
 
519
677
  # Compute cumulative TP and precision (matches R's TP.count = cumsum(true), Precision = TP / (1:n))
520
678
  pra['cumTP'] = pra['prediction'].cumsum()
@@ -809,6 +967,9 @@ def binary(corr):
809
967
 
810
968
  stack = corr.stack().rename_axis(index=['gene1', 'gene2']).\
811
969
  reset_index().rename(columns={0: 'score'})
970
+ if stack.empty:
971
+ log.done("Pair-wise conversion.")
972
+ return stack
812
973
  if has_mirror_of_first_pair(stack):
813
974
  log.info("Mirror pairs detected. Dropping them to ensure unique gene pairs.")
814
975
  stack = drop_mirror_pairs(stack)
@@ -859,7 +1020,7 @@ def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_pe
859
1020
  continue
860
1021
 
861
1022
  if category == "mpr_complexes_auc" and isinstance(data, dict):
862
- # Dict[dataset_name -> Dict[filter_key -> auc]]
1023
+ # Dict[dataset_name -> Dict[variant_key -> auc]]
863
1024
  try:
864
1025
  df = pd.DataFrame.from_dict(data, orient="index")
865
1026
  df.index.name = "Dataset"
@@ -902,21 +1063,10 @@ def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_pe
902
1063
 
903
1064
  log.done("Results saved to CSV files in the output folder.")
904
1065
 
905
- ################### mPR
906
- ################### mPR ###################
907
-
908
-
909
-
910
- ################### mPR ###################
911
- ################### mPR ###################
912
-
913
1066
  # -----------------------------------------------------------------------------
914
1067
  # mPR preparation (module-level precision–recall, Fig. 1E / 1F)
915
1068
  # -----------------------------------------------------------------------------
916
1069
 
917
- import numpy as np
918
- import pandas as pd
919
-
920
1070
 
921
1071
  def _mpr_get_mtRibo_ETCI_ids(terms_like):
922
1072
  """
@@ -971,34 +1121,16 @@ def _mpr_get_small_high_auprc_ids(
971
1121
  # Helpers implementing the FLEX stepwise module-level PR logic
972
1122
  # -------------------------------------------------------------------------
973
1123
 
974
- """
975
- CORRECT FIX for _mpr_build_pairs in analysis.py
976
-
977
- The issue: The current code marks filtered TPs as true=0, which makes them
978
- count as False Positives and dramatically lowers precision.
979
-
980
- The R code (getSubsetOfCoAnnRemoveIDs with replace=FALSE) REMOVES the
981
- filtered positive pairs entirely from the dataset.
982
-
983
- This is the key difference:
984
- - Current Python: Keeps all rows, filtered TPs become FPs → precision tanks
985
- - R code: Removes filtered TP rows → they don't affect precision at all
986
- """
987
-
988
- import numpy as np
989
- import pandas as pd
990
-
991
-
992
- def _mpr_build_pairs(pra, removed_ids=None):
1124
+ def _mpr_build_pairs(pra, removed_ids=None, ascending=False):
993
1125
  """
994
1126
  Build a Pairs.in.data-like table for mPR / stepwise contributions.
995
-
996
- FIXED: Removes rows that contain filtered complex IDs (matching R behavior)
997
- instead of marking them as true=0.
1127
+
1128
+ Rows containing filtered positive complex IDs are removed from the ranking,
1129
+ matching the FLEX stepwise module-level precision-recall behavior.
998
1130
 
999
1131
  Input:
1000
1132
  pra : DataFrame with at least columns
1001
- - 'score' : ranking score (higher = better)
1133
+ - 'score' : ranking score
1002
1134
  - 'complex_id' : complex annotations
1003
1135
  removed_ids : set[int] of complexes to remove
1004
1136
 
@@ -1057,11 +1189,9 @@ def _mpr_build_pairs(pra, removed_ids=None):
1057
1189
  out["complex_ids"] = df[cid_col].apply(normalize_ids)
1058
1190
  out["true"] = out["complex_ids"].apply(lambda ids: 1 if len(ids) > 0 else 0)
1059
1191
 
1060
- # KEY FIX: Remove rows that are TPs AND contain a removed complex ID
1061
- # This matches the R behavior of getSubsetOfCoAnnRemoveIDs with replace=FALSE
1192
+ # Remove rows that are TPs and contain a removed complex ID.
1062
1193
  if removed_ids:
1063
1194
  should_remove_mask = df[cid_col].apply(should_remove)
1064
- # Only remove if it's a TP (true=1)
1065
1195
  remove_mask = should_remove_mask & (out["true"] == 1)
1066
1196
  out = out[~remove_mask].copy()
1067
1197
 
@@ -1071,41 +1201,11 @@ def _mpr_build_pairs(pra, removed_ids=None):
1071
1201
  lambda ids: [cid for cid in ids if cid not in removed_ids]
1072
1202
  )
1073
1203
 
1074
- # Sort by predicted descending
1075
- out = out.sort_values("predicted", ascending=False).reset_index(drop=True)
1204
+ # Sort by the dataset's configured score direction.
1205
+ out = out.sort_values("predicted", ascending=ascending).reset_index(drop=True)
1076
1206
  return out
1077
1207
 
1078
1208
 
1079
- # ============================================================================
1080
- # HOW TO APPLY THIS FIX
1081
- # ============================================================================
1082
- #
1083
- # In analysis.py, replace the _mpr_build_pairs function (around line 962-1025)
1084
- # with the _mpr_build_pairs_fixed function above.
1085
- #
1086
- # The key changes are:
1087
- #
1088
- # 1. REMOVE rows instead of marking true=0:
1089
- #
1090
- # OLD CODE:
1091
- # if removed_ids:
1092
- # ids = [cid for cid in ids if cid not in removed_ids]
1093
- # return ids
1094
- # ...
1095
- # out["true"] = out["complex_ids"].apply(lambda ids: 1 if len(ids) > 0 else 0)
1096
- #
1097
- # NEW CODE:
1098
- # # First compute true normally
1099
- # out["true"] = out["complex_ids"].apply(lambda ids: 1 if len(ids) > 0 else 0)
1100
- #
1101
- # # Then REMOVE rows that are TPs and contain removed IDs
1102
- # if removed_ids:
1103
- # should_remove_mask = df[cid_col].apply(should_remove)
1104
- # remove_mask = should_remove_mask & (out["true"] == 1)
1105
- # out = out[~remove_mask].copy()
1106
- #
1107
- # ============================================================================
1108
-
1109
1209
  def _mpr_precision_cutoffs_from_pairs(pairs, step=0.025):
1110
1210
  """
1111
1211
  Choose precision cutoffs similar to FLEX:
@@ -1137,7 +1237,7 @@ def _mpr_precision_cutoffs_from_pairs(pairs, step=0.025):
1137
1237
  return np.array(cuts, dtype=float)
1138
1238
 
1139
1239
 
1140
- def _mpr_stepwise_contributions(pairs, precision_cutoffs):
1240
+ def _mpr_stepwise_contributions(pairs, precision_cutoffs, ascending=False):
1141
1241
  """
1142
1242
  Greedy, stepwise TP allocation per complex at each precision cutoff.
1143
1243
 
@@ -1152,7 +1252,7 @@ def _mpr_stepwise_contributions(pairs, precision_cutoffs):
1152
1252
  contrib_df : DataFrame [complex_id x cutoff] with TP counts
1153
1253
  """
1154
1254
  pairs = pairs.copy()
1155
- pairs = pairs.sort_values("predicted", ascending=False).reset_index(drop=True)
1255
+ pairs = pairs.sort_values("predicted", ascending=ascending).reset_index(drop=True)
1156
1256
 
1157
1257
  true = pairs["true"].to_numpy(dtype=int)
1158
1258
  n = len(true)
@@ -1272,16 +1372,29 @@ def _mpr_module_coverage(contrib_df, terms, tp_th=1, percent_th=0.1):
1272
1372
  row = terms.loc[cid_int]
1273
1373
 
1274
1374
  n_genes = None
1275
-
1276
- # FIXED: Handle all_genes as list (how it's stored in preprocessing)
1277
- if "all_genes" in row.index:
1375
+
1376
+ # Prefer used_genes (genes actually in the dataset) for a fair coverage
1377
+ # fraction. This matters for GOBP/PATHWAY where all_genes >> used_genes.
1378
+ if "used_genes" in row.index:
1379
+ genes = row["used_genes"]
1380
+ if isinstance(genes, (list, np.ndarray)) and len(genes) > 0:
1381
+ n_genes = len(genes)
1382
+ if n_genes is None and "n_used_genes" in row.index:
1383
+ try:
1384
+ v = int(row["n_used_genes"])
1385
+ if v > 0:
1386
+ n_genes = v
1387
+ except (ValueError, TypeError):
1388
+ pass
1389
+
1390
+ # Fallback: all_genes (how it's stored in preprocessing)
1391
+ if n_genes is None and "all_genes" in row.index:
1278
1392
  genes = row["all_genes"]
1279
- if isinstance(genes, list):
1393
+ if isinstance(genes, (list, np.ndarray)):
1280
1394
  n_genes = len(genes)
1281
1395
  elif isinstance(genes, str):
1282
- # Fallback if stored as string
1283
1396
  n_genes = len([g for g in genes.split(";") if g])
1284
-
1397
+
1285
1398
  # Fallback to Genes column (original string from CORUM)
1286
1399
  if n_genes is None and "Genes" in row.index:
1287
1400
  genes_str = row["Genes"]
@@ -1338,7 +1451,7 @@ def _mpr_complexes_auc(
1338
1451
 
1339
1452
  We compute a normalized AUC by integrating precision over the *normalized*
1340
1453
  coverage axis:
1341
- AUC = \int y \, d(x/max_complexes)
1454
+ AUC = integral y d(x/max_complexes)
1342
1455
 
1343
1456
  This yields a score in [0, 1] (or NaN if insufficient data).
1344
1457
  """
@@ -1348,7 +1461,7 @@ def _mpr_complexes_auc(
1348
1461
  if cov.size == 0 or prec.size == 0:
1349
1462
  return 0.0
1350
1463
 
1351
- # Match plot_mpr_complexes_multi(): only count cov>0 (log-x cannot show 0)
1464
+ # Match plot_mpr_complex_coverage_curve(): only count cov>0 (log-x cannot show 0)
1352
1465
  mask = (
1353
1466
  np.isfinite(cov)
1354
1467
  & np.isfinite(prec)
@@ -1410,26 +1523,31 @@ def mpr_prepare(
1410
1523
  """
1411
1524
  pra = dload("pra", name)
1412
1525
  pra_percomplex = dload("pra_percomplex", name)
1413
- terms = dload("common", "terms")
1526
+ terms = dload("common", f"terms_{name}")
1527
+ if not isinstance(terms, pd.DataFrame):
1528
+ # Fallback for backward compatibility
1529
+ terms = dload("common", "terms")
1414
1530
 
1415
- if pra is None:
1531
+ if pra is None or not isinstance(pra, pd.DataFrame) or pra.empty:
1416
1532
  raise RuntimeError(
1417
1533
  f"mpr_prepare(): PRA data for dataset '{name}' not found "
1418
1534
  "(dload('pra', name))."
1419
1535
  )
1420
- if pra_percomplex is None:
1536
+ if pra_percomplex is None or not isinstance(pra_percomplex, pd.DataFrame) or pra_percomplex.empty:
1421
1537
  raise RuntimeError(
1422
1538
  f"mpr_prepare(): per-complex PRA data for dataset '{name}' not found "
1423
1539
  "(dload('pra_percomplex', name))."
1424
1540
  )
1425
- if terms is None:
1541
+ if terms is None or not isinstance(terms, pd.DataFrame) or terms.empty:
1426
1542
  raise RuntimeError(
1427
1543
  "mpr_prepare(): CORUM 'terms' table not found (dload('common', 'terms'))."
1428
1544
  )
1429
1545
 
1430
- # sort by score descending (ranking)
1546
+ ascending = _sort_ascending_for_dataset(name)
1547
+
1548
+ # Sort by the dataset's configured score direction.
1431
1549
  if "score" in pra.columns:
1432
- pra = pra.sort_values("score", ascending=False).reset_index(drop=True)
1550
+ pra = pra.sort_values("score", ascending=ascending).reset_index(drop=True)
1433
1551
  else:
1434
1552
  pra = pra.reset_index(drop=True)
1435
1553
 
@@ -1455,7 +1573,7 @@ def mpr_prepare(
1455
1573
 
1456
1574
  for label, removed in filter_sets.items():
1457
1575
  # 1) Build pairs table after removing complexes in `removed`
1458
- pairs = _mpr_build_pairs(pra, removed_ids=removed)
1576
+ pairs = _mpr_build_pairs(pra, removed_ids=removed, ascending=ascending)
1459
1577
 
1460
1578
  true = pairs["true"].to_numpy(dtype=int)
1461
1579
  n = len(true)
@@ -1482,13 +1600,24 @@ def mpr_prepare(
1482
1600
  if precision_cutoffs is None:
1483
1601
  precision_cutoffs = _mpr_precision_cutoffs_from_pairs(pairs)
1484
1602
 
1485
- contrib_df = _mpr_stepwise_contributions(pairs, precision_cutoffs)
1603
+ contrib_df = _mpr_stepwise_contributions(
1604
+ pairs,
1605
+ precision_cutoffs,
1606
+ ascending=ascending,
1607
+ )
1486
1608
  cov = _mpr_module_coverage(
1487
1609
  contrib_df,
1488
1610
  terms,
1489
1611
  tp_th=tp_th,
1490
1612
  percent_th=percent_th,
1491
1613
  )
1614
+ # precision_cutoffs are sorted ascending (low → high).
1615
+ # Coverage must be non-increasing in that direction: a more permissive
1616
+ # threshold (lower precision) should never yield fewer covered terms.
1617
+ # The independent greedy allocation per cutoff can violate this, so
1618
+ # enforce monotonicity by propagating the max from right to left.
1619
+ if cov.size > 0:
1620
+ cov = np.maximum.accumulate(cov[::-1])[::-1]
1492
1621
  coverage_curves[label] = cov
1493
1622
  complexes_auc[label] = _mpr_complexes_auc(
1494
1623
  cov,
@@ -1516,424 +1645,3 @@ def mpr_prepare(
1516
1645
 
1517
1646
  # Convenience: store AUCs as their own category for easy export / plotting.
1518
1647
  dsave(complexes_auc, "mpr_complexes_auc", name)
1519
-
1520
-
1521
-
1522
- ### OLD FUNCTIONS
1523
-
1524
- # new but withoutparallel
1525
-
1526
- # def pra_percomplex(dataset_name, matrix, is_corr=False):
1527
- # log.started(f"*** Per-complex PRA started - {dataset_name} ***")
1528
- # config = dload("config")
1529
- # terms = dload("tmp", "terms")
1530
- # genes_present = dload("tmp", "genes_present_in_terms")
1531
- # sorting = dload("input", "sorting")
1532
- # sort_order = sorting.get(dataset_name, "high")
1533
- # if not is_corr:
1534
- # matrix = perform_corr(matrix, config.get("corr_function"))
1535
- # matrix = filter_matrix_by_genes(matrix, genes_present)
1536
- # log.info(f"Matrix shape: {matrix.shape}")
1537
- # df = binary(matrix)
1538
- # log.info(f"Pair-wise shape: {df.shape}")
1539
- # df = quick_sort(df, ascending=(sort_order == "low"))
1540
- # pairwise_df = df.copy()
1541
- # pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
1542
- # pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
1543
-
1544
- # # Precompute a mapping from each gene to the row indices in the pairwise DataFrame where it appears.
1545
- # gene_to_pair_indices = {}
1546
- # for i, (gene_a, gene_b) in enumerate(zip(pairwise_df["gene1"], pairwise_df["gene2"])):
1547
- # gene_to_pair_indices.setdefault(gene_a, []).append(i)
1548
- # gene_to_pair_indices.setdefault(gene_b, []).append(i)
1549
- # log.done
1550
-
1551
- # # Build gold_pair_to_complex using sets for efficiency
1552
- # gold_pair_to_complex = defaultdict(set)
1553
- # for idx, row in terms.iterrows():
1554
- # genes = row.used_genes
1555
- # if len(genes) < 2:
1556
- # continue
1557
- # for i, g1 in enumerate(genes):
1558
- # for g2 in genes[i + 1:]:
1559
- # pair = tuple(sorted((g1, g2)))
1560
- # gold_pair_to_complex[pair].add(idx)
1561
-
1562
- # # Precompute complex_ids as semicolon-separated strings in pairwise_df
1563
- # pairs = [tuple(sorted((g1, g2))) for g1, g2 in zip(pairwise_df["gene1"], pairwise_df["gene2"])]
1564
- # pairwise_df['complex_ids'] = [';'.join(map(str, sorted(gold_pair_to_complex.get(pair, set())))) for pair in pairs]
1565
-
1566
- # # Initialize AUC scores
1567
- # auc_scores = {}
1568
- # # Loop over each gene complex
1569
- # for idx, row in tqdm(terms.iterrows()):
1570
- # gene_set = set(row.used_genes)
1571
- # if config["min_genes_per_complex_analysis"] > len(gene_set):
1572
- # continue
1573
- # # Collect all row indices in the pairwise data where either gene belongs to the complex.
1574
- # candidate_indices = bitarray(len(pairwise_df))
1575
- # for gene in gene_set:
1576
- # if gene in gene_to_pair_indices:
1577
- # candidate_indices[gene_to_pair_indices[gene]] = True
1578
-
1579
- # if not candidate_indices.any():
1580
- # continue
1581
-
1582
- # # Select only the relevant pairwise comparisons.
1583
- # selected_rows = np.unpackbits(candidate_indices).view(bool)[:len(pairwise_df)]
1584
- # sub_df = pairwise_df.iloc[selected_rows]
1585
-
1586
- # # Get current complex ID (assuming idx is the ID; adjust if row['ID'] is different)
1587
- # complex_id = str(idx) # Or str(row['ID']) if available
1588
-
1589
- # # Create true_label: 1 if complex_id in complex_ids (vectorized with str.contains)
1590
- # #true_label = sub_df['complex_ids'].str.contains(complex_id, regex=False).astype(int)
1591
-
1592
- # # Inside the loop, for each complex:
1593
- # # Inside the loop:
1594
- # complex_id = str(idx)
1595
- # # Use (?:^|;) and (?:;|$) to avoid capturing groups
1596
- # pattern = r'(?:^|;)' + re.escape(complex_id) + r'(?:;|$)'
1597
- # true_label = sub_df['complex_ids'].str.contains(pattern, regex=True).astype(int)
1598
- # # Filter to keep verified negatives (complex_ids == "") or positives for this complex (true_label == 1)
1599
- # complex_mask = (sub_df['complex_ids'] == "") | (true_label == 1)
1600
-
1601
- # # Use the masked true labels for AUPRC (avoids SettingWithCopyWarning)
1602
- # predictions = true_label[complex_mask]
1603
-
1604
- # if predictions.sum() == 0:
1605
- # continue
1606
- # # Compute cumulative true positives and derive precision and recall.
1607
- # true_positive_cumsum = predictions.cumsum()
1608
- # precision = true_positive_cumsum / (np.arange(len(predictions)) + 1)
1609
- # recall = true_positive_cumsum / true_positive_cumsum.iloc[-1]
1610
-
1611
- # if len(recall) < 2 or recall.iloc[-1] == 0:
1612
- # continue
1613
- # auc_scores[idx] = metrics.auc(recall, precision)
1614
-
1615
- # # Add the computed AUC scores to the terms DataFrame.
1616
- # terms["auc_score"] = pd.Series(auc_scores)
1617
- # terms.drop(columns=["hash"], inplace=True)
1618
- # dsave(terms, "pra_percomplex", dataset_name)
1619
- # log.done(f"Per-complex PRA completed.")
1620
- # return terms
1621
-
1622
- # it works quick but only maps 1 complex to each pair
1623
-
1624
- # def pra_percomplex_old_type_filtering(dataset_name, matrix, is_corr=False):
1625
- # log.started(f"*** Per-complex PRA started - {dataset_name} ***")
1626
- # config = dload("config")
1627
- # terms = dload("tmp", "terms")
1628
- # genes_present = dload("tmp", "genes_present_in_terms")
1629
- # sorting = dload("input", "sorting")
1630
- # sort_order = sorting.get(dataset_name, "high")
1631
- # if not is_corr:
1632
- # matrix = perform_corr(matrix, config.get("corr_function"))
1633
- # matrix = filter_matrix_by_genes(matrix, genes_present)
1634
- # log.info(f"Matrix shape: {matrix.shape}")
1635
- # df = binary(matrix)
1636
- # log.info(f"Pair-wise shape: {df.shape}")
1637
- # df = quick_sort(df, ascending=(sort_order == "low"))
1638
- # pairwise_df = df.copy()
1639
- # pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
1640
- # pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
1641
- # # Precompute a mapping from each gene to the row indices in the pairwise DataFrame where it appears.
1642
- # gene_to_pair_indices = {}
1643
- # for i, (gene_a, gene_b) in enumerate(zip(pairwise_df["gene1"], pairwise_df["gene2"])):
1644
- # gene_to_pair_indices.setdefault(gene_a, []).append(i)
1645
- # gene_to_pair_indices.setdefault(gene_b, []).append(i)
1646
- # # Initialize AUC scores (one for each complex) with NaNs.
1647
- # #auc_scores = np.full(len(terms), np.nan)
1648
- # auc_scores = {}
1649
- # # Loop over each gene complex
1650
- # for idx, row in tqdm(terms.iterrows()):
1651
- # gene_set = set(row.used_genes)
1652
-
1653
- # if config["min_genes_per_complex_analysis"] > len(gene_set):
1654
- # continue
1655
- # # Collect all row indices in the pairwise data where either gene belongs to the complex.
1656
- # candidate_indices = bitarray(len(pairwise_df))
1657
- # for gene in gene_set:
1658
- # if gene in gene_to_pair_indices:
1659
- # candidate_indices[gene_to_pair_indices[gene]] = True
1660
- # if not candidate_indices.any():
1661
- # continue
1662
- # # Select only the relevant pairwise comparisons.
1663
- # selected_rows = np.unpackbits(candidate_indices).view(bool)[:len(pairwise_df)]
1664
- # sub_df = pairwise_df.iloc[selected_rows]
1665
- # # A prediction is 1 if both genes in the pair are in the complex; otherwise 0.
1666
- # predictions = (sub_df["gene1"].isin(gene_set) & sub_df["gene2"].isin(gene_set)).astype(int)
1667
- # if predictions.sum() == 0:
1668
- # continue
1669
- # # Compute cumulative true positives and derive precision and recall.
1670
- # true_positive_cumsum = predictions.cumsum()
1671
- # precision = true_positive_cumsum / (np.arange(len(predictions)) + 1)
1672
- # recall = true_positive_cumsum / true_positive_cumsum.iloc[-1]
1673
-
1674
- # if len(recall) < 2 or recall.iloc[-1] == 0:
1675
- # continue
1676
- # auc_scores[idx] = metrics.auc(recall, precision)
1677
- # # Add the computed AUC scores to the terms DataFrame.
1678
- # terms["auc_score"] = pd.Series(auc_scores)
1679
- # terms.drop(columns=["hash"], inplace=True)
1680
- # dsave(terms, "pra_percomplex", dataset_name)
1681
- # log.done(f"Per-complex PRA completed.")
1682
- # return terms
1683
-
1684
- # OLD
1685
- # def pra_percomplex(dataset_name, matrix, is_corr=False):
1686
- # log.started(f"*** Per-complex PRA started for {dataset_name} ***")
1687
- # config = dload("config")
1688
- # terms = dload("tmp", "terms")
1689
- # genes_present = dload("tmp", "genes_present_in_terms")
1690
- # sorting = dload("input", "sorting")
1691
- # sort_order = sorting.get(dataset_name, "high")
1692
-
1693
- # if not is_corr:
1694
- # matrix = perform_corr(matrix, "numpy")
1695
- # matrix = filter_matrix_by_genes(matrix, genes_present)
1696
- # log.info(f"Matrix shape: {matrix.shape}")
1697
- # df = binary(matrix)
1698
- # log.info(f"Pair-wise shape: {df.shape}")
1699
- # df = quick_sort(df, ascending=(sort_order == "low"))
1700
- # # Precompute gene → row indices
1701
- # gene_to_rows = {}
1702
- # for i, (g1, g2) in enumerate(zip(df["gene1"], df["gene2"])):
1703
- # gene_to_rows.setdefault(g1, []).append(i)
1704
- # gene_to_rows.setdefault(g2, []).append(i)
1705
- # aucs = np.full(len(terms), np.nan)
1706
- # N = len(df)
1707
- # for idx, row in tqdm(terms.iterrows()):
1708
- # genes = set(row.used_genes)
1709
- # if len(genes) < config["min_complex_size_for_percomplex"]: # Skip small complexes
1710
- # continue
1711
- # # Get all row indices where either gene is in the complex
1712
- # candidate_idxs = set()
1713
- # for g in genes:
1714
- # candidate_idxs.update(gene_to_rows.get(g, []))
1715
- # candidate_idxs = sorted(candidate_idxs)
1716
- # if not candidate_idxs:
1717
- # continue
1718
- # # Use only relevant rows for prediction
1719
- # sub = df.loc[candidate_idxs]
1720
- # preds = (sub["gene1"].isin(genes) & sub["gene2"].isin(genes)).astype(int)
1721
- # if preds.sum() == 0:
1722
- # continue
1723
- # tp = preds.cumsum()
1724
- # prec = tp / (np.arange(len(preds)) + 1)
1725
- # recall = tp / tp.iloc[-1]
1726
- # if len(recall) < 2 or recall.iloc[-1] == 0:
1727
- # continue
1728
- # aucs[idx] = metrics.auc(recall, prec)
1729
- # terms["auc_score"] = aucs
1730
- # terms.drop(columns=["list", "set", "hash"], inplace=True)
1731
- # dsave(terms, "pra_percomplex", dataset_name)
1732
- # log.done(f"Per-complex PRA completed.")
1733
- # return terms
1734
-
1735
- # without greedy
1736
- # def complex_contributions(name):
1737
- # log.info(f"Computing complex contributions for dataset: {name}")
1738
-
1739
- # pra = dload("pra", name)
1740
- # terms = dload("tmp", "terms")
1741
- # d = pra.query('prediction == 1').drop(columns=['gene1', 'gene2'])
1742
- # results = {}
1743
- # thresholds = [round(i, 2) for i in np.arange(1, 0.0001, -0.025)]
1744
- # for cid in terms.ID.to_list():
1745
- # arr = []
1746
- # for threshold in thresholds:
1747
- # r = d[d.complex_id == cid].query('precision >= @threshold')
1748
- # arr.append(r.shape[0])
1749
- # results[cid] = arr
1750
-
1751
- # r = pd.DataFrame(results, index=thresholds).T
1752
- # t = terms[['ID', 'Name']].set_index('ID')
1753
- # r['Name'] = r.index.map(t.Name)
1754
- # r = r[list(reversed(list(r.columns)))]
1755
- # r = r.reset_index(drop=True)
1756
- # dsave(r, "complex_contributions", name)
1757
- # log.info(f"Complex contributions computation completed for dataset: {name}")
1758
- # return r
1759
-
1760
- # # new
1761
- # def complex_contributions(name):
1762
- # log.info(f"Computing complex contributions using R-style greedy logic for dataset: {name}")
1763
- # pra = dload("pra", name)
1764
- # terms = dload("common", "terms")
1765
-
1766
- # # Ensure pra is sorted by score descending
1767
- # pra = pra.sort_values(by='score', ascending=False).reset_index(drop=True)
1768
-
1769
- # # Compute cumulative TP and precision if not present
1770
- # pra['cumTP'] = pra['prediction'].cumsum()
1771
- # pra['rank'] = pra.index + 1
1772
- # pra['precision'] = pra['cumTP'] / pra['rank']
1773
-
1774
- # # R-style precision thresholds
1775
- # prec_min = pra['precision'].min()
1776
- # prec_max = pra['precision'].max()
1777
- # precision_cutoffs = [round(prec_min, 3)]
1778
- # cutoffs_range = np.arange(0.1, prec_max + 0.001, 0.025)
1779
- # precision_cutoffs += [round(t, 3) for t in cutoffs_range if t > prec_min]
1780
- # thresholds = sorted(set(precision_cutoffs)) # Ensure unique and sorted
1781
-
1782
- # results = {}
1783
- # for t in thresholds:
1784
- # if pra['precision'].max() < t:
1785
- # continue
1786
- # cand = pra[pra['precision'] >= t]
1787
- # if cand.empty:
1788
- # continue
1789
- # k = cand.index.max() # rightmost index where precision >= t
1790
- # tp_target = pra.loc[k, 'cumTP']
1791
- # # Find the smallest m where cumTP[m] >= tp_target
1792
- # ind = pra[pra['cumTP'] >= tp_target].index.min()
1793
- # if pd.isna(ind):
1794
- # continue
1795
- # # Select top (ind+1) rows
1796
- # tmp = pra.iloc[0:ind + 1].copy()
1797
- # # Filter for predicted positives (true == 1)
1798
- # tmp = tmp[tmp['prediction'] == 1]
1799
- # tmp = tmp[tmp["complex_id"].notnull()]
1800
- # tmp["ID"] = tmp["complex_id"].apply(lambda ids: ";".join(str(int(i)) for i in ids if pd.notnull(i)))
1801
- # # Now greedy logic
1802
- # final_contrib = {}
1803
- # while not tmp.empty:
1804
- # all_ids = tmp["ID"].str.split(";").explode()
1805
- # contrib = all_ids.value_counts()
1806
- # if contrib.empty:
1807
- # break
1808
- # top_id = contrib.idxmax()
1809
- # final_contrib[top_id] = contrib[top_id]
1810
- # tmp = tmp[~tmp["ID"].str.contains(rf"\b{top_id}\b", regex=True)]
1811
- # for cid, count in final_contrib.items():
1812
- # if cid not in results:
1813
- # results[cid] = [0] * len(thresholds)
1814
- # results[cid][thresholds.index(t)] = count
1815
-
1816
- # # Add back gold standard complexes with 0 contribution
1817
- # gold_ids = set(terms.index.astype(str))
1818
- # all_ids = set(results.keys())
1819
- # missing_ids = gold_ids - all_ids
1820
- # for cid in missing_ids:
1821
- # results[cid] = [0] * len(thresholds)
1822
-
1823
- # # Build result DataFrame
1824
- # r = pd.DataFrame(results, index=thresholds).T
1825
- # r['Name'] = r.index.astype(int).map(terms['Name'])
1826
- # r = r[['Name'] + [c for c in r.columns if c != 'Name']] # Name as first col
1827
- # r = r[(r.drop(columns="Name").sum(axis=1) > 0)]
1828
- # # Move ID to first column, keep Name second, then precision columns in order
1829
- # dsave(r, "complex_contributions", name)
1830
- # log.info(f"Greedy R-style complex contribution completed for dataset: {name}")
1831
- # return r
1832
-
1833
- # def pra(dataset_name, matrix, is_corr=False):
1834
- # log.info(f"******************** {dataset_name} ********************")
1835
- # log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
1836
- # config = dload("config")
1837
-
1838
- # terms_data = dload("tmp", "terms")
1839
- # if terms_data is None or not isinstance(terms_data, pd.DataFrame):
1840
- # raise ValueError("Expected 'terms' to be a DataFrame, but got None or invalid type.")
1841
- # terms = terms_data
1842
- # genes_present = dload("tmp", "genes_present_in_terms")
1843
- # sorting = dload("input", "sorting")
1844
- # sort_order = sorting.get(dataset_name, "high")
1845
-
1846
- # if not is_corr:
1847
- # matrix = perform_corr(matrix, config.get("corr_function"))
1848
-
1849
- # matrix = filter_matrix_by_genes(matrix, genes_present)
1850
-
1851
- # log.info(f"Matrix shape: {matrix.shape}")
1852
- # df = binary(matrix)
1853
- # log.info(f"Pair-wise shape: {df.shape}")
1854
- # df = quick_sort(df, ascending=(sort_order == "low"))
1855
-
1856
- # gold_pair_to_complex = defaultdict(list)
1857
- # for idx, row in terms.iterrows():
1858
- # genes = row.used_genes
1859
- # if len(genes) < 2:
1860
- # continue
1861
- # for i, g1 in enumerate(genes):
1862
- # for g2 in genes[i + 1:]:
1863
- # pair = tuple(sorted((g1, g2)))
1864
- # gold_pair_to_complex[pair].append(idx)
1865
-
1866
- # # Label predictions and complex IDs
1867
- # complex_ids = []
1868
- # predictions = []
1869
- # for g1, g2 in zip(df["gene1"], df["gene2"]):
1870
- # pair = tuple(sorted((g1, g2)))
1871
- # ids = gold_pair_to_complex.get(pair, [])
1872
- # if ids:
1873
- # predictions.append(1)
1874
- # complex_ids.append(ids)
1875
- # else:
1876
- # predictions.append(0)
1877
- # complex_ids.append([])
1878
-
1879
- # df["prediction"] = predictions
1880
- # df["complex_id"] = complex_ids
1881
-
1882
- # if df["prediction"].sum() == 0:
1883
- # log.info("No true positives found in dataset.")
1884
- # pr_auc = np.nan
1885
- # else:
1886
- # tp = df["prediction"].cumsum()
1887
- # df["tp"] = tp
1888
- # precision = tp / (np.arange(len(df)) + 1)
1889
- # recall = tp / tp.iloc[-1]
1890
- # pr_auc = metrics.auc(recall, precision)
1891
- # df["precision"] = precision
1892
- # df["recall"] = recall
1893
-
1894
- # log.info(f"PR-AUC: {pr_auc:.4f}, Number of true positives: {df['prediction'].sum()}")
1895
- # dsave(df, "pra", dataset_name)
1896
- # dsave(pr_auc, "pr_auc", dataset_name)
1897
- # log.done(f"Global PRA completed for {dataset_name}")
1898
- # return df, pr_auc
1899
-
1900
- # def compute_pra(df):
1901
- # log.info("Calculating precision-recall and AUC score.")
1902
- # if df.empty:
1903
- # log.warning("Empty DataFrame encountered in compute_pra. Returning empty DataFrame.")
1904
- # return df
1905
- # df["tp"] = df["prediction"].cumsum()
1906
- # df.reset_index(drop=True, inplace=True)
1907
- # df["precision"] = df["tp"] / (df.index + 1)
1908
- # df["recall"] = df["tp"] / df["tp"].iloc[-1]
1909
- # log.info("DONE: Calculating precision-recall AUC score.")
1910
- # return df
1911
-
1912
- # def pra(dataset_name, matrix, is_corr=False):
1913
- # log.info(f"PRA computation started for {dataset_name}.")
1914
- # genes_present_in_terms = dload("tmp", "genes_present_in_terms")
1915
- # #terms_hash_table = dload("tmp", "terms_hash_table")
1916
- # sorting_prefs = dload("input", "sorting")
1917
- # sort_order = sorting_prefs.get(dataset_name, "high")
1918
- # if not is_corr: matrix = perform_corr(matrix, "numpy")
1919
- # matrix = filter_matrix_by_genes(matrix, genes_present_in_terms)
1920
- # stack = binary(matrix)
1921
-
1922
- # log.info("Checking gene pairs against the gold standard.")
1923
- # gene_pairs = list(zip(stack["gene1"], stack["gene2"]))
1924
- # hashed_pairs = [hash(pair) for pair in gene_pairs]
1925
- # stack["complex_id"] = [terms_hash_table.get(h, 0) for h in hashed_pairs]
1926
- # stack["prediction"] = [1 if h in terms_hash_table else 0 for h in hashed_pairs]
1927
-
1928
- # annotated = stack.copy()
1929
- # if sort_order == "low":
1930
- # ann_sorted = quick_sort(annotated, ascending=True)
1931
- # else:
1932
- # ann_sorted = quick_sort(annotated)
1933
-
1934
- # pra = compute_pra(ann_sorted)
1935
- # pr_auc = metrics.auc(pra.recall, pra.precision)
1936
- # dsave(pra, "pra", dataset_name)
1937
- # dsave(pr_auc, "pr_auc", dataset_name)
1938
- # log.info(f"PRA computation completed for {dataset_name} (Sorting: {sort_order}).")
1939
- # return pra, pr_auc