pythonflex 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pythonflex/analysis.py ADDED
@@ -0,0 +1,1299 @@
1
+ # Standard library imports
2
+ import gc
3
+ import os
4
+ import re
5
+ import shutil
6
+ import time
7
+ from collections import defaultdict, OrderedDict
8
+ from pathlib import Path
9
+
10
+ # Third-party imports
11
+ from art import tprint
12
+ from bitarray import bitarray
13
+ from joblib import Parallel, delayed, dump, load
14
+ import matplotlib.pyplot as plt
15
+ from numba import njit, prange
16
+ import numpy as np
17
+ import pandas as pd
18
+ from sklearn import metrics
19
+ from tqdm import tqdm
20
+
21
+ # Local/application-specific imports
22
+ from .logging_config import log
23
+ from .preprocessing import filter_matrix_by_genes
24
+ from .utils import dsave, dload, _sanitize
25
+
26
+
27
+
28
+ def deep_update(source, overrides):
29
+ """Recursively update the source dict with the overrides."""
30
+ for key, value in overrides.items():
31
+ if isinstance(value, dict) and key in source and isinstance(source[key], dict):
32
+ deep_update(source[key], value)
33
+ else:
34
+ source[key] = value
35
+ return source
36
+
37
+
38
+
39
+ def initialize(config={}):
40
+
41
+ default_config = {
42
+ "min_genes_in_complex": 3,
43
+ "min_genes_per_complex_analysis": 3,
44
+ "output_folder": "output",
45
+ "gold_standard": "CORUM",
46
+ "color_map": "RdYlBu",
47
+ "jaccard": True,
48
+ "plotting": {
49
+ "save_plot": True,
50
+ "show_plot": True,
51
+ "output_type": "png",
52
+ },
53
+ "preprocessing": {
54
+ "normalize": False,
55
+ "fill_na": False,
56
+ "drop_na": False,
57
+ },
58
+ "corr_function": "numpy",
59
+ "logging": { # Added: Default logging config
60
+ "visible_levels": ["DONE"] # if needed #, "PROGRESS", "STARTED", "INFO"
61
+ }
62
+ }
63
+
64
+ # Early merge to get user-overridden config (including logging.visible_levels)
65
+ if config is not None:
66
+ config = deep_update(default_config, config)
67
+ else:
68
+ config = default_config
69
+
70
+ # Extract visible_levels from the merged config and set logging visibility immediately (before any logs)
71
+ visible_levels = config.get("logging", {}).get("visible_levels", ["DONE"])
72
+ log.set_visible_levels(visible_levels)
73
+
74
+ log.info("******************************************************************")
75
+ log.info("🧬 pyFLEX: Systematic CRISPR screen benchmarking framework")
76
+ log.info("******************************************************************")
77
+ log.started("Initialization")
78
+
79
+ # Check and remove .tmp folder if it exists (clean slate to avoid overriding old results)
80
+ tmp_folder = ".tmp"
81
+ if os.path.exists(tmp_folder):
82
+ log.info(f"Removing existing '{tmp_folder}' folder for a clean start.")
83
+ shutil.rmtree(tmp_folder)
84
+ log.done(f"'{tmp_folder}' folder removed successfully.")
85
+
86
+ log.progress("Saving configuration settings.")
87
+
88
+ dsave(config, "config")
89
+ update_matploblib_config(config)
90
+ output_folder = config.get("output_folder", "output")
91
+ os.makedirs(output_folder, exist_ok=True)
92
+ log.progress(f"Output folder '{output_folder}' ensured to exist.")
93
+ log.done("Initialization completed. ")
94
+ tprint("pyFLEX",font="standard")
95
+
96
+
97
+
98
+ def update_matploblib_config(config={}):
99
+ log.progress("Updating matplotlib settings.")
100
+ plt.rcParams.update({
101
+ "font.family": "DejaVu Sans", # ← change if you prefer Arial, etc.
102
+ "mathtext.fontset": "dejavusans",
103
+ 'font.size': 7, # General font size
104
+ 'axes.titlesize': 10, # Title size
105
+ 'axes.labelsize': 7, # Axis labels (xlabel/ylabel)
106
+ 'legend.fontsize': 7, # Legend text
107
+ 'xtick.labelsize': 6, # X-axis tick labels
108
+ 'ytick.labelsize': 6, # Y-axis tick labels
109
+ 'lines.linewidth': 1.5, # Line width for plots
110
+ 'figure.dpi': 300, # Figure resolution
111
+ 'figure.figsize': (8, 6), # Default figure size
112
+ 'grid.linestyle': '--', # Grid line style
113
+ 'grid.linewidth': 0.5, # Grid line width
114
+ 'grid.alpha': 0.2, # Grid transparency
115
+ 'axes.spines.right': False, # Hide right spine
116
+ 'axes.spines.top': False, # Hide top spine
117
+ 'image.cmap': config['color_map'], # Default colormap
118
+ 'axes.edgecolor': 'black', # Axis edge color
119
+ 'axes.facecolor': 'none', # Transparent axes background
120
+ 'text.usetex': False # Ensure LaTeX is off
121
+ })
122
+ log.done("Matplotlib settings updated.")
123
+
124
+
125
+
126
+
127
+
128
+ def pra(dataset_name, matrix, is_corr=False):
129
+ log.info(f"******************** {dataset_name} ********************")
130
+ log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
131
+ config = dload("config")
132
+
133
+ terms_data = dload("common", "terms")
134
+ if terms_data is None or not isinstance(terms_data, pd.DataFrame):
135
+ raise ValueError("Expected 'terms' to be a DataFrame, but got None or invalid type.")
136
+ terms = terms_data
137
+ genes_present = dload("common", "genes_present_in_terms")
138
+ sorting = dload("input", "sorting")
139
+ sort_order = sorting.get(dataset_name, "high")
140
+
141
+ if not is_corr:
142
+ matrix = perform_corr(matrix, config.get("corr_function"))
143
+
144
+ matrix = filter_matrix_by_genes(matrix, genes_present)
145
+
146
+ log.info(f"Matrix shape: {matrix.shape}")
147
+ df = binary(matrix)
148
+ log.info(f"Pair-wise shape: {df.shape}")
149
+ df = quick_sort(df, ascending=(sort_order == "low"))
150
+
151
+ log.started("Building gene-to-pair indices")
152
+ gold_pair_to_complex = _build_gold_pair_to_complex(terms)
153
+ log.done("Gene-to-pair indices built.")
154
+
155
+ log.started("Precomputing complex IDs")
156
+ df = _precompute_complex_ids(df, gold_pair_to_complex)
157
+ log.done("Complex IDs precomputed.")
158
+
159
+ df["prediction"] = df["complex_ids"].astype(bool).astype(int)
160
+ df["complex_id"] = df["complex_ids"].apply(
161
+ lambda s: list(map(int, s.split(";"))) if s else []
162
+ )
163
+
164
+ if df["prediction"].sum() == 0:
165
+ log.info("No true positives found in dataset.")
166
+ pr_auc = np.nan
167
+ else:
168
+ tp = df["prediction"].cumsum()
169
+ df["tp"] = tp
170
+ precision = tp / (np.arange(len(df)) + 1)
171
+ recall = tp / tp.iloc[-1]
172
+ pr_auc = metrics.auc(recall, precision)
173
+ df["precision"] = precision
174
+ df["recall"] = recall
175
+
176
+ log.info(f"PR-AUC: {pr_auc:.4f}, Number of true positives: {df['prediction'].sum()}")
177
+ dsave(df, "pra", dataset_name)
178
+ dsave(pr_auc, "pr_auc", dataset_name)
179
+ log.done(f"Global PRA completed for {dataset_name}")
180
+ return df, pr_auc
181
+
182
+
183
+
184
+
185
+
186
+
187
+
188
+ # --------------------------------------------------------------------------
189
+ # helper functions for PRA per-complex analysis
190
+ # --------------------------------------------------------------------------
191
+
192
+ def _build_gene_to_pair_indices(pairwise_df):
193
+ indices = pairwise_df.index.values
194
+ genes = pd.concat([pairwise_df['gene1'], pairwise_df['gene2']], ignore_index=True)
195
+ stacked_indices = np.concatenate([indices, indices])
196
+ idx_series = pd.Series(stacked_indices, index=range(len(genes)))
197
+ gene_to_pair_indices = defaultdict(list)
198
+ for gene, group in idx_series.groupby(genes):
199
+ gene_to_pair_indices[gene] = group.values.tolist()
200
+ return gene_to_pair_indices
201
+
202
+
203
+ def _build_gold_pair_to_complex(terms):
204
+ pair_map = defaultdict(set)
205
+ for comp_id, genes in zip(terms.index, terms['used_genes']):
206
+ genes = list(genes)
207
+ if len(genes) < 2: continue
208
+ for i in range(len(genes)):
209
+ for j in range(i+1, len(genes)):
210
+ g1, g2 = sorted([genes[i], genes[j]])
211
+ pair_map[(g1, g2)].add(comp_id)
212
+ return pair_map
213
+
214
+
215
+ def _precompute_complex_ids(pairwise_df, gold_pair_to_complex):
216
+ if not gold_pair_to_complex:
217
+ pairwise_df['complex_ids'] = ''
218
+ return pairwise_df
219
+
220
+ # Precompute pairs as tuples
221
+ g1 = pairwise_df['gene1']
222
+ g2 = pairwise_df['gene2']
223
+ pairs = [tuple(sorted((a, b))) for a, b in zip(g1, g2)]
224
+ pairwise_df['complex_ids'] = [
225
+ ';'.join(map(str, sorted(gold_pair_to_complex[p])))
226
+ if p in gold_pair_to_complex else ''
227
+ for p in pairs
228
+ ]
229
+ return pairwise_df
230
+
231
+
232
+
233
+ def _dump_pairwise_memmap(df: pd.DataFrame, tag: str) -> Path:
234
+ tmp_dir = Path(os.path.join(".tmp", "mmap")) # Use .tmp/mmap/ for organization
235
+ tmp_dir.mkdir(parents=True, exist_ok=True) # Create if it doesn't exist
236
+ path = tmp_dir / f".pairwise_{_sanitize(tag)}.pkl"
237
+ dump(df, path, compress=0)
238
+ return path
239
+
240
+
241
+
242
+
243
+ def _init_worker(memmap_path, gene_to_pair_indices):
244
+ global PAIRWISE_DF, GENE2IDX
245
+ PAIRWISE_DF = load(memmap_path)
246
+ GENE2IDX = gene_to_pair_indices
247
+
248
+
249
+
250
+ def delete_memmap(memmap_path, log, wait_seconds=0.1):
251
+
252
+ gc.collect()
253
+ time.sleep(wait_seconds)
254
+
255
+ try:
256
+ os.remove(memmap_path)
257
+ log.info(f"Cleaned up temporary memmap file: {memmap_path}")
258
+ except OSError as e:
259
+ log.warning(f"Original error: {e}")
260
+
261
+
262
+
263
+ # --------------------------------------------------------------------------
264
+ # Process each chunk of terms
265
+ # --------------------------------------------------------------------------
266
+ def _process_chunk(chunk_terms, min_genes):
267
+ pairwise_df = PAIRWISE_DF
268
+ gene_to_pair_indices = GENE2IDX
269
+ local_auc_scores = {}
270
+
271
+ for idx, row in chunk_terms.iterrows():
272
+ gene_set = set(row.used_genes)
273
+ if len(gene_set) < min_genes:
274
+ continue
275
+
276
+ candidate_indices = bitarray(len(pairwise_df))
277
+ for g in gene_set:
278
+ if g in gene_to_pair_indices:
279
+ candidate_indices[gene_to_pair_indices[g]] = True
280
+ if not candidate_indices.any():
281
+ continue
282
+
283
+ selected = np.unpackbits(candidate_indices).view(bool)[:len(pairwise_df)]
284
+ sub_df = pairwise_df.iloc[selected]
285
+
286
+ complex_id = str(idx)
287
+ pattern = r'(?:^|;)' + re.escape(complex_id) + r'(?:;|$)'
288
+ true_label = sub_df["complex_ids"].str.contains(pattern, regex=True).astype(int)
289
+ mask = (sub_df["complex_ids"] == "") | (true_label == 1)
290
+ preds = true_label[mask]
291
+
292
+ if preds.sum() == 0:
293
+ continue
294
+
295
+ tp_cum = preds.cumsum()
296
+ precision = tp_cum / (np.arange(len(preds)) + 1)
297
+ recall = tp_cum / tp_cum.iloc[-1]
298
+ if len(recall) >= 2 and recall.iloc[-1] != 0:
299
+ local_auc_scores[idx] = metrics.auc(recall, precision)
300
+
301
+ return local_auc_scores
302
+
303
+
304
+
305
+ def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
306
+ log.started(f"*** Per-complex PRA started - {dataset_name} ***")
307
+ config = dload("config")
308
+ terms = dload("common", "terms")
309
+ genes_present = dload("common", "genes_present_in_terms")
310
+ sorting = dload("input", "sorting")
311
+ sort_order = sorting.get(dataset_name, "highdor")
312
+ if not is_corr:
313
+ matrix = perform_corr(matrix, config.get("corr_function"))
314
+ matrix = filter_matrix_by_genes(matrix, genes_present)
315
+ log.info(f"Matrix shape: {matrix.shape}")
316
+ df = binary(matrix)
317
+ log.info(f"Pair-wise shape: {df.shape}")
318
+ df = quick_sort(df, ascending=(sort_order == "low"))
319
+ pairwise_df = df.copy()
320
+ pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
321
+ pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
322
+
323
+ # Use helper functions for precomputations
324
+ log.started("Building gene-to-pair indices")
325
+ gene_to_pair_indices = _build_gene_to_pair_indices(pairwise_df)
326
+ log.done("Building gene-to-pair indices")
327
+
328
+ log.started("Building gold pair to complex mapping")
329
+ gold_pair_to_complex = _build_gold_pair_to_complex(terms) # Now serial
330
+ log.done("Building gold pair to complex mapping")
331
+
332
+ log.started("Precomputing complex IDs")
333
+ pairwise_df = _precompute_complex_ids(pairwise_df, gold_pair_to_complex)
334
+ log.done("Precomputing complex IDs") #
335
+
336
+ log.info('Dumping pairwise_df to memmap')
337
+ memmap_path = _dump_pairwise_memmap(pairwise_df, dataset_name)
338
+ log.done('Dumping pairwise_df to memmap')
339
+
340
+ # choose smaller chunks now that pickling cost is gone
341
+ chunks = [terms.iloc[i:i+chunk_size] for i in range(0, len(terms), chunk_size)]
342
+ min_genes = config["min_genes_per_complex_analysis"]
343
+
344
+ # Initialize results variable
345
+ results = None
346
+
347
+ try:
348
+ # Simplified parallel execution without progress callback interference
349
+ log.started("Processing chunks in parallel")
350
+ with tqdm(total=len(chunks), desc="Per-complex PRA") as pbar:
351
+ results = Parallel(
352
+ n_jobs=8,
353
+ temp_folder=os.path.dirname(memmap_path),
354
+ max_nbytes=None,
355
+ mmap_mode="r",
356
+ initializer=_init_worker,
357
+ initargs=(memmap_path, gene_to_pair_indices),
358
+ verbose=0 # Reduce joblib verbosity
359
+ )(delayed(_process_chunk)(chunk, min_genes) for chunk in chunks)
360
+
361
+ # Update progress bar once all tasks are complete
362
+ pbar.update(len(chunks))
363
+
364
+ log.done("Processing chunks in parallel")
365
+
366
+ except Exception as e:
367
+ log.error(f"Error during parallel processing: {e}")
368
+ # Still try to clean up the memmap file
369
+ try:
370
+ if os.path.exists(memmap_path):
371
+ os.remove(memmap_path)
372
+ log.info(f"Cleaned up temporary memmap file after error: {memmap_path}")
373
+ except OSError as cleanup_error:
374
+ log.warning(f"Failed to remove memmap file after error {memmap_path}: {cleanup_error}")
375
+ raise # Re-raise the original exception
376
+
377
+ finally:
378
+ # Ensure cleanup happens regardless of success or failure
379
+ try:
380
+ if os.path.exists(memmap_path):
381
+ os.remove(memmap_path)
382
+ log.info(f"Cleaned up temporary memmap file: {memmap_path}")
383
+ except OSError as e:
384
+ log.warning(f"Failed to remove memmap file {memmap_path}: {e}")
385
+
386
+ # Merge results with error handling
387
+ auc_scores = {}
388
+ if results:
389
+ for res in results:
390
+ if isinstance(res, dict):
391
+ auc_scores.update(res)
392
+ elif isinstance(res, tuple) and res[0] is None:
393
+ log.error(res[1]) # Log the error message from the chunk
394
+ else:
395
+ log.error(f"Ignoring unexpected chunk result: {res}")
396
+
397
+ # Add the computed AUC scores to the terms DataFrame.
398
+ terms["auc_score"] = pd.Series(auc_scores)
399
+ terms.drop(columns=["hash"], inplace=True)
400
+ dsave(terms, "pra_percomplex", dataset_name)
401
+ log.done(f"Per-complex PRA completed.")
402
+ return terms
403
+
404
+
405
+
406
+
407
+
408
+ def complex_contributions(name):
409
+ log.info(f"Computing complex contributions (Greedy) for dataset: {name}")
410
+ pra = dload("pra", name)
411
+ terms = dload("common", "terms")
412
+
413
+ # Ensure pra is sorted by score descending (matches R's order by predicted descending)
414
+ pra = pra.sort_values(by='score', ascending=False).reset_index(drop=True)
415
+
416
+ # Compute cumulative TP and precision (matches R's TP.count = cumsum(true), Precision = TP / (1:n))
417
+ pra['cumTP'] = pra['prediction'].cumsum()
418
+ pra['rank'] = pra.index + 1
419
+ pra['precision'] = pra['cumTP'] / pra['rank']
420
+
421
+ # R-style precision thresholds (matches c( min, seq(0.1, max, 0.025) ) rounded)
422
+ prec_min = pra['precision'].min()
423
+ prec_max = pra['precision'].max()
424
+ precision_cutoffs = [round(prec_min, 3)]
425
+ cutoffs_range = np.arange(0.1, prec_max + 0.001, 0.025)
426
+ precision_cutoffs += [round(t, 3) for t in cutoffs_range if t > prec_min]
427
+ thresholds = sorted(set(precision_cutoffs)) # Ensure unique and sorted
428
+
429
+ # Precompute positives for faster access
430
+ pos_mask = pra['prediction'] == 1
431
+ positives = pra[pos_mask].reset_index(drop=True)
432
+
433
+ # Compute global unique ordered IDs for initial tie-breaking (appearance order from all positives)
434
+ global_row_to_cids = []
435
+ for ids in positives['complex_id']:
436
+ if isinstance(ids, str):
437
+ cleaned = [str(int(float(i.strip()))) for i in ids.split(';') if i.strip()]
438
+ else:
439
+ cleaned = [str(int(i)) for i in ids if pd.notnull(i)]
440
+ global_row_to_cids.append(cleaned)
441
+ all_global_ids = [cid for cids in global_row_to_cids for cid in cids]
442
+ global_unique_ordered = list(OrderedDict.fromkeys(all_global_ids))
443
+
444
+ results = {}
445
+ valid_thresholds = [] # Track valid like R's ind.valid.precision
446
+
447
+ # Progress bar for the main loop (thresholds)
448
+ with tqdm(total=len(thresholds), desc="Processing thresholds", unit="thresh") as pbar:
449
+ for thresh_idx, t in enumerate(thresholds):
450
+ # Check if valid (matches R's ind.valid.precision)
451
+ if not (prec_min <= t <= prec_max):
452
+ pbar.update(1)
453
+ continue
454
+ valid_thresholds.append(thresh_idx) # Track for later sorting
455
+
456
+ # Find rightmost k where precision >= t (matches R's cand.ind[length(cand.ind)])
457
+ cand_mask = pra['precision'] >= t
458
+ if not cand_mask.any():
459
+ pbar.update(1)
460
+ continue
461
+ k = pra.index[cand_mask].max()
462
+ tp_target = pra.loc[k, 'cumTP']
463
+ if tp_target <= 0:
464
+ pbar.update(1)
465
+ continue
466
+
467
+ # Find first ind where cumTP == tp_target (matches R's tmp.ind[1])
468
+ matching_inds = pra[pra['cumTP'] == tp_target].index
469
+ if matching_inds.empty:
470
+ pbar.update(1)
471
+ continue
472
+ ind = matching_inds.min() # First (smallest) like R
473
+
474
+ # Get top (ind+1) rows, filter to prediction==1 and non-null complex_id
475
+ tmp = pra.iloc[0:ind + 1]
476
+ tmp = tmp[(tmp['prediction'] == 1) & tmp['complex_id'].notnull()].reset_index(drop=True)
477
+ if tmp.empty:
478
+ pbar.update(1)
479
+ continue
480
+
481
+ # Build row_to_cids as list of lists (str for consistency, matches R strsplit)
482
+ row_to_cids = []
483
+ for ids in tmp['complex_id']:
484
+ if isinstance(ids, str):
485
+ cleaned = [str(int(float(i.strip()))) for i in ids.split(';') if i.strip()]
486
+ else:
487
+ cleaned = [str(int(i)) for i in ids if pd.notnull(i)]
488
+ row_to_cids.append(cleaned)
489
+
490
+ N = len(row_to_cids)
491
+ cid_to_rows = defaultdict(list)
492
+ for row_idx in range(N):
493
+ for cid in row_to_cids[row_idx]:
494
+ cid_to_rows[cid].append(row_idx)
495
+
496
+ current_size = {cid: len(lst) for cid, lst in cid_to_rows.items()}
497
+ covered = np.zeros(N, dtype=bool)
498
+ remaining_rows = list(range(N)) # Track remaining for tie-breaking
499
+ final_contrib = {}
500
+ is_first = True # Flag for initial greedy step
501
+
502
+ while current_size:
503
+ if not current_size:
504
+ break
505
+ max_contrib = max(current_size.values())
506
+ candidates = [cid for cid, cnt in current_size.items() if cnt == max_contrib]
507
+
508
+ if len(candidates) == 1:
509
+ top_cid = candidates[0]
510
+ else:
511
+ if is_first:
512
+ # Initial tie-break: first in global appearance order (matches R's global matrix row order)
513
+ positions = {cid: global_unique_ordered.index(cid) for cid in candidates if cid in global_unique_ordered}
514
+ top_cid = min(positions, key=positions.get)
515
+ else:
516
+ # Subsequent: first in local remaining appearance order
517
+ all_ids = [cid for ri in remaining_rows for cid in row_to_cids[ri]]
518
+ unique_ordered = list(OrderedDict.fromkeys(all_ids))
519
+ positions = {cid: unique_ordered.index(cid) for cid in candidates if cid in unique_ordered}
520
+ top_cid = min(positions, key=positions.get) # Earliest appearance
521
+
522
+ contrib = current_size[top_cid]
523
+ if contrib <= 0:
524
+ current_size.pop(top_cid, None)
525
+ continue
526
+
527
+ # Cover the remaining rows for top_cid
528
+ for row_idx in cid_to_rows[top_cid]:
529
+ if not covered[row_idx]:
530
+ covered[row_idx] = True
531
+ for cid in row_to_cids[row_idx]:
532
+ if cid in current_size:
533
+ current_size[cid] -= 1
534
+ if current_size[cid] <= 0:
535
+ current_size.pop(cid, None)
536
+
537
+ # Update remaining_rows (remove covered)
538
+ remaining_rows = [ri for ri in remaining_rows if not covered[ri]]
539
+
540
+ final_contrib[top_cid] = contrib
541
+ is_first = False # Only first time is special
542
+
543
+ # Store for this threshold
544
+ for cid, count in final_contrib.items():
545
+ if cid not in results:
546
+ results[cid] = [0] * len(thresholds)
547
+ results[cid][thresh_idx] = count
548
+
549
+ pbar.update(1) # Update progress after processing threshold
550
+
551
+ # Build result DataFrame (index=cid as str)
552
+ r = pd.DataFrame(results, index=thresholds).T
553
+ r.index = r.index.astype(str)
554
+
555
+ # Filter to non-zero first (matches R's nonzero.cont.ind)
556
+ r = r[r.sum(axis=1) > 0]
557
+
558
+ # Intersect with terms IDs, preserving terms order
559
+ gold_ids = set(r.index)
560
+ common_ids = [str(id) for id in terms.index if str(id) in gold_ids]
561
+ r = r.loc[common_ids]
562
+
563
+ # Map Names and insert as first column
564
+ t = pd.Series(terms['Name'].values, index=terms.index.astype(str))
565
+ r.insert(0, 'Name', r.index.map(t))
566
+
567
+ # Set all column names: Name + Precision_*
568
+ precision_cols = [f"Precision_{t}" for t in thresholds]
569
+ r.columns = ['Name'] + precision_cols
570
+
571
+ # Sort by the last valid precision column descending, stable sort (matches R's stable order)
572
+ if valid_thresholds:
573
+ last_valid_col = f"Precision_{thresholds[valid_thresholds[-1]]}"
574
+ r = r.sort_values(by=last_valid_col, ascending=False, kind='stable')
575
+
576
+ # De-duplicate by Name, keeping first (matches R's !duplicated(Name) after function)
577
+ r = r[~r['Name'].duplicated(keep='first')].reset_index(drop=True)
578
+
579
+ dsave(r, "complex_contributions", name)
580
+ log.info(f"Complex contribution (Greedy) completed for dataset: {name}")
581
+ return r
582
+
583
+
584
+
585
+
586
+ # --------------------------------------------------------------------------
587
+ # Helpers
588
+ # --------------------------------------------------------------------------
589
+
590
+ def perform_corr(df, corr_func):
591
+ if corr_func not in {"numpy", "pandas","numba"}:
592
+ raise ValueError("corr_func must be 'numpy' or 'pandas'")
593
+
594
+ log.started(f"Performing correlation using '{corr_func}' method.")
595
+
596
+ if corr_func == "numpy":
597
+ M = np.ma.masked_invalid(df.values)
598
+ corr = np.ma.corrcoef(M)
599
+ arr = corr.filled(np.nan)
600
+ df_corr = pd.DataFrame(arr, index=df.index, columns=df.index)
601
+ np.fill_diagonal(df_corr.values, np.nan)
602
+ log.done("Correlation.")
603
+ return df_corr
604
+
605
+ elif corr_func == "numba":
606
+ corr = fast_corr(df)
607
+ np.fill_diagonal(corr.values, np.nan)
608
+ log.done("Correlation using Numba.")
609
+ return corr
610
+
611
+ else:
612
+ # Compute correlations and modify diagonal in-place
613
+ corr = df.T.corr()
614
+ np.fill_diagonal(corr.values, np.nan)
615
+ return corr
616
+
617
+
618
+
619
+ def fast_corr(df):
620
+ @njit(parallel=True)
621
+ def compute_corr(data):
622
+ m, n = data.shape
623
+ corr = np.full((n, n), np.nan, dtype=np.float64)
624
+ # Compute off-diagonal (upper triangle, parallel over i)
625
+ for i in prange(n):
626
+ for j in range(i + 1, n):
627
+ sum_x = 0.0
628
+ sum_y = 0.0
629
+ sum_xx = 0.0
630
+ sum_yy = 0.0
631
+ sum_xy = 0.0
632
+ count = 0
633
+ for k in range(m):
634
+ x = data[k, i]
635
+ y = data[k, j]
636
+ if not np.isnan(x) and not np.isnan(y):
637
+ sum_x += x
638
+ sum_y += y
639
+ sum_xx += x * x
640
+ sum_yy += y * y
641
+ sum_xy += x * y
642
+ count += 1
643
+ if count >= 2:
644
+ # Sample variance/covariance (div by count-1)
645
+ var_x = (sum_xx - (sum_x ** 2) / count) / (count - 1)
646
+ var_y = (sum_yy - (sum_y ** 2) / count) / (count - 1)
647
+ cov = (sum_xy - (sum_x * sum_y) / count) / (count - 1)
648
+ denom = np.sqrt(var_x * var_y)
649
+ if denom > 0: # Avoid div-by-zero (e.g., constant cols -> nan)
650
+ r = cov / denom
651
+ else:
652
+ r = np.nan
653
+ else:
654
+ r = np.nan
655
+ corr[i, j] = r
656
+ corr[j, i] = r # Symmetric
657
+ # Compute diagonal in parallel
658
+ for i in prange(n):
659
+ sum_x = 0.0
660
+ sum_xx = 0.0
661
+ count = 0
662
+ for k in range(m):
663
+ x = data[k, i]
664
+ if not np.isnan(x):
665
+ sum_x += x
666
+ sum_xx += x * x
667
+ count += 1
668
+ if count >= 2:
669
+ var_x = (sum_xx - (sum_x ** 2) / count) / (count - 1)
670
+ if var_x > 0:
671
+ corr[i, i] = 1.0
672
+ else:
673
+ corr[i, i] = np.nan # Constant column
674
+ else:
675
+ corr[i, i] = np.nan
676
+ return corr
677
+
678
+ df_numeric = df.select_dtypes(include=np.number)
679
+ data = df_numeric.to_numpy().T
680
+ corr_matrix = compute_corr(data)
681
+ corr_df = pd.DataFrame(corr_matrix, index=df_numeric.index, columns=df_numeric.index)
682
+ return corr_df
683
+
684
+
685
+
686
+ def is_symmetric(df):
687
+ return np.allclose(df, df.T, equal_nan=True)
688
+
689
+
690
+ def binary(corr):
691
+ log.started("Converting correlation matrix to pair-wise format.")
692
+ if is_symmetric(corr):
693
+ corr = convert_full_to_half_matrix(corr)
694
+
695
+ stack = corr.stack().rename_axis(index=['gene1', 'gene2']).\
696
+ reset_index().rename(columns={0: 'score'})
697
+ if has_mirror_of_first_pair(stack):
698
+ log.info("Mirror pairs detected. Dropping them to ensure unique gene pairs.")
699
+ stack = drop_mirror_pairs(stack)
700
+ log.done("Pair-wise conversion.")
701
+ return stack
702
+
703
+
704
+ def has_mirror_of_first_pair(df):
705
+ g1, g2 = df.iloc[0]['gene1'], df.iloc[0]['gene2']
706
+ mirror_exists = ((df['gene1'] == g2) & (df['gene2'] == g1)).iloc[1:].any()
707
+ return mirror_exists
708
+
709
+
710
+ def convert_full_to_half_matrix(df):
711
+ if not is_symmetric(df):
712
+ raise ValueError("Matrix must be symmetric to convert to half matrix.")
713
+
714
+ log.started("Converting full correlation matrix to upper triangle (half-matrix) format.")
715
+ arr = df.values.copy()
716
+ arr[np.tril_indices_from(arr)] = np.nan # zero-based lower triangle + diagonal → NaN
717
+ log.done("Matrix conversion.")
718
+ return pd.DataFrame(arr, index=df.index, columns=df.columns)
719
+
720
+
721
+ def drop_mirror_pairs(df):
722
+ log.started("Dropping mirror pairs to ensure unique gene pairs (Optimized).")
723
+ gene_pairs = np.sort(df[["gene1", "gene2"]].to_numpy(), axis=1)
724
+ df.loc[:, ["gene1", "gene2"]] = gene_pairs
725
+ df = df.loc[~df.duplicated(subset=["gene1", "gene2"], keep="first")]
726
+ log.done("Mirror pairs are dropped.")
727
+ return df
728
+
729
+
730
+ def quick_sort(df, ascending=False):
731
+ log.started(f"Pair-wise matrix is sorting based on the 'score' column: ascending:{ascending}")
732
+ order = 1 if ascending else -1
733
+ sorted_df = df.iloc[np.argsort(order * df["score"].values)].reset_index(drop=True)
734
+ log.done("Pair-wise matrix sorting.")
735
+ return sorted_df
736
+
737
+
738
+
739
+
740
+ def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_percomplex"]):
741
+
742
+ config = dload("config") # Load config to get output folder
743
+ output_folder = Path(config.get("output_folder", "output"))
744
+ output_folder = output_folder / "csv" # Create a subfolder for results
745
+ output_folder.mkdir(parents=True, exist_ok=True) # Ensure output folder exists
746
+
747
+ for category in categories:
748
+ data = dload(category) # Load the data for this category
749
+ if data is None:
750
+ log.warning(f"No data found for category '{category}'. Skipping save.")
751
+ continue
752
+
753
+ if category == "pr_auc" and isinstance(data, dict):
754
+ # Special handling: Convert dict to DataFrame (assuming keys are indices, values are data)
755
+ # If values are scalars, create a simple DF with 'Dataset' and 'AUC' columns
756
+ try:
757
+ df = pd.DataFrame.from_dict(data, orient='index', columns=['AUC'])
758
+ df.index.name = 'Dataset'
759
+ txt_path = output_folder / f"{category}.txt"
760
+ df.to_csv(txt_path, sep='\t', index=True) # Save as tab-delimited TXT
761
+ log.info(f"Saved '{category}' as tabular TXT to {txt_path}")
762
+ except Exception as e:
763
+ log.warning(f"Failed to convert and save '{category}' as TXT: {e}")
764
+ continue # Skip to next category after handling pr_auc
765
+
766
+ if isinstance(data, dict):
767
+ # If it's a dict of datasets, save each as a separate CSV
768
+ for key, df in data.items():
769
+ if isinstance(df, pd.DataFrame):
770
+ csv_path = output_folder / f"{category}_{key}.csv"
771
+ df.to_csv(csv_path, index=False)
772
+ log.info(f"Saved '{category}_{key}' to {csv_path}")
773
+ else:
774
+ log.warning(f"Skipping non-DataFrame item '{key}' in '{category}'.")
775
+ elif isinstance(data, pd.DataFrame):
776
+ # If it's a single DataFrame, save it directly
777
+ csv_path = output_folder / f"{category}.csv"
778
+ data.to_csv(csv_path, index=False)
779
+ log.info(f"Saved '{category}' to {csv_path}")
780
+ else:
781
+ log.warning(f"Unsupported data type for '{category}'. Expected DataFrame or dict of DataFrames. Skipping.")
782
+
783
+ log.done("Results saved to CSV files in the output folder.")
784
+
785
+
786
+
787
+
788
+
789
+
790
+ ### OLD FUNCTIONS
791
+
792
+
793
+ # new but withoutparallel
794
+
795
+ # def pra_percomplex(dataset_name, matrix, is_corr=False):
796
+ # log.started(f"*** Per-complex PRA started - {dataset_name} ***")
797
+ # config = dload("config")
798
+ # terms = dload("tmp", "terms")
799
+ # genes_present = dload("tmp", "genes_present_in_terms")
800
+ # sorting = dload("input", "sorting")
801
+ # sort_order = sorting.get(dataset_name, "high")
802
+ # if not is_corr:
803
+ # matrix = perform_corr(matrix, config.get("corr_function"))
804
+ # matrix = filter_matrix_by_genes(matrix, genes_present)
805
+ # log.info(f"Matrix shape: {matrix.shape}")
806
+ # df = binary(matrix)
807
+ # log.info(f"Pair-wise shape: {df.shape}")
808
+ # df = quick_sort(df, ascending=(sort_order == "low"))
809
+ # pairwise_df = df.copy()
810
+ # pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
811
+ # pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
812
+
813
+ # # Precompute a mapping from each gene to the row indices in the pairwise DataFrame where it appears.
814
+ # gene_to_pair_indices = {}
815
+ # for i, (gene_a, gene_b) in enumerate(zip(pairwise_df["gene1"], pairwise_df["gene2"])):
816
+ # gene_to_pair_indices.setdefault(gene_a, []).append(i)
817
+ # gene_to_pair_indices.setdefault(gene_b, []).append(i)
818
+ # log.done
819
+
820
+ # # Build gold_pair_to_complex using sets for efficiency
821
+ # gold_pair_to_complex = defaultdict(set)
822
+ # for idx, row in terms.iterrows():
823
+ # genes = row.used_genes
824
+ # if len(genes) < 2:
825
+ # continue
826
+ # for i, g1 in enumerate(genes):
827
+ # for g2 in genes[i + 1:]:
828
+ # pair = tuple(sorted((g1, g2)))
829
+ # gold_pair_to_complex[pair].add(idx)
830
+
831
+ # # Precompute complex_ids as semicolon-separated strings in pairwise_df
832
+ # pairs = [tuple(sorted((g1, g2))) for g1, g2 in zip(pairwise_df["gene1"], pairwise_df["gene2"])]
833
+ # pairwise_df['complex_ids'] = [';'.join(map(str, sorted(gold_pair_to_complex.get(pair, set())))) for pair in pairs]
834
+
835
+ # # Initialize AUC scores
836
+ # auc_scores = {}
837
+ # # Loop over each gene complex
838
+ # for idx, row in tqdm(terms.iterrows()):
839
+ # gene_set = set(row.used_genes)
840
+ # if config["min_genes_per_complex_analysis"] > len(gene_set):
841
+ # continue
842
+ # # Collect all row indices in the pairwise data where either gene belongs to the complex.
843
+ # candidate_indices = bitarray(len(pairwise_df))
844
+ # for gene in gene_set:
845
+ # if gene in gene_to_pair_indices:
846
+ # candidate_indices[gene_to_pair_indices[gene]] = True
847
+
848
+ # if not candidate_indices.any():
849
+ # continue
850
+
851
+ # # Select only the relevant pairwise comparisons.
852
+ # selected_rows = np.unpackbits(candidate_indices).view(bool)[:len(pairwise_df)]
853
+ # sub_df = pairwise_df.iloc[selected_rows]
854
+
855
+ # # Get current complex ID (assuming idx is the ID; adjust if row['ID'] is different)
856
+ # complex_id = str(idx) # Or str(row['ID']) if available
857
+
858
+ # # Create true_label: 1 if complex_id in complex_ids (vectorized with str.contains)
859
+ # #true_label = sub_df['complex_ids'].str.contains(complex_id, regex=False).astype(int)
860
+
861
+ # # Inside the loop, for each complex:
862
+ # # Inside the loop:
863
+ # complex_id = str(idx)
864
+ # # Use (?:^|;) and (?:;|$) to avoid capturing groups
865
+ # pattern = r'(?:^|;)' + re.escape(complex_id) + r'(?:;|$)'
866
+ # true_label = sub_df['complex_ids'].str.contains(pattern, regex=True).astype(int)
867
+ # # Filter to keep verified negatives (complex_ids == "") or positives for this complex (true_label == 1)
868
+ # complex_mask = (sub_df['complex_ids'] == "") | (true_label == 1)
869
+
870
+ # # Use the masked true labels for AUPRC (avoids SettingWithCopyWarning)
871
+ # predictions = true_label[complex_mask]
872
+
873
+ # if predictions.sum() == 0:
874
+ # continue
875
+ # # Compute cumulative true positives and derive precision and recall.
876
+ # true_positive_cumsum = predictions.cumsum()
877
+ # precision = true_positive_cumsum / (np.arange(len(predictions)) + 1)
878
+ # recall = true_positive_cumsum / true_positive_cumsum.iloc[-1]
879
+
880
+ # if len(recall) < 2 or recall.iloc[-1] == 0:
881
+ # continue
882
+ # auc_scores[idx] = metrics.auc(recall, precision)
883
+
884
+ # # Add the computed AUC scores to the terms DataFrame.
885
+ # terms["auc_score"] = pd.Series(auc_scores)
886
+ # terms.drop(columns=["hash"], inplace=True)
887
+ # dsave(terms, "pra_percomplex", dataset_name)
888
+ # log.done(f"Per-complex PRA completed.")
889
+ # return terms
890
+
891
+
892
+
893
+ # it works quick but only maps 1 complex to each pair
894
+
895
+ # def pra_percomplex_old_type_filtering(dataset_name, matrix, is_corr=False):
896
+ # log.started(f"*** Per-complex PRA started - {dataset_name} ***")
897
+ # config = dload("config")
898
+ # terms = dload("tmp", "terms")
899
+ # genes_present = dload("tmp", "genes_present_in_terms")
900
+ # sorting = dload("input", "sorting")
901
+ # sort_order = sorting.get(dataset_name, "high")
902
+ # if not is_corr:
903
+ # matrix = perform_corr(matrix, config.get("corr_function"))
904
+ # matrix = filter_matrix_by_genes(matrix, genes_present)
905
+ # log.info(f"Matrix shape: {matrix.shape}")
906
+ # df = binary(matrix)
907
+ # log.info(f"Pair-wise shape: {df.shape}")
908
+ # df = quick_sort(df, ascending=(sort_order == "low"))
909
+ # pairwise_df = df.copy()
910
+ # pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
911
+ # pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
912
+ # # Precompute a mapping from each gene to the row indices in the pairwise DataFrame where it appears.
913
+ # gene_to_pair_indices = {}
914
+ # for i, (gene_a, gene_b) in enumerate(zip(pairwise_df["gene1"], pairwise_df["gene2"])):
915
+ # gene_to_pair_indices.setdefault(gene_a, []).append(i)
916
+ # gene_to_pair_indices.setdefault(gene_b, []).append(i)
917
+ # # Initialize AUC scores (one for each complex) with NaNs.
918
+ # #auc_scores = np.full(len(terms), np.nan)
919
+ # auc_scores = {}
920
+ # # Loop over each gene complex
921
+ # for idx, row in tqdm(terms.iterrows()):
922
+ # gene_set = set(row.used_genes)
923
+
924
+ # if config["min_genes_per_complex_analysis"] > len(gene_set):
925
+ # continue
926
+ # # Collect all row indices in the pairwise data where either gene belongs to the complex.
927
+ # candidate_indices = bitarray(len(pairwise_df))
928
+ # for gene in gene_set:
929
+ # if gene in gene_to_pair_indices:
930
+ # candidate_indices[gene_to_pair_indices[gene]] = True
931
+ # if not candidate_indices.any():
932
+ # continue
933
+ # # Select only the relevant pairwise comparisons.
934
+ # selected_rows = np.unpackbits(candidate_indices).view(bool)[:len(pairwise_df)]
935
+ # sub_df = pairwise_df.iloc[selected_rows]
936
+ # # A prediction is 1 if both genes in the pair are in the complex; otherwise 0.
937
+ # predictions = (sub_df["gene1"].isin(gene_set) & sub_df["gene2"].isin(gene_set)).astype(int)
938
+ # if predictions.sum() == 0:
939
+ # continue
940
+ # # Compute cumulative true positives and derive precision and recall.
941
+ # true_positive_cumsum = predictions.cumsum()
942
+ # precision = true_positive_cumsum / (np.arange(len(predictions)) + 1)
943
+ # recall = true_positive_cumsum / true_positive_cumsum.iloc[-1]
944
+
945
+ # if len(recall) < 2 or recall.iloc[-1] == 0:
946
+ # continue
947
+ # auc_scores[idx] = metrics.auc(recall, precision)
948
+ # # Add the computed AUC scores to the terms DataFrame.
949
+ # terms["auc_score"] = pd.Series(auc_scores)
950
+ # terms.drop(columns=["hash"], inplace=True)
951
+ # dsave(terms, "pra_percomplex", dataset_name)
952
+ # log.done(f"Per-complex PRA completed.")
953
+ # return terms
954
+
955
+
956
+
957
+ # OLD
958
+ # def pra_percomplex(dataset_name, matrix, is_corr=False):
959
+ # log.started(f"*** Per-complex PRA started for {dataset_name} ***")
960
+ # config = dload("config")
961
+ # terms = dload("tmp", "terms")
962
+ # genes_present = dload("tmp", "genes_present_in_terms")
963
+ # sorting = dload("input", "sorting")
964
+ # sort_order = sorting.get(dataset_name, "high")
965
+
966
+ # if not is_corr:
967
+ # matrix = perform_corr(matrix, "numpy")
968
+ # matrix = filter_matrix_by_genes(matrix, genes_present)
969
+ # log.info(f"Matrix shape: {matrix.shape}")
970
+ # df = binary(matrix)
971
+ # log.info(f"Pair-wise shape: {df.shape}")
972
+ # df = quick_sort(df, ascending=(sort_order == "low"))
973
+ # # Precompute gene → row indices
974
+ # gene_to_rows = {}
975
+ # for i, (g1, g2) in enumerate(zip(df["gene1"], df["gene2"])):
976
+ # gene_to_rows.setdefault(g1, []).append(i)
977
+ # gene_to_rows.setdefault(g2, []).append(i)
978
+ # aucs = np.full(len(terms), np.nan)
979
+ # N = len(df)
980
+ # for idx, row in tqdm(terms.iterrows()):
981
+ # genes = set(row.used_genes)
982
+ # if len(genes) < config["min_complex_size_for_percomplex"]: # Skip small complexes
983
+ # continue
984
+ # # Get all row indices where either gene is in the complex
985
+ # candidate_idxs = set()
986
+ # for g in genes:
987
+ # candidate_idxs.update(gene_to_rows.get(g, []))
988
+ # candidate_idxs = sorted(candidate_idxs)
989
+ # if not candidate_idxs:
990
+ # continue
991
+ # # Use only relevant rows for prediction
992
+ # sub = df.loc[candidate_idxs]
993
+ # preds = (sub["gene1"].isin(genes) & sub["gene2"].isin(genes)).astype(int)
994
+ # if preds.sum() == 0:
995
+ # continue
996
+ # tp = preds.cumsum()
997
+ # prec = tp / (np.arange(len(preds)) + 1)
998
+ # recall = tp / tp.iloc[-1]
999
+ # if len(recall) < 2 or recall.iloc[-1] == 0:
1000
+ # continue
1001
+ # aucs[idx] = metrics.auc(recall, prec)
1002
+ # terms["auc_score"] = aucs
1003
+ # terms.drop(columns=["list", "set", "hash"], inplace=True)
1004
+ # dsave(terms, "pra_percomplex", dataset_name)
1005
+ # log.done(f"Per-complex PRA completed.")
1006
+ # return terms
1007
+
1008
+
1009
+
1010
+
1011
+
1012
+
1013
+
1014
+
1015
+ # without greedy
1016
+ # def complex_contributions(name):
1017
+ # log.info(f"Computing complex contributions for dataset: {name}")
1018
+
1019
+ # pra = dload("pra", name)
1020
+ # terms = dload("tmp", "terms")
1021
+ # d = pra.query('prediction == 1').drop(columns=['gene1', 'gene2'])
1022
+ # results = {}
1023
+ # thresholds = [round(i, 2) for i in np.arange(1, 0.0001, -0.025)]
1024
+ # for cid in terms.ID.to_list():
1025
+ # arr = []
1026
+ # for threshold in thresholds:
1027
+ # r = d[d.complex_id == cid].query('precision >= @threshold')
1028
+ # arr.append(r.shape[0])
1029
+ # results[cid] = arr
1030
+
1031
+ # r = pd.DataFrame(results, index=thresholds).T
1032
+ # t = terms[['ID', 'Name']].set_index('ID')
1033
+ # r['Name'] = r.index.map(t.Name)
1034
+ # r = r[list(reversed(list(r.columns)))]
1035
+ # r = r.reset_index(drop=True)
1036
+ # dsave(r, "complex_contributions", name)
1037
+ # log.info(f"Complex contributions computation completed for dataset: {name}")
1038
+ # return r
1039
+
1040
+
1041
+
1042
+
1043
+
1044
+
1045
+ # # new
1046
+ # def complex_contributions(name):
1047
+ # log.info(f"Computing complex contributions using R-style greedy logic for dataset: {name}")
1048
+ # pra = dload("pra", name)
1049
+ # terms = dload("common", "terms")
1050
+
1051
+ # # Ensure pra is sorted by score descending
1052
+ # pra = pra.sort_values(by='score', ascending=False).reset_index(drop=True)
1053
+
1054
+ # # Compute cumulative TP and precision if not present
1055
+ # pra['cumTP'] = pra['prediction'].cumsum()
1056
+ # pra['rank'] = pra.index + 1
1057
+ # pra['precision'] = pra['cumTP'] / pra['rank']
1058
+
1059
+ # # R-style precision thresholds
1060
+ # prec_min = pra['precision'].min()
1061
+ # prec_max = pra['precision'].max()
1062
+ # precision_cutoffs = [round(prec_min, 3)]
1063
+ # cutoffs_range = np.arange(0.1, prec_max + 0.001, 0.025)
1064
+ # precision_cutoffs += [round(t, 3) for t in cutoffs_range if t > prec_min]
1065
+ # thresholds = sorted(set(precision_cutoffs)) # Ensure unique and sorted
1066
+
1067
+ # results = {}
1068
+ # for t in thresholds:
1069
+ # if pra['precision'].max() < t:
1070
+ # continue
1071
+ # cand = pra[pra['precision'] >= t]
1072
+ # if cand.empty:
1073
+ # continue
1074
+ # k = cand.index.max() # rightmost index where precision >= t
1075
+ # tp_target = pra.loc[k, 'cumTP']
1076
+ # # Find the smallest m where cumTP[m] >= tp_target
1077
+ # ind = pra[pra['cumTP'] >= tp_target].index.min()
1078
+ # if pd.isna(ind):
1079
+ # continue
1080
+ # # Select top (ind+1) rows
1081
+ # tmp = pra.iloc[0:ind + 1].copy()
1082
+ # # Filter for predicted positives (true == 1)
1083
+ # tmp = tmp[tmp['prediction'] == 1]
1084
+ # tmp = tmp[tmp["complex_id"].notnull()]
1085
+ # tmp["ID"] = tmp["complex_id"].apply(lambda ids: ";".join(str(int(i)) for i in ids if pd.notnull(i)))
1086
+ # # Now greedy logic
1087
+ # final_contrib = {}
1088
+ # while not tmp.empty:
1089
+ # all_ids = tmp["ID"].str.split(";").explode()
1090
+ # contrib = all_ids.value_counts()
1091
+ # if contrib.empty:
1092
+ # break
1093
+ # top_id = contrib.idxmax()
1094
+ # final_contrib[top_id] = contrib[top_id]
1095
+ # tmp = tmp[~tmp["ID"].str.contains(rf"\b{top_id}\b", regex=True)]
1096
+ # for cid, count in final_contrib.items():
1097
+ # if cid not in results:
1098
+ # results[cid] = [0] * len(thresholds)
1099
+ # results[cid][thresholds.index(t)] = count
1100
+
1101
+ # # Add back gold standard complexes with 0 contribution
1102
+ # gold_ids = set(terms.index.astype(str))
1103
+ # all_ids = set(results.keys())
1104
+ # missing_ids = gold_ids - all_ids
1105
+ # for cid in missing_ids:
1106
+ # results[cid] = [0] * len(thresholds)
1107
+
1108
+ # # Build result DataFrame
1109
+ # r = pd.DataFrame(results, index=thresholds).T
1110
+ # r['Name'] = r.index.astype(int).map(terms['Name'])
1111
+ # r = r[['Name'] + [c for c in r.columns if c != 'Name']] # Name as first col
1112
+ # r = r[(r.drop(columns="Name").sum(axis=1) > 0)]
1113
+ # # Move ID to first column, keep Name second, then precision columns in order
1114
+ # dsave(r, "complex_contributions", name)
1115
+ # log.info(f"Greedy R-style complex contribution completed for dataset: {name}")
1116
+ # return r
1117
+
1118
+
1119
+
1120
+ # def pra(dataset_name, matrix, is_corr=False):
1121
+ # log.info(f"******************** {dataset_name} ********************")
1122
+ # log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
1123
+ # config = dload("config")
1124
+
1125
+ # terms_data = dload("tmp", "terms")
1126
+ # if terms_data is None or not isinstance(terms_data, pd.DataFrame):
1127
+ # raise ValueError("Expected 'terms' to be a DataFrame, but got None or invalid type.")
1128
+ # terms = terms_data
1129
+ # genes_present = dload("tmp", "genes_present_in_terms")
1130
+ # sorting = dload("input", "sorting")
1131
+ # sort_order = sorting.get(dataset_name, "high")
1132
+
1133
+ # if not is_corr:
1134
+ # matrix = perform_corr(matrix, config.get("corr_function"))
1135
+
1136
+ # matrix = filter_matrix_by_genes(matrix, genes_present)
1137
+
1138
+ # log.info(f"Matrix shape: {matrix.shape}")
1139
+ # df = binary(matrix)
1140
+ # log.info(f"Pair-wise shape: {df.shape}")
1141
+ # df = quick_sort(df, ascending=(sort_order == "low"))
1142
+
1143
+ # gold_pair_to_complex = defaultdict(list)
1144
+ # for idx, row in terms.iterrows():
1145
+ # genes = row.used_genes
1146
+ # if len(genes) < 2:
1147
+ # continue
1148
+ # for i, g1 in enumerate(genes):
1149
+ # for g2 in genes[i + 1:]:
1150
+ # pair = tuple(sorted((g1, g2)))
1151
+ # gold_pair_to_complex[pair].append(idx)
1152
+
1153
+
1154
+ # # Label predictions and complex IDs
1155
+ # complex_ids = []
1156
+ # predictions = []
1157
+ # for g1, g2 in zip(df["gene1"], df["gene2"]):
1158
+ # pair = tuple(sorted((g1, g2)))
1159
+ # ids = gold_pair_to_complex.get(pair, [])
1160
+ # if ids:
1161
+ # predictions.append(1)
1162
+ # complex_ids.append(ids)
1163
+ # else:
1164
+ # predictions.append(0)
1165
+ # complex_ids.append([])
1166
+
1167
+ # df["prediction"] = predictions
1168
+ # df["complex_id"] = complex_ids
1169
+
1170
+ # if df["prediction"].sum() == 0:
1171
+ # log.info("No true positives found in dataset.")
1172
+ # pr_auc = np.nan
1173
+ # else:
1174
+ # tp = df["prediction"].cumsum()
1175
+ # df["tp"] = tp
1176
+ # precision = tp / (np.arange(len(df)) + 1)
1177
+ # recall = tp / tp.iloc[-1]
1178
+ # pr_auc = metrics.auc(recall, precision)
1179
+ # df["precision"] = precision
1180
+ # df["recall"] = recall
1181
+
1182
+ # log.info(f"PR-AUC: {pr_auc:.4f}, Number of true positives: {df['prediction'].sum()}")
1183
+ # dsave(df, "pra", dataset_name)
1184
+ # dsave(pr_auc, "pr_auc", dataset_name)
1185
+ # log.done(f"Global PRA completed for {dataset_name}")
1186
+ # return df, pr_auc
1187
+
1188
+
1189
+
1190
+ # def compute_pra(df):
1191
+ # log.info("Calculating precision-recall and AUC score.")
1192
+ # if df.empty:
1193
+ # log.warning("Empty DataFrame encountered in compute_pra. Returning empty DataFrame.")
1194
+ # return df
1195
+ # df["tp"] = df["prediction"].cumsum()
1196
+ # df.reset_index(drop=True, inplace=True)
1197
+ # df["precision"] = df["tp"] / (df.index + 1)
1198
+ # df["recall"] = df["tp"] / df["tp"].iloc[-1]
1199
+ # log.info("DONE: Calculating precision-recall AUC score.")
1200
+ # return df
1201
+
1202
+
1203
+ # def pra(dataset_name, matrix, is_corr=False):
1204
+ # log.info(f"PRA computation started for {dataset_name}.")
1205
+ # genes_present_in_terms = dload("tmp", "genes_present_in_terms")
1206
+ # #terms_hash_table = dload("tmp", "terms_hash_table")
1207
+ # sorting_prefs = dload("input", "sorting")
1208
+ # sort_order = sorting_prefs.get(dataset_name, "high")
1209
+ # if not is_corr: matrix = perform_corr(matrix, "numpy")
1210
+ # matrix = filter_matrix_by_genes(matrix, genes_present_in_terms)
1211
+ # stack = binary(matrix)
1212
+
1213
+ # log.info("Checking gene pairs against the gold standard.")
1214
+ # gene_pairs = list(zip(stack["gene1"], stack["gene2"]))
1215
+ # hashed_pairs = [hash(pair) for pair in gene_pairs]
1216
+ # stack["complex_id"] = [terms_hash_table.get(h, 0) for h in hashed_pairs]
1217
+ # stack["prediction"] = [1 if h in terms_hash_table else 0 for h in hashed_pairs]
1218
+
1219
+ # annotated = stack.copy()
1220
+ # if sort_order == "low":
1221
+ # ann_sorted = quick_sort(annotated, ascending=True)
1222
+ # else:
1223
+ # ann_sorted = quick_sort(annotated)
1224
+
1225
+ # pra = compute_pra(ann_sorted)
1226
+ # pr_auc = metrics.auc(pra.recall, pra.precision)
1227
+ # dsave(pra, "pra", dataset_name)
1228
+ # dsave(pr_auc, "pr_auc", dataset_name)
1229
+ # log.info(f"PRA computation completed for {dataset_name} (Sorting: {sort_order}).")
1230
+ # return pra, pr_auc
1231
+
1232
+
1233
+
1234
+
1235
+ # new but not seperated to functions (Build gold standard etc.)
1236
+
1237
+ # def pra(dataset_name, matrix, is_corr=False):
1238
+ # log.info(f"******************** {dataset_name} ********************")
1239
+ # log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
1240
+ # terms_data = dload("tmp", "terms")
1241
+ # if terms_data is None or not isinstance(terms_data, pd.DataFrame):
1242
+ # raise ValueError("Expected 'terms' to be a DataFrame, but got None or invalid type.")
1243
+ # terms = terms_data.reset_index(drop=True)
1244
+ # genes_present = dload("tmp", "genes_present_in_terms")
1245
+ # sorting = dload("input", "sorting")
1246
+ # sort_order = sorting.get(dataset_name, "high")
1247
+
1248
+ # if not is_corr:
1249
+ # matrix = perform_corr(matrix, "numpy")
1250
+
1251
+ # matrix = filter_matrix_by_genes(matrix, genes_present)
1252
+
1253
+ # log.info(f"Matrix shape: {matrix.shape}")
1254
+ # df = binary(matrix)
1255
+ # log.info(f"Pair-wise shape: {df.shape}")
1256
+ # df = quick_sort(df, ascending=(sort_order == "low"))
1257
+ # df = df.reset_index(drop=True)
1258
+
1259
+ # # Build gold standard: map pair → complex ID
1260
+ # gold_pair_to_complex = {}
1261
+ # for idx, row in terms.iterrows():
1262
+ # genes = row.used_genes
1263
+ # if len(genes) < 2:
1264
+ # continue
1265
+ # for i, g1 in enumerate(genes):
1266
+ # for g2 in genes[i + 1:]:
1267
+ # pair = tuple(sorted((g1, g2)))
1268
+ # gold_pair_to_complex[pair] = idx
1269
+
1270
+ # # Label predictions and complex IDs
1271
+ # complex_ids = []
1272
+ # predictions = []
1273
+ # for g1, g2 in zip(df["gene1"], df["gene2"]):
1274
+ # pair = tuple(sorted((g1, g2)))
1275
+ # if pair in gold_pair_to_complex:
1276
+ # predictions.append(1)
1277
+ # complex_ids.append(gold_pair_to_complex[pair])
1278
+ # else:
1279
+ # predictions.append(0)
1280
+ # complex_ids.append(0)
1281
+ # df["prediction"] = predictions
1282
+ # df["complex_id"] = complex_ids
1283
+ # if df["prediction"].sum() == 0:
1284
+ # log.info("No true positives found in dataset.")
1285
+ # pr_auc = np.nan
1286
+ # else:
1287
+ # tp = df["prediction"].cumsum()
1288
+ # df["tp"] = tp
1289
+ # precision = tp / (np.arange(len(df)) + 1)
1290
+ # recall = tp / tp.iloc[-1]
1291
+ # pr_auc = metrics.auc(recall, precision)
1292
+ # df["precision"] = precision
1293
+ # df["recall"] = recall
1294
+ # log.info(f"PR-AUC: {pr_auc:.4f}, Number of true positives: {df['prediction'].sum()}")
1295
+ # dsave(df, "pra", dataset_name)
1296
+ # dsave(pr_auc, "pr_auc", dataset_name)
1297
+ # log.done(f"Global PRA completed for {dataset_name}")
1298
+ # return df, pr_auc
1299
+