pythonflex 0.1.2__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pythonflex/__init__.py +18 -0
- pythonflex/analysis.py +1299 -0
- pythonflex/data/dataset/liver_cell_lines_500_genes.csv +501 -0
- pythonflex/data/dataset/melanoma_cell_lines_500_genes.csv +501 -0
- pythonflex/data/dataset/neuroblastoma_cell_lines_500_genes.csv +501 -0
- pythonflex/data/gold_standard/CORUM.parquet +0 -0
- pythonflex/data/gold_standard/GOBP.parquet +0 -0
- pythonflex/data/gold_standard/PATHWAY.parquet +0 -0
- pythonflex/data/gold_standard/corum.csv +2917 -0
- pythonflex/data/gold_standard/gobp.csv +4829 -0
- pythonflex/data/gold_standard/pathway.csv +1330 -0
- pythonflex/examples/basic_usage.py +108 -0
- pythonflex/examples/dataset_filtering.py +29 -0
- pythonflex/logging_config.py +56 -0
- pythonflex/plotting.py +510 -0
- pythonflex/preprocessing.py +221 -0
- pythonflex/utils.py +100 -0
- {pythonflex-0.1.2.dist-info → pythonflex-0.1.3.dist-info}/METADATA +1 -1
- pythonflex-0.1.3.dist-info/RECORD +21 -0
- pythonflex-0.1.2.dist-info/RECORD +0 -4
- {pythonflex-0.1.2.dist-info → pythonflex-0.1.3.dist-info}/WHEEL +0 -0
- {pythonflex-0.1.2.dist-info → pythonflex-0.1.3.dist-info}/entry_points.txt +0 -0
pythonflex/analysis.py
ADDED
|
@@ -0,0 +1,1299 @@
|
|
|
1
|
+
# Standard library imports
|
|
2
|
+
import gc
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
import shutil
|
|
6
|
+
import time
|
|
7
|
+
from collections import defaultdict, OrderedDict
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
# Third-party imports
|
|
11
|
+
from art import tprint
|
|
12
|
+
from bitarray import bitarray
|
|
13
|
+
from joblib import Parallel, delayed, dump, load
|
|
14
|
+
import matplotlib.pyplot as plt
|
|
15
|
+
from numba import njit, prange
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from sklearn import metrics
|
|
19
|
+
from tqdm import tqdm
|
|
20
|
+
|
|
21
|
+
# Local/application-specific imports
|
|
22
|
+
from .logging_config import log
|
|
23
|
+
from .preprocessing import filter_matrix_by_genes
|
|
24
|
+
from .utils import dsave, dload, _sanitize
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def deep_update(source, overrides):
|
|
29
|
+
"""Recursively update the source dict with the overrides."""
|
|
30
|
+
for key, value in overrides.items():
|
|
31
|
+
if isinstance(value, dict) and key in source and isinstance(source[key], dict):
|
|
32
|
+
deep_update(source[key], value)
|
|
33
|
+
else:
|
|
34
|
+
source[key] = value
|
|
35
|
+
return source
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def initialize(config={}):
|
|
40
|
+
|
|
41
|
+
default_config = {
|
|
42
|
+
"min_genes_in_complex": 3,
|
|
43
|
+
"min_genes_per_complex_analysis": 3,
|
|
44
|
+
"output_folder": "output",
|
|
45
|
+
"gold_standard": "CORUM",
|
|
46
|
+
"color_map": "RdYlBu",
|
|
47
|
+
"jaccard": True,
|
|
48
|
+
"plotting": {
|
|
49
|
+
"save_plot": True,
|
|
50
|
+
"show_plot": True,
|
|
51
|
+
"output_type": "png",
|
|
52
|
+
},
|
|
53
|
+
"preprocessing": {
|
|
54
|
+
"normalize": False,
|
|
55
|
+
"fill_na": False,
|
|
56
|
+
"drop_na": False,
|
|
57
|
+
},
|
|
58
|
+
"corr_function": "numpy",
|
|
59
|
+
"logging": { # Added: Default logging config
|
|
60
|
+
"visible_levels": ["DONE"] # if needed #, "PROGRESS", "STARTED", "INFO"
|
|
61
|
+
}
|
|
62
|
+
}
|
|
63
|
+
|
|
64
|
+
# Early merge to get user-overridden config (including logging.visible_levels)
|
|
65
|
+
if config is not None:
|
|
66
|
+
config = deep_update(default_config, config)
|
|
67
|
+
else:
|
|
68
|
+
config = default_config
|
|
69
|
+
|
|
70
|
+
# Extract visible_levels from the merged config and set logging visibility immediately (before any logs)
|
|
71
|
+
visible_levels = config.get("logging", {}).get("visible_levels", ["DONE"])
|
|
72
|
+
log.set_visible_levels(visible_levels)
|
|
73
|
+
|
|
74
|
+
log.info("******************************************************************")
|
|
75
|
+
log.info("🧬 pyFLEX: Systematic CRISPR screen benchmarking framework")
|
|
76
|
+
log.info("******************************************************************")
|
|
77
|
+
log.started("Initialization")
|
|
78
|
+
|
|
79
|
+
# Check and remove .tmp folder if it exists (clean slate to avoid overriding old results)
|
|
80
|
+
tmp_folder = ".tmp"
|
|
81
|
+
if os.path.exists(tmp_folder):
|
|
82
|
+
log.info(f"Removing existing '{tmp_folder}' folder for a clean start.")
|
|
83
|
+
shutil.rmtree(tmp_folder)
|
|
84
|
+
log.done(f"'{tmp_folder}' folder removed successfully.")
|
|
85
|
+
|
|
86
|
+
log.progress("Saving configuration settings.")
|
|
87
|
+
|
|
88
|
+
dsave(config, "config")
|
|
89
|
+
update_matploblib_config(config)
|
|
90
|
+
output_folder = config.get("output_folder", "output")
|
|
91
|
+
os.makedirs(output_folder, exist_ok=True)
|
|
92
|
+
log.progress(f"Output folder '{output_folder}' ensured to exist.")
|
|
93
|
+
log.done("Initialization completed. ")
|
|
94
|
+
tprint("pyFLEX",font="standard")
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def update_matploblib_config(config={}):
|
|
99
|
+
log.progress("Updating matplotlib settings.")
|
|
100
|
+
plt.rcParams.update({
|
|
101
|
+
"font.family": "DejaVu Sans", # ← change if you prefer Arial, etc.
|
|
102
|
+
"mathtext.fontset": "dejavusans",
|
|
103
|
+
'font.size': 7, # General font size
|
|
104
|
+
'axes.titlesize': 10, # Title size
|
|
105
|
+
'axes.labelsize': 7, # Axis labels (xlabel/ylabel)
|
|
106
|
+
'legend.fontsize': 7, # Legend text
|
|
107
|
+
'xtick.labelsize': 6, # X-axis tick labels
|
|
108
|
+
'ytick.labelsize': 6, # Y-axis tick labels
|
|
109
|
+
'lines.linewidth': 1.5, # Line width for plots
|
|
110
|
+
'figure.dpi': 300, # Figure resolution
|
|
111
|
+
'figure.figsize': (8, 6), # Default figure size
|
|
112
|
+
'grid.linestyle': '--', # Grid line style
|
|
113
|
+
'grid.linewidth': 0.5, # Grid line width
|
|
114
|
+
'grid.alpha': 0.2, # Grid transparency
|
|
115
|
+
'axes.spines.right': False, # Hide right spine
|
|
116
|
+
'axes.spines.top': False, # Hide top spine
|
|
117
|
+
'image.cmap': config['color_map'], # Default colormap
|
|
118
|
+
'axes.edgecolor': 'black', # Axis edge color
|
|
119
|
+
'axes.facecolor': 'none', # Transparent axes background
|
|
120
|
+
'text.usetex': False # Ensure LaTeX is off
|
|
121
|
+
})
|
|
122
|
+
log.done("Matplotlib settings updated.")
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def pra(dataset_name, matrix, is_corr=False):
|
|
129
|
+
log.info(f"******************** {dataset_name} ********************")
|
|
130
|
+
log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
|
|
131
|
+
config = dload("config")
|
|
132
|
+
|
|
133
|
+
terms_data = dload("common", "terms")
|
|
134
|
+
if terms_data is None or not isinstance(terms_data, pd.DataFrame):
|
|
135
|
+
raise ValueError("Expected 'terms' to be a DataFrame, but got None or invalid type.")
|
|
136
|
+
terms = terms_data
|
|
137
|
+
genes_present = dload("common", "genes_present_in_terms")
|
|
138
|
+
sorting = dload("input", "sorting")
|
|
139
|
+
sort_order = sorting.get(dataset_name, "high")
|
|
140
|
+
|
|
141
|
+
if not is_corr:
|
|
142
|
+
matrix = perform_corr(matrix, config.get("corr_function"))
|
|
143
|
+
|
|
144
|
+
matrix = filter_matrix_by_genes(matrix, genes_present)
|
|
145
|
+
|
|
146
|
+
log.info(f"Matrix shape: {matrix.shape}")
|
|
147
|
+
df = binary(matrix)
|
|
148
|
+
log.info(f"Pair-wise shape: {df.shape}")
|
|
149
|
+
df = quick_sort(df, ascending=(sort_order == "low"))
|
|
150
|
+
|
|
151
|
+
log.started("Building gene-to-pair indices")
|
|
152
|
+
gold_pair_to_complex = _build_gold_pair_to_complex(terms)
|
|
153
|
+
log.done("Gene-to-pair indices built.")
|
|
154
|
+
|
|
155
|
+
log.started("Precomputing complex IDs")
|
|
156
|
+
df = _precompute_complex_ids(df, gold_pair_to_complex)
|
|
157
|
+
log.done("Complex IDs precomputed.")
|
|
158
|
+
|
|
159
|
+
df["prediction"] = df["complex_ids"].astype(bool).astype(int)
|
|
160
|
+
df["complex_id"] = df["complex_ids"].apply(
|
|
161
|
+
lambda s: list(map(int, s.split(";"))) if s else []
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
if df["prediction"].sum() == 0:
|
|
165
|
+
log.info("No true positives found in dataset.")
|
|
166
|
+
pr_auc = np.nan
|
|
167
|
+
else:
|
|
168
|
+
tp = df["prediction"].cumsum()
|
|
169
|
+
df["tp"] = tp
|
|
170
|
+
precision = tp / (np.arange(len(df)) + 1)
|
|
171
|
+
recall = tp / tp.iloc[-1]
|
|
172
|
+
pr_auc = metrics.auc(recall, precision)
|
|
173
|
+
df["precision"] = precision
|
|
174
|
+
df["recall"] = recall
|
|
175
|
+
|
|
176
|
+
log.info(f"PR-AUC: {pr_auc:.4f}, Number of true positives: {df['prediction'].sum()}")
|
|
177
|
+
dsave(df, "pra", dataset_name)
|
|
178
|
+
dsave(pr_auc, "pr_auc", dataset_name)
|
|
179
|
+
log.done(f"Global PRA completed for {dataset_name}")
|
|
180
|
+
return df, pr_auc
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
# --------------------------------------------------------------------------
|
|
189
|
+
# helper functions for PRA per-complex analysis
|
|
190
|
+
# --------------------------------------------------------------------------
|
|
191
|
+
|
|
192
|
+
def _build_gene_to_pair_indices(pairwise_df):
|
|
193
|
+
indices = pairwise_df.index.values
|
|
194
|
+
genes = pd.concat([pairwise_df['gene1'], pairwise_df['gene2']], ignore_index=True)
|
|
195
|
+
stacked_indices = np.concatenate([indices, indices])
|
|
196
|
+
idx_series = pd.Series(stacked_indices, index=range(len(genes)))
|
|
197
|
+
gene_to_pair_indices = defaultdict(list)
|
|
198
|
+
for gene, group in idx_series.groupby(genes):
|
|
199
|
+
gene_to_pair_indices[gene] = group.values.tolist()
|
|
200
|
+
return gene_to_pair_indices
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _build_gold_pair_to_complex(terms):
|
|
204
|
+
pair_map = defaultdict(set)
|
|
205
|
+
for comp_id, genes in zip(terms.index, terms['used_genes']):
|
|
206
|
+
genes = list(genes)
|
|
207
|
+
if len(genes) < 2: continue
|
|
208
|
+
for i in range(len(genes)):
|
|
209
|
+
for j in range(i+1, len(genes)):
|
|
210
|
+
g1, g2 = sorted([genes[i], genes[j]])
|
|
211
|
+
pair_map[(g1, g2)].add(comp_id)
|
|
212
|
+
return pair_map
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _precompute_complex_ids(pairwise_df, gold_pair_to_complex):
|
|
216
|
+
if not gold_pair_to_complex:
|
|
217
|
+
pairwise_df['complex_ids'] = ''
|
|
218
|
+
return pairwise_df
|
|
219
|
+
|
|
220
|
+
# Precompute pairs as tuples
|
|
221
|
+
g1 = pairwise_df['gene1']
|
|
222
|
+
g2 = pairwise_df['gene2']
|
|
223
|
+
pairs = [tuple(sorted((a, b))) for a, b in zip(g1, g2)]
|
|
224
|
+
pairwise_df['complex_ids'] = [
|
|
225
|
+
';'.join(map(str, sorted(gold_pair_to_complex[p])))
|
|
226
|
+
if p in gold_pair_to_complex else ''
|
|
227
|
+
for p in pairs
|
|
228
|
+
]
|
|
229
|
+
return pairwise_df
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _dump_pairwise_memmap(df: pd.DataFrame, tag: str) -> Path:
|
|
234
|
+
tmp_dir = Path(os.path.join(".tmp", "mmap")) # Use .tmp/mmap/ for organization
|
|
235
|
+
tmp_dir.mkdir(parents=True, exist_ok=True) # Create if it doesn't exist
|
|
236
|
+
path = tmp_dir / f".pairwise_{_sanitize(tag)}.pkl"
|
|
237
|
+
dump(df, path, compress=0)
|
|
238
|
+
return path
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def _init_worker(memmap_path, gene_to_pair_indices):
|
|
244
|
+
global PAIRWISE_DF, GENE2IDX
|
|
245
|
+
PAIRWISE_DF = load(memmap_path)
|
|
246
|
+
GENE2IDX = gene_to_pair_indices
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def delete_memmap(memmap_path, log, wait_seconds=0.1):
|
|
251
|
+
|
|
252
|
+
gc.collect()
|
|
253
|
+
time.sleep(wait_seconds)
|
|
254
|
+
|
|
255
|
+
try:
|
|
256
|
+
os.remove(memmap_path)
|
|
257
|
+
log.info(f"Cleaned up temporary memmap file: {memmap_path}")
|
|
258
|
+
except OSError as e:
|
|
259
|
+
log.warning(f"Original error: {e}")
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
# --------------------------------------------------------------------------
|
|
264
|
+
# Process each chunk of terms
|
|
265
|
+
# --------------------------------------------------------------------------
|
|
266
|
+
def _process_chunk(chunk_terms, min_genes):
|
|
267
|
+
pairwise_df = PAIRWISE_DF
|
|
268
|
+
gene_to_pair_indices = GENE2IDX
|
|
269
|
+
local_auc_scores = {}
|
|
270
|
+
|
|
271
|
+
for idx, row in chunk_terms.iterrows():
|
|
272
|
+
gene_set = set(row.used_genes)
|
|
273
|
+
if len(gene_set) < min_genes:
|
|
274
|
+
continue
|
|
275
|
+
|
|
276
|
+
candidate_indices = bitarray(len(pairwise_df))
|
|
277
|
+
for g in gene_set:
|
|
278
|
+
if g in gene_to_pair_indices:
|
|
279
|
+
candidate_indices[gene_to_pair_indices[g]] = True
|
|
280
|
+
if not candidate_indices.any():
|
|
281
|
+
continue
|
|
282
|
+
|
|
283
|
+
selected = np.unpackbits(candidate_indices).view(bool)[:len(pairwise_df)]
|
|
284
|
+
sub_df = pairwise_df.iloc[selected]
|
|
285
|
+
|
|
286
|
+
complex_id = str(idx)
|
|
287
|
+
pattern = r'(?:^|;)' + re.escape(complex_id) + r'(?:;|$)'
|
|
288
|
+
true_label = sub_df["complex_ids"].str.contains(pattern, regex=True).astype(int)
|
|
289
|
+
mask = (sub_df["complex_ids"] == "") | (true_label == 1)
|
|
290
|
+
preds = true_label[mask]
|
|
291
|
+
|
|
292
|
+
if preds.sum() == 0:
|
|
293
|
+
continue
|
|
294
|
+
|
|
295
|
+
tp_cum = preds.cumsum()
|
|
296
|
+
precision = tp_cum / (np.arange(len(preds)) + 1)
|
|
297
|
+
recall = tp_cum / tp_cum.iloc[-1]
|
|
298
|
+
if len(recall) >= 2 and recall.iloc[-1] != 0:
|
|
299
|
+
local_auc_scores[idx] = metrics.auc(recall, precision)
|
|
300
|
+
|
|
301
|
+
return local_auc_scores
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def pra_percomplex(dataset_name, matrix, is_corr=False, chunk_size=200):
|
|
306
|
+
log.started(f"*** Per-complex PRA started - {dataset_name} ***")
|
|
307
|
+
config = dload("config")
|
|
308
|
+
terms = dload("common", "terms")
|
|
309
|
+
genes_present = dload("common", "genes_present_in_terms")
|
|
310
|
+
sorting = dload("input", "sorting")
|
|
311
|
+
sort_order = sorting.get(dataset_name, "highdor")
|
|
312
|
+
if not is_corr:
|
|
313
|
+
matrix = perform_corr(matrix, config.get("corr_function"))
|
|
314
|
+
matrix = filter_matrix_by_genes(matrix, genes_present)
|
|
315
|
+
log.info(f"Matrix shape: {matrix.shape}")
|
|
316
|
+
df = binary(matrix)
|
|
317
|
+
log.info(f"Pair-wise shape: {df.shape}")
|
|
318
|
+
df = quick_sort(df, ascending=(sort_order == "low"))
|
|
319
|
+
pairwise_df = df.copy()
|
|
320
|
+
pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
|
|
321
|
+
pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
|
|
322
|
+
|
|
323
|
+
# Use helper functions for precomputations
|
|
324
|
+
log.started("Building gene-to-pair indices")
|
|
325
|
+
gene_to_pair_indices = _build_gene_to_pair_indices(pairwise_df)
|
|
326
|
+
log.done("Building gene-to-pair indices")
|
|
327
|
+
|
|
328
|
+
log.started("Building gold pair to complex mapping")
|
|
329
|
+
gold_pair_to_complex = _build_gold_pair_to_complex(terms) # Now serial
|
|
330
|
+
log.done("Building gold pair to complex mapping")
|
|
331
|
+
|
|
332
|
+
log.started("Precomputing complex IDs")
|
|
333
|
+
pairwise_df = _precompute_complex_ids(pairwise_df, gold_pair_to_complex)
|
|
334
|
+
log.done("Precomputing complex IDs") #
|
|
335
|
+
|
|
336
|
+
log.info('Dumping pairwise_df to memmap')
|
|
337
|
+
memmap_path = _dump_pairwise_memmap(pairwise_df, dataset_name)
|
|
338
|
+
log.done('Dumping pairwise_df to memmap')
|
|
339
|
+
|
|
340
|
+
# choose smaller chunks now that pickling cost is gone
|
|
341
|
+
chunks = [terms.iloc[i:i+chunk_size] for i in range(0, len(terms), chunk_size)]
|
|
342
|
+
min_genes = config["min_genes_per_complex_analysis"]
|
|
343
|
+
|
|
344
|
+
# Initialize results variable
|
|
345
|
+
results = None
|
|
346
|
+
|
|
347
|
+
try:
|
|
348
|
+
# Simplified parallel execution without progress callback interference
|
|
349
|
+
log.started("Processing chunks in parallel")
|
|
350
|
+
with tqdm(total=len(chunks), desc="Per-complex PRA") as pbar:
|
|
351
|
+
results = Parallel(
|
|
352
|
+
n_jobs=8,
|
|
353
|
+
temp_folder=os.path.dirname(memmap_path),
|
|
354
|
+
max_nbytes=None,
|
|
355
|
+
mmap_mode="r",
|
|
356
|
+
initializer=_init_worker,
|
|
357
|
+
initargs=(memmap_path, gene_to_pair_indices),
|
|
358
|
+
verbose=0 # Reduce joblib verbosity
|
|
359
|
+
)(delayed(_process_chunk)(chunk, min_genes) for chunk in chunks)
|
|
360
|
+
|
|
361
|
+
# Update progress bar once all tasks are complete
|
|
362
|
+
pbar.update(len(chunks))
|
|
363
|
+
|
|
364
|
+
log.done("Processing chunks in parallel")
|
|
365
|
+
|
|
366
|
+
except Exception as e:
|
|
367
|
+
log.error(f"Error during parallel processing: {e}")
|
|
368
|
+
# Still try to clean up the memmap file
|
|
369
|
+
try:
|
|
370
|
+
if os.path.exists(memmap_path):
|
|
371
|
+
os.remove(memmap_path)
|
|
372
|
+
log.info(f"Cleaned up temporary memmap file after error: {memmap_path}")
|
|
373
|
+
except OSError as cleanup_error:
|
|
374
|
+
log.warning(f"Failed to remove memmap file after error {memmap_path}: {cleanup_error}")
|
|
375
|
+
raise # Re-raise the original exception
|
|
376
|
+
|
|
377
|
+
finally:
|
|
378
|
+
# Ensure cleanup happens regardless of success or failure
|
|
379
|
+
try:
|
|
380
|
+
if os.path.exists(memmap_path):
|
|
381
|
+
os.remove(memmap_path)
|
|
382
|
+
log.info(f"Cleaned up temporary memmap file: {memmap_path}")
|
|
383
|
+
except OSError as e:
|
|
384
|
+
log.warning(f"Failed to remove memmap file {memmap_path}: {e}")
|
|
385
|
+
|
|
386
|
+
# Merge results with error handling
|
|
387
|
+
auc_scores = {}
|
|
388
|
+
if results:
|
|
389
|
+
for res in results:
|
|
390
|
+
if isinstance(res, dict):
|
|
391
|
+
auc_scores.update(res)
|
|
392
|
+
elif isinstance(res, tuple) and res[0] is None:
|
|
393
|
+
log.error(res[1]) # Log the error message from the chunk
|
|
394
|
+
else:
|
|
395
|
+
log.error(f"Ignoring unexpected chunk result: {res}")
|
|
396
|
+
|
|
397
|
+
# Add the computed AUC scores to the terms DataFrame.
|
|
398
|
+
terms["auc_score"] = pd.Series(auc_scores)
|
|
399
|
+
terms.drop(columns=["hash"], inplace=True)
|
|
400
|
+
dsave(terms, "pra_percomplex", dataset_name)
|
|
401
|
+
log.done(f"Per-complex PRA completed.")
|
|
402
|
+
return terms
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
|
|
406
|
+
|
|
407
|
+
|
|
408
|
+
def complex_contributions(name):
|
|
409
|
+
log.info(f"Computing complex contributions (Greedy) for dataset: {name}")
|
|
410
|
+
pra = dload("pra", name)
|
|
411
|
+
terms = dload("common", "terms")
|
|
412
|
+
|
|
413
|
+
# Ensure pra is sorted by score descending (matches R's order by predicted descending)
|
|
414
|
+
pra = pra.sort_values(by='score', ascending=False).reset_index(drop=True)
|
|
415
|
+
|
|
416
|
+
# Compute cumulative TP and precision (matches R's TP.count = cumsum(true), Precision = TP / (1:n))
|
|
417
|
+
pra['cumTP'] = pra['prediction'].cumsum()
|
|
418
|
+
pra['rank'] = pra.index + 1
|
|
419
|
+
pra['precision'] = pra['cumTP'] / pra['rank']
|
|
420
|
+
|
|
421
|
+
# R-style precision thresholds (matches c( min, seq(0.1, max, 0.025) ) rounded)
|
|
422
|
+
prec_min = pra['precision'].min()
|
|
423
|
+
prec_max = pra['precision'].max()
|
|
424
|
+
precision_cutoffs = [round(prec_min, 3)]
|
|
425
|
+
cutoffs_range = np.arange(0.1, prec_max + 0.001, 0.025)
|
|
426
|
+
precision_cutoffs += [round(t, 3) for t in cutoffs_range if t > prec_min]
|
|
427
|
+
thresholds = sorted(set(precision_cutoffs)) # Ensure unique and sorted
|
|
428
|
+
|
|
429
|
+
# Precompute positives for faster access
|
|
430
|
+
pos_mask = pra['prediction'] == 1
|
|
431
|
+
positives = pra[pos_mask].reset_index(drop=True)
|
|
432
|
+
|
|
433
|
+
# Compute global unique ordered IDs for initial tie-breaking (appearance order from all positives)
|
|
434
|
+
global_row_to_cids = []
|
|
435
|
+
for ids in positives['complex_id']:
|
|
436
|
+
if isinstance(ids, str):
|
|
437
|
+
cleaned = [str(int(float(i.strip()))) for i in ids.split(';') if i.strip()]
|
|
438
|
+
else:
|
|
439
|
+
cleaned = [str(int(i)) for i in ids if pd.notnull(i)]
|
|
440
|
+
global_row_to_cids.append(cleaned)
|
|
441
|
+
all_global_ids = [cid for cids in global_row_to_cids for cid in cids]
|
|
442
|
+
global_unique_ordered = list(OrderedDict.fromkeys(all_global_ids))
|
|
443
|
+
|
|
444
|
+
results = {}
|
|
445
|
+
valid_thresholds = [] # Track valid like R's ind.valid.precision
|
|
446
|
+
|
|
447
|
+
# Progress bar for the main loop (thresholds)
|
|
448
|
+
with tqdm(total=len(thresholds), desc="Processing thresholds", unit="thresh") as pbar:
|
|
449
|
+
for thresh_idx, t in enumerate(thresholds):
|
|
450
|
+
# Check if valid (matches R's ind.valid.precision)
|
|
451
|
+
if not (prec_min <= t <= prec_max):
|
|
452
|
+
pbar.update(1)
|
|
453
|
+
continue
|
|
454
|
+
valid_thresholds.append(thresh_idx) # Track for later sorting
|
|
455
|
+
|
|
456
|
+
# Find rightmost k where precision >= t (matches R's cand.ind[length(cand.ind)])
|
|
457
|
+
cand_mask = pra['precision'] >= t
|
|
458
|
+
if not cand_mask.any():
|
|
459
|
+
pbar.update(1)
|
|
460
|
+
continue
|
|
461
|
+
k = pra.index[cand_mask].max()
|
|
462
|
+
tp_target = pra.loc[k, 'cumTP']
|
|
463
|
+
if tp_target <= 0:
|
|
464
|
+
pbar.update(1)
|
|
465
|
+
continue
|
|
466
|
+
|
|
467
|
+
# Find first ind where cumTP == tp_target (matches R's tmp.ind[1])
|
|
468
|
+
matching_inds = pra[pra['cumTP'] == tp_target].index
|
|
469
|
+
if matching_inds.empty:
|
|
470
|
+
pbar.update(1)
|
|
471
|
+
continue
|
|
472
|
+
ind = matching_inds.min() # First (smallest) like R
|
|
473
|
+
|
|
474
|
+
# Get top (ind+1) rows, filter to prediction==1 and non-null complex_id
|
|
475
|
+
tmp = pra.iloc[0:ind + 1]
|
|
476
|
+
tmp = tmp[(tmp['prediction'] == 1) & tmp['complex_id'].notnull()].reset_index(drop=True)
|
|
477
|
+
if tmp.empty:
|
|
478
|
+
pbar.update(1)
|
|
479
|
+
continue
|
|
480
|
+
|
|
481
|
+
# Build row_to_cids as list of lists (str for consistency, matches R strsplit)
|
|
482
|
+
row_to_cids = []
|
|
483
|
+
for ids in tmp['complex_id']:
|
|
484
|
+
if isinstance(ids, str):
|
|
485
|
+
cleaned = [str(int(float(i.strip()))) for i in ids.split(';') if i.strip()]
|
|
486
|
+
else:
|
|
487
|
+
cleaned = [str(int(i)) for i in ids if pd.notnull(i)]
|
|
488
|
+
row_to_cids.append(cleaned)
|
|
489
|
+
|
|
490
|
+
N = len(row_to_cids)
|
|
491
|
+
cid_to_rows = defaultdict(list)
|
|
492
|
+
for row_idx in range(N):
|
|
493
|
+
for cid in row_to_cids[row_idx]:
|
|
494
|
+
cid_to_rows[cid].append(row_idx)
|
|
495
|
+
|
|
496
|
+
current_size = {cid: len(lst) for cid, lst in cid_to_rows.items()}
|
|
497
|
+
covered = np.zeros(N, dtype=bool)
|
|
498
|
+
remaining_rows = list(range(N)) # Track remaining for tie-breaking
|
|
499
|
+
final_contrib = {}
|
|
500
|
+
is_first = True # Flag for initial greedy step
|
|
501
|
+
|
|
502
|
+
while current_size:
|
|
503
|
+
if not current_size:
|
|
504
|
+
break
|
|
505
|
+
max_contrib = max(current_size.values())
|
|
506
|
+
candidates = [cid for cid, cnt in current_size.items() if cnt == max_contrib]
|
|
507
|
+
|
|
508
|
+
if len(candidates) == 1:
|
|
509
|
+
top_cid = candidates[0]
|
|
510
|
+
else:
|
|
511
|
+
if is_first:
|
|
512
|
+
# Initial tie-break: first in global appearance order (matches R's global matrix row order)
|
|
513
|
+
positions = {cid: global_unique_ordered.index(cid) for cid in candidates if cid in global_unique_ordered}
|
|
514
|
+
top_cid = min(positions, key=positions.get)
|
|
515
|
+
else:
|
|
516
|
+
# Subsequent: first in local remaining appearance order
|
|
517
|
+
all_ids = [cid for ri in remaining_rows for cid in row_to_cids[ri]]
|
|
518
|
+
unique_ordered = list(OrderedDict.fromkeys(all_ids))
|
|
519
|
+
positions = {cid: unique_ordered.index(cid) for cid in candidates if cid in unique_ordered}
|
|
520
|
+
top_cid = min(positions, key=positions.get) # Earliest appearance
|
|
521
|
+
|
|
522
|
+
contrib = current_size[top_cid]
|
|
523
|
+
if contrib <= 0:
|
|
524
|
+
current_size.pop(top_cid, None)
|
|
525
|
+
continue
|
|
526
|
+
|
|
527
|
+
# Cover the remaining rows for top_cid
|
|
528
|
+
for row_idx in cid_to_rows[top_cid]:
|
|
529
|
+
if not covered[row_idx]:
|
|
530
|
+
covered[row_idx] = True
|
|
531
|
+
for cid in row_to_cids[row_idx]:
|
|
532
|
+
if cid in current_size:
|
|
533
|
+
current_size[cid] -= 1
|
|
534
|
+
if current_size[cid] <= 0:
|
|
535
|
+
current_size.pop(cid, None)
|
|
536
|
+
|
|
537
|
+
# Update remaining_rows (remove covered)
|
|
538
|
+
remaining_rows = [ri for ri in remaining_rows if not covered[ri]]
|
|
539
|
+
|
|
540
|
+
final_contrib[top_cid] = contrib
|
|
541
|
+
is_first = False # Only first time is special
|
|
542
|
+
|
|
543
|
+
# Store for this threshold
|
|
544
|
+
for cid, count in final_contrib.items():
|
|
545
|
+
if cid not in results:
|
|
546
|
+
results[cid] = [0] * len(thresholds)
|
|
547
|
+
results[cid][thresh_idx] = count
|
|
548
|
+
|
|
549
|
+
pbar.update(1) # Update progress after processing threshold
|
|
550
|
+
|
|
551
|
+
# Build result DataFrame (index=cid as str)
|
|
552
|
+
r = pd.DataFrame(results, index=thresholds).T
|
|
553
|
+
r.index = r.index.astype(str)
|
|
554
|
+
|
|
555
|
+
# Filter to non-zero first (matches R's nonzero.cont.ind)
|
|
556
|
+
r = r[r.sum(axis=1) > 0]
|
|
557
|
+
|
|
558
|
+
# Intersect with terms IDs, preserving terms order
|
|
559
|
+
gold_ids = set(r.index)
|
|
560
|
+
common_ids = [str(id) for id in terms.index if str(id) in gold_ids]
|
|
561
|
+
r = r.loc[common_ids]
|
|
562
|
+
|
|
563
|
+
# Map Names and insert as first column
|
|
564
|
+
t = pd.Series(terms['Name'].values, index=terms.index.astype(str))
|
|
565
|
+
r.insert(0, 'Name', r.index.map(t))
|
|
566
|
+
|
|
567
|
+
# Set all column names: Name + Precision_*
|
|
568
|
+
precision_cols = [f"Precision_{t}" for t in thresholds]
|
|
569
|
+
r.columns = ['Name'] + precision_cols
|
|
570
|
+
|
|
571
|
+
# Sort by the last valid precision column descending, stable sort (matches R's stable order)
|
|
572
|
+
if valid_thresholds:
|
|
573
|
+
last_valid_col = f"Precision_{thresholds[valid_thresholds[-1]]}"
|
|
574
|
+
r = r.sort_values(by=last_valid_col, ascending=False, kind='stable')
|
|
575
|
+
|
|
576
|
+
# De-duplicate by Name, keeping first (matches R's !duplicated(Name) after function)
|
|
577
|
+
r = r[~r['Name'].duplicated(keep='first')].reset_index(drop=True)
|
|
578
|
+
|
|
579
|
+
dsave(r, "complex_contributions", name)
|
|
580
|
+
log.info(f"Complex contribution (Greedy) completed for dataset: {name}")
|
|
581
|
+
return r
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
# --------------------------------------------------------------------------
|
|
587
|
+
# Helpers
|
|
588
|
+
# --------------------------------------------------------------------------
|
|
589
|
+
|
|
590
|
+
def perform_corr(df, corr_func):
|
|
591
|
+
if corr_func not in {"numpy", "pandas","numba"}:
|
|
592
|
+
raise ValueError("corr_func must be 'numpy' or 'pandas'")
|
|
593
|
+
|
|
594
|
+
log.started(f"Performing correlation using '{corr_func}' method.")
|
|
595
|
+
|
|
596
|
+
if corr_func == "numpy":
|
|
597
|
+
M = np.ma.masked_invalid(df.values)
|
|
598
|
+
corr = np.ma.corrcoef(M)
|
|
599
|
+
arr = corr.filled(np.nan)
|
|
600
|
+
df_corr = pd.DataFrame(arr, index=df.index, columns=df.index)
|
|
601
|
+
np.fill_diagonal(df_corr.values, np.nan)
|
|
602
|
+
log.done("Correlation.")
|
|
603
|
+
return df_corr
|
|
604
|
+
|
|
605
|
+
elif corr_func == "numba":
|
|
606
|
+
corr = fast_corr(df)
|
|
607
|
+
np.fill_diagonal(corr.values, np.nan)
|
|
608
|
+
log.done("Correlation using Numba.")
|
|
609
|
+
return corr
|
|
610
|
+
|
|
611
|
+
else:
|
|
612
|
+
# Compute correlations and modify diagonal in-place
|
|
613
|
+
corr = df.T.corr()
|
|
614
|
+
np.fill_diagonal(corr.values, np.nan)
|
|
615
|
+
return corr
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
|
|
619
|
+
def fast_corr(df):
|
|
620
|
+
@njit(parallel=True)
|
|
621
|
+
def compute_corr(data):
|
|
622
|
+
m, n = data.shape
|
|
623
|
+
corr = np.full((n, n), np.nan, dtype=np.float64)
|
|
624
|
+
# Compute off-diagonal (upper triangle, parallel over i)
|
|
625
|
+
for i in prange(n):
|
|
626
|
+
for j in range(i + 1, n):
|
|
627
|
+
sum_x = 0.0
|
|
628
|
+
sum_y = 0.0
|
|
629
|
+
sum_xx = 0.0
|
|
630
|
+
sum_yy = 0.0
|
|
631
|
+
sum_xy = 0.0
|
|
632
|
+
count = 0
|
|
633
|
+
for k in range(m):
|
|
634
|
+
x = data[k, i]
|
|
635
|
+
y = data[k, j]
|
|
636
|
+
if not np.isnan(x) and not np.isnan(y):
|
|
637
|
+
sum_x += x
|
|
638
|
+
sum_y += y
|
|
639
|
+
sum_xx += x * x
|
|
640
|
+
sum_yy += y * y
|
|
641
|
+
sum_xy += x * y
|
|
642
|
+
count += 1
|
|
643
|
+
if count >= 2:
|
|
644
|
+
# Sample variance/covariance (div by count-1)
|
|
645
|
+
var_x = (sum_xx - (sum_x ** 2) / count) / (count - 1)
|
|
646
|
+
var_y = (sum_yy - (sum_y ** 2) / count) / (count - 1)
|
|
647
|
+
cov = (sum_xy - (sum_x * sum_y) / count) / (count - 1)
|
|
648
|
+
denom = np.sqrt(var_x * var_y)
|
|
649
|
+
if denom > 0: # Avoid div-by-zero (e.g., constant cols -> nan)
|
|
650
|
+
r = cov / denom
|
|
651
|
+
else:
|
|
652
|
+
r = np.nan
|
|
653
|
+
else:
|
|
654
|
+
r = np.nan
|
|
655
|
+
corr[i, j] = r
|
|
656
|
+
corr[j, i] = r # Symmetric
|
|
657
|
+
# Compute diagonal in parallel
|
|
658
|
+
for i in prange(n):
|
|
659
|
+
sum_x = 0.0
|
|
660
|
+
sum_xx = 0.0
|
|
661
|
+
count = 0
|
|
662
|
+
for k in range(m):
|
|
663
|
+
x = data[k, i]
|
|
664
|
+
if not np.isnan(x):
|
|
665
|
+
sum_x += x
|
|
666
|
+
sum_xx += x * x
|
|
667
|
+
count += 1
|
|
668
|
+
if count >= 2:
|
|
669
|
+
var_x = (sum_xx - (sum_x ** 2) / count) / (count - 1)
|
|
670
|
+
if var_x > 0:
|
|
671
|
+
corr[i, i] = 1.0
|
|
672
|
+
else:
|
|
673
|
+
corr[i, i] = np.nan # Constant column
|
|
674
|
+
else:
|
|
675
|
+
corr[i, i] = np.nan
|
|
676
|
+
return corr
|
|
677
|
+
|
|
678
|
+
df_numeric = df.select_dtypes(include=np.number)
|
|
679
|
+
data = df_numeric.to_numpy().T
|
|
680
|
+
corr_matrix = compute_corr(data)
|
|
681
|
+
corr_df = pd.DataFrame(corr_matrix, index=df_numeric.index, columns=df_numeric.index)
|
|
682
|
+
return corr_df
|
|
683
|
+
|
|
684
|
+
|
|
685
|
+
|
|
686
|
+
def is_symmetric(df):
|
|
687
|
+
return np.allclose(df, df.T, equal_nan=True)
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def binary(corr):
|
|
691
|
+
log.started("Converting correlation matrix to pair-wise format.")
|
|
692
|
+
if is_symmetric(corr):
|
|
693
|
+
corr = convert_full_to_half_matrix(corr)
|
|
694
|
+
|
|
695
|
+
stack = corr.stack().rename_axis(index=['gene1', 'gene2']).\
|
|
696
|
+
reset_index().rename(columns={0: 'score'})
|
|
697
|
+
if has_mirror_of_first_pair(stack):
|
|
698
|
+
log.info("Mirror pairs detected. Dropping them to ensure unique gene pairs.")
|
|
699
|
+
stack = drop_mirror_pairs(stack)
|
|
700
|
+
log.done("Pair-wise conversion.")
|
|
701
|
+
return stack
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def has_mirror_of_first_pair(df):
|
|
705
|
+
g1, g2 = df.iloc[0]['gene1'], df.iloc[0]['gene2']
|
|
706
|
+
mirror_exists = ((df['gene1'] == g2) & (df['gene2'] == g1)).iloc[1:].any()
|
|
707
|
+
return mirror_exists
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
def convert_full_to_half_matrix(df):
|
|
711
|
+
if not is_symmetric(df):
|
|
712
|
+
raise ValueError("Matrix must be symmetric to convert to half matrix.")
|
|
713
|
+
|
|
714
|
+
log.started("Converting full correlation matrix to upper triangle (half-matrix) format.")
|
|
715
|
+
arr = df.values.copy()
|
|
716
|
+
arr[np.tril_indices_from(arr)] = np.nan # zero-based lower triangle + diagonal → NaN
|
|
717
|
+
log.done("Matrix conversion.")
|
|
718
|
+
return pd.DataFrame(arr, index=df.index, columns=df.columns)
|
|
719
|
+
|
|
720
|
+
|
|
721
|
+
def drop_mirror_pairs(df):
|
|
722
|
+
log.started("Dropping mirror pairs to ensure unique gene pairs (Optimized).")
|
|
723
|
+
gene_pairs = np.sort(df[["gene1", "gene2"]].to_numpy(), axis=1)
|
|
724
|
+
df.loc[:, ["gene1", "gene2"]] = gene_pairs
|
|
725
|
+
df = df.loc[~df.duplicated(subset=["gene1", "gene2"], keep="first")]
|
|
726
|
+
log.done("Mirror pairs are dropped.")
|
|
727
|
+
return df
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def quick_sort(df, ascending=False):
|
|
731
|
+
log.started(f"Pair-wise matrix is sorting based on the 'score' column: ascending:{ascending}")
|
|
732
|
+
order = 1 if ascending else -1
|
|
733
|
+
sorted_df = df.iloc[np.argsort(order * df["score"].values)].reset_index(drop=True)
|
|
734
|
+
log.done("Pair-wise matrix sorting.")
|
|
735
|
+
return sorted_df
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
|
|
739
|
+
|
|
740
|
+
def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_percomplex"]):
|
|
741
|
+
|
|
742
|
+
config = dload("config") # Load config to get output folder
|
|
743
|
+
output_folder = Path(config.get("output_folder", "output"))
|
|
744
|
+
output_folder = output_folder / "csv" # Create a subfolder for results
|
|
745
|
+
output_folder.mkdir(parents=True, exist_ok=True) # Ensure output folder exists
|
|
746
|
+
|
|
747
|
+
for category in categories:
|
|
748
|
+
data = dload(category) # Load the data for this category
|
|
749
|
+
if data is None:
|
|
750
|
+
log.warning(f"No data found for category '{category}'. Skipping save.")
|
|
751
|
+
continue
|
|
752
|
+
|
|
753
|
+
if category == "pr_auc" and isinstance(data, dict):
|
|
754
|
+
# Special handling: Convert dict to DataFrame (assuming keys are indices, values are data)
|
|
755
|
+
# If values are scalars, create a simple DF with 'Dataset' and 'AUC' columns
|
|
756
|
+
try:
|
|
757
|
+
df = pd.DataFrame.from_dict(data, orient='index', columns=['AUC'])
|
|
758
|
+
df.index.name = 'Dataset'
|
|
759
|
+
txt_path = output_folder / f"{category}.txt"
|
|
760
|
+
df.to_csv(txt_path, sep='\t', index=True) # Save as tab-delimited TXT
|
|
761
|
+
log.info(f"Saved '{category}' as tabular TXT to {txt_path}")
|
|
762
|
+
except Exception as e:
|
|
763
|
+
log.warning(f"Failed to convert and save '{category}' as TXT: {e}")
|
|
764
|
+
continue # Skip to next category after handling pr_auc
|
|
765
|
+
|
|
766
|
+
if isinstance(data, dict):
|
|
767
|
+
# If it's a dict of datasets, save each as a separate CSV
|
|
768
|
+
for key, df in data.items():
|
|
769
|
+
if isinstance(df, pd.DataFrame):
|
|
770
|
+
csv_path = output_folder / f"{category}_{key}.csv"
|
|
771
|
+
df.to_csv(csv_path, index=False)
|
|
772
|
+
log.info(f"Saved '{category}_{key}' to {csv_path}")
|
|
773
|
+
else:
|
|
774
|
+
log.warning(f"Skipping non-DataFrame item '{key}' in '{category}'.")
|
|
775
|
+
elif isinstance(data, pd.DataFrame):
|
|
776
|
+
# If it's a single DataFrame, save it directly
|
|
777
|
+
csv_path = output_folder / f"{category}.csv"
|
|
778
|
+
data.to_csv(csv_path, index=False)
|
|
779
|
+
log.info(f"Saved '{category}' to {csv_path}")
|
|
780
|
+
else:
|
|
781
|
+
log.warning(f"Unsupported data type for '{category}'. Expected DataFrame or dict of DataFrames. Skipping.")
|
|
782
|
+
|
|
783
|
+
log.done("Results saved to CSV files in the output folder.")
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
|
|
788
|
+
|
|
789
|
+
|
|
790
|
+
### OLD FUNCTIONS
|
|
791
|
+
|
|
792
|
+
|
|
793
|
+
# new but withoutparallel
|
|
794
|
+
|
|
795
|
+
# def pra_percomplex(dataset_name, matrix, is_corr=False):
|
|
796
|
+
# log.started(f"*** Per-complex PRA started - {dataset_name} ***")
|
|
797
|
+
# config = dload("config")
|
|
798
|
+
# terms = dload("tmp", "terms")
|
|
799
|
+
# genes_present = dload("tmp", "genes_present_in_terms")
|
|
800
|
+
# sorting = dload("input", "sorting")
|
|
801
|
+
# sort_order = sorting.get(dataset_name, "high")
|
|
802
|
+
# if not is_corr:
|
|
803
|
+
# matrix = perform_corr(matrix, config.get("corr_function"))
|
|
804
|
+
# matrix = filter_matrix_by_genes(matrix, genes_present)
|
|
805
|
+
# log.info(f"Matrix shape: {matrix.shape}")
|
|
806
|
+
# df = binary(matrix)
|
|
807
|
+
# log.info(f"Pair-wise shape: {df.shape}")
|
|
808
|
+
# df = quick_sort(df, ascending=(sort_order == "low"))
|
|
809
|
+
# pairwise_df = df.copy()
|
|
810
|
+
# pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
|
|
811
|
+
# pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
|
|
812
|
+
|
|
813
|
+
# # Precompute a mapping from each gene to the row indices in the pairwise DataFrame where it appears.
|
|
814
|
+
# gene_to_pair_indices = {}
|
|
815
|
+
# for i, (gene_a, gene_b) in enumerate(zip(pairwise_df["gene1"], pairwise_df["gene2"])):
|
|
816
|
+
# gene_to_pair_indices.setdefault(gene_a, []).append(i)
|
|
817
|
+
# gene_to_pair_indices.setdefault(gene_b, []).append(i)
|
|
818
|
+
# log.done
|
|
819
|
+
|
|
820
|
+
# # Build gold_pair_to_complex using sets for efficiency
|
|
821
|
+
# gold_pair_to_complex = defaultdict(set)
|
|
822
|
+
# for idx, row in terms.iterrows():
|
|
823
|
+
# genes = row.used_genes
|
|
824
|
+
# if len(genes) < 2:
|
|
825
|
+
# continue
|
|
826
|
+
# for i, g1 in enumerate(genes):
|
|
827
|
+
# for g2 in genes[i + 1:]:
|
|
828
|
+
# pair = tuple(sorted((g1, g2)))
|
|
829
|
+
# gold_pair_to_complex[pair].add(idx)
|
|
830
|
+
|
|
831
|
+
# # Precompute complex_ids as semicolon-separated strings in pairwise_df
|
|
832
|
+
# pairs = [tuple(sorted((g1, g2))) for g1, g2 in zip(pairwise_df["gene1"], pairwise_df["gene2"])]
|
|
833
|
+
# pairwise_df['complex_ids'] = [';'.join(map(str, sorted(gold_pair_to_complex.get(pair, set())))) for pair in pairs]
|
|
834
|
+
|
|
835
|
+
# # Initialize AUC scores
|
|
836
|
+
# auc_scores = {}
|
|
837
|
+
# # Loop over each gene complex
|
|
838
|
+
# for idx, row in tqdm(terms.iterrows()):
|
|
839
|
+
# gene_set = set(row.used_genes)
|
|
840
|
+
# if config["min_genes_per_complex_analysis"] > len(gene_set):
|
|
841
|
+
# continue
|
|
842
|
+
# # Collect all row indices in the pairwise data where either gene belongs to the complex.
|
|
843
|
+
# candidate_indices = bitarray(len(pairwise_df))
|
|
844
|
+
# for gene in gene_set:
|
|
845
|
+
# if gene in gene_to_pair_indices:
|
|
846
|
+
# candidate_indices[gene_to_pair_indices[gene]] = True
|
|
847
|
+
|
|
848
|
+
# if not candidate_indices.any():
|
|
849
|
+
# continue
|
|
850
|
+
|
|
851
|
+
# # Select only the relevant pairwise comparisons.
|
|
852
|
+
# selected_rows = np.unpackbits(candidate_indices).view(bool)[:len(pairwise_df)]
|
|
853
|
+
# sub_df = pairwise_df.iloc[selected_rows]
|
|
854
|
+
|
|
855
|
+
# # Get current complex ID (assuming idx is the ID; adjust if row['ID'] is different)
|
|
856
|
+
# complex_id = str(idx) # Or str(row['ID']) if available
|
|
857
|
+
|
|
858
|
+
# # Create true_label: 1 if complex_id in complex_ids (vectorized with str.contains)
|
|
859
|
+
# #true_label = sub_df['complex_ids'].str.contains(complex_id, regex=False).astype(int)
|
|
860
|
+
|
|
861
|
+
# # Inside the loop, for each complex:
|
|
862
|
+
# # Inside the loop:
|
|
863
|
+
# complex_id = str(idx)
|
|
864
|
+
# # Use (?:^|;) and (?:;|$) to avoid capturing groups
|
|
865
|
+
# pattern = r'(?:^|;)' + re.escape(complex_id) + r'(?:;|$)'
|
|
866
|
+
# true_label = sub_df['complex_ids'].str.contains(pattern, regex=True).astype(int)
|
|
867
|
+
# # Filter to keep verified negatives (complex_ids == "") or positives for this complex (true_label == 1)
|
|
868
|
+
# complex_mask = (sub_df['complex_ids'] == "") | (true_label == 1)
|
|
869
|
+
|
|
870
|
+
# # Use the masked true labels for AUPRC (avoids SettingWithCopyWarning)
|
|
871
|
+
# predictions = true_label[complex_mask]
|
|
872
|
+
|
|
873
|
+
# if predictions.sum() == 0:
|
|
874
|
+
# continue
|
|
875
|
+
# # Compute cumulative true positives and derive precision and recall.
|
|
876
|
+
# true_positive_cumsum = predictions.cumsum()
|
|
877
|
+
# precision = true_positive_cumsum / (np.arange(len(predictions)) + 1)
|
|
878
|
+
# recall = true_positive_cumsum / true_positive_cumsum.iloc[-1]
|
|
879
|
+
|
|
880
|
+
# if len(recall) < 2 or recall.iloc[-1] == 0:
|
|
881
|
+
# continue
|
|
882
|
+
# auc_scores[idx] = metrics.auc(recall, precision)
|
|
883
|
+
|
|
884
|
+
# # Add the computed AUC scores to the terms DataFrame.
|
|
885
|
+
# terms["auc_score"] = pd.Series(auc_scores)
|
|
886
|
+
# terms.drop(columns=["hash"], inplace=True)
|
|
887
|
+
# dsave(terms, "pra_percomplex", dataset_name)
|
|
888
|
+
# log.done(f"Per-complex PRA completed.")
|
|
889
|
+
# return terms
|
|
890
|
+
|
|
891
|
+
|
|
892
|
+
|
|
893
|
+
# it works quick but only maps 1 complex to each pair
|
|
894
|
+
|
|
895
|
+
# def pra_percomplex_old_type_filtering(dataset_name, matrix, is_corr=False):
|
|
896
|
+
# log.started(f"*** Per-complex PRA started - {dataset_name} ***")
|
|
897
|
+
# config = dload("config")
|
|
898
|
+
# terms = dload("tmp", "terms")
|
|
899
|
+
# genes_present = dload("tmp", "genes_present_in_terms")
|
|
900
|
+
# sorting = dload("input", "sorting")
|
|
901
|
+
# sort_order = sorting.get(dataset_name, "high")
|
|
902
|
+
# if not is_corr:
|
|
903
|
+
# matrix = perform_corr(matrix, config.get("corr_function"))
|
|
904
|
+
# matrix = filter_matrix_by_genes(matrix, genes_present)
|
|
905
|
+
# log.info(f"Matrix shape: {matrix.shape}")
|
|
906
|
+
# df = binary(matrix)
|
|
907
|
+
# log.info(f"Pair-wise shape: {df.shape}")
|
|
908
|
+
# df = quick_sort(df, ascending=(sort_order == "low"))
|
|
909
|
+
# pairwise_df = df.copy()
|
|
910
|
+
# pairwise_df['gene1'] = pairwise_df['gene1'].astype("category")
|
|
911
|
+
# pairwise_df['gene2'] = pairwise_df['gene2'].astype("category")
|
|
912
|
+
# # Precompute a mapping from each gene to the row indices in the pairwise DataFrame where it appears.
|
|
913
|
+
# gene_to_pair_indices = {}
|
|
914
|
+
# for i, (gene_a, gene_b) in enumerate(zip(pairwise_df["gene1"], pairwise_df["gene2"])):
|
|
915
|
+
# gene_to_pair_indices.setdefault(gene_a, []).append(i)
|
|
916
|
+
# gene_to_pair_indices.setdefault(gene_b, []).append(i)
|
|
917
|
+
# # Initialize AUC scores (one for each complex) with NaNs.
|
|
918
|
+
# #auc_scores = np.full(len(terms), np.nan)
|
|
919
|
+
# auc_scores = {}
|
|
920
|
+
# # Loop over each gene complex
|
|
921
|
+
# for idx, row in tqdm(terms.iterrows()):
|
|
922
|
+
# gene_set = set(row.used_genes)
|
|
923
|
+
|
|
924
|
+
# if config["min_genes_per_complex_analysis"] > len(gene_set):
|
|
925
|
+
# continue
|
|
926
|
+
# # Collect all row indices in the pairwise data where either gene belongs to the complex.
|
|
927
|
+
# candidate_indices = bitarray(len(pairwise_df))
|
|
928
|
+
# for gene in gene_set:
|
|
929
|
+
# if gene in gene_to_pair_indices:
|
|
930
|
+
# candidate_indices[gene_to_pair_indices[gene]] = True
|
|
931
|
+
# if not candidate_indices.any():
|
|
932
|
+
# continue
|
|
933
|
+
# # Select only the relevant pairwise comparisons.
|
|
934
|
+
# selected_rows = np.unpackbits(candidate_indices).view(bool)[:len(pairwise_df)]
|
|
935
|
+
# sub_df = pairwise_df.iloc[selected_rows]
|
|
936
|
+
# # A prediction is 1 if both genes in the pair are in the complex; otherwise 0.
|
|
937
|
+
# predictions = (sub_df["gene1"].isin(gene_set) & sub_df["gene2"].isin(gene_set)).astype(int)
|
|
938
|
+
# if predictions.sum() == 0:
|
|
939
|
+
# continue
|
|
940
|
+
# # Compute cumulative true positives and derive precision and recall.
|
|
941
|
+
# true_positive_cumsum = predictions.cumsum()
|
|
942
|
+
# precision = true_positive_cumsum / (np.arange(len(predictions)) + 1)
|
|
943
|
+
# recall = true_positive_cumsum / true_positive_cumsum.iloc[-1]
|
|
944
|
+
|
|
945
|
+
# if len(recall) < 2 or recall.iloc[-1] == 0:
|
|
946
|
+
# continue
|
|
947
|
+
# auc_scores[idx] = metrics.auc(recall, precision)
|
|
948
|
+
# # Add the computed AUC scores to the terms DataFrame.
|
|
949
|
+
# terms["auc_score"] = pd.Series(auc_scores)
|
|
950
|
+
# terms.drop(columns=["hash"], inplace=True)
|
|
951
|
+
# dsave(terms, "pra_percomplex", dataset_name)
|
|
952
|
+
# log.done(f"Per-complex PRA completed.")
|
|
953
|
+
# return terms
|
|
954
|
+
|
|
955
|
+
|
|
956
|
+
|
|
957
|
+
# OLD
|
|
958
|
+
# def pra_percomplex(dataset_name, matrix, is_corr=False):
|
|
959
|
+
# log.started(f"*** Per-complex PRA started for {dataset_name} ***")
|
|
960
|
+
# config = dload("config")
|
|
961
|
+
# terms = dload("tmp", "terms")
|
|
962
|
+
# genes_present = dload("tmp", "genes_present_in_terms")
|
|
963
|
+
# sorting = dload("input", "sorting")
|
|
964
|
+
# sort_order = sorting.get(dataset_name, "high")
|
|
965
|
+
|
|
966
|
+
# if not is_corr:
|
|
967
|
+
# matrix = perform_corr(matrix, "numpy")
|
|
968
|
+
# matrix = filter_matrix_by_genes(matrix, genes_present)
|
|
969
|
+
# log.info(f"Matrix shape: {matrix.shape}")
|
|
970
|
+
# df = binary(matrix)
|
|
971
|
+
# log.info(f"Pair-wise shape: {df.shape}")
|
|
972
|
+
# df = quick_sort(df, ascending=(sort_order == "low"))
|
|
973
|
+
# # Precompute gene → row indices
|
|
974
|
+
# gene_to_rows = {}
|
|
975
|
+
# for i, (g1, g2) in enumerate(zip(df["gene1"], df["gene2"])):
|
|
976
|
+
# gene_to_rows.setdefault(g1, []).append(i)
|
|
977
|
+
# gene_to_rows.setdefault(g2, []).append(i)
|
|
978
|
+
# aucs = np.full(len(terms), np.nan)
|
|
979
|
+
# N = len(df)
|
|
980
|
+
# for idx, row in tqdm(terms.iterrows()):
|
|
981
|
+
# genes = set(row.used_genes)
|
|
982
|
+
# if len(genes) < config["min_complex_size_for_percomplex"]: # Skip small complexes
|
|
983
|
+
# continue
|
|
984
|
+
# # Get all row indices where either gene is in the complex
|
|
985
|
+
# candidate_idxs = set()
|
|
986
|
+
# for g in genes:
|
|
987
|
+
# candidate_idxs.update(gene_to_rows.get(g, []))
|
|
988
|
+
# candidate_idxs = sorted(candidate_idxs)
|
|
989
|
+
# if not candidate_idxs:
|
|
990
|
+
# continue
|
|
991
|
+
# # Use only relevant rows for prediction
|
|
992
|
+
# sub = df.loc[candidate_idxs]
|
|
993
|
+
# preds = (sub["gene1"].isin(genes) & sub["gene2"].isin(genes)).astype(int)
|
|
994
|
+
# if preds.sum() == 0:
|
|
995
|
+
# continue
|
|
996
|
+
# tp = preds.cumsum()
|
|
997
|
+
# prec = tp / (np.arange(len(preds)) + 1)
|
|
998
|
+
# recall = tp / tp.iloc[-1]
|
|
999
|
+
# if len(recall) < 2 or recall.iloc[-1] == 0:
|
|
1000
|
+
# continue
|
|
1001
|
+
# aucs[idx] = metrics.auc(recall, prec)
|
|
1002
|
+
# terms["auc_score"] = aucs
|
|
1003
|
+
# terms.drop(columns=["list", "set", "hash"], inplace=True)
|
|
1004
|
+
# dsave(terms, "pra_percomplex", dataset_name)
|
|
1005
|
+
# log.done(f"Per-complex PRA completed.")
|
|
1006
|
+
# return terms
|
|
1007
|
+
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
# without greedy
|
|
1016
|
+
# def complex_contributions(name):
|
|
1017
|
+
# log.info(f"Computing complex contributions for dataset: {name}")
|
|
1018
|
+
|
|
1019
|
+
# pra = dload("pra", name)
|
|
1020
|
+
# terms = dload("tmp", "terms")
|
|
1021
|
+
# d = pra.query('prediction == 1').drop(columns=['gene1', 'gene2'])
|
|
1022
|
+
# results = {}
|
|
1023
|
+
# thresholds = [round(i, 2) for i in np.arange(1, 0.0001, -0.025)]
|
|
1024
|
+
# for cid in terms.ID.to_list():
|
|
1025
|
+
# arr = []
|
|
1026
|
+
# for threshold in thresholds:
|
|
1027
|
+
# r = d[d.complex_id == cid].query('precision >= @threshold')
|
|
1028
|
+
# arr.append(r.shape[0])
|
|
1029
|
+
# results[cid] = arr
|
|
1030
|
+
|
|
1031
|
+
# r = pd.DataFrame(results, index=thresholds).T
|
|
1032
|
+
# t = terms[['ID', 'Name']].set_index('ID')
|
|
1033
|
+
# r['Name'] = r.index.map(t.Name)
|
|
1034
|
+
# r = r[list(reversed(list(r.columns)))]
|
|
1035
|
+
# r = r.reset_index(drop=True)
|
|
1036
|
+
# dsave(r, "complex_contributions", name)
|
|
1037
|
+
# log.info(f"Complex contributions computation completed for dataset: {name}")
|
|
1038
|
+
# return r
|
|
1039
|
+
|
|
1040
|
+
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
|
|
1044
|
+
|
|
1045
|
+
# # new
|
|
1046
|
+
# def complex_contributions(name):
|
|
1047
|
+
# log.info(f"Computing complex contributions using R-style greedy logic for dataset: {name}")
|
|
1048
|
+
# pra = dload("pra", name)
|
|
1049
|
+
# terms = dload("common", "terms")
|
|
1050
|
+
|
|
1051
|
+
# # Ensure pra is sorted by score descending
|
|
1052
|
+
# pra = pra.sort_values(by='score', ascending=False).reset_index(drop=True)
|
|
1053
|
+
|
|
1054
|
+
# # Compute cumulative TP and precision if not present
|
|
1055
|
+
# pra['cumTP'] = pra['prediction'].cumsum()
|
|
1056
|
+
# pra['rank'] = pra.index + 1
|
|
1057
|
+
# pra['precision'] = pra['cumTP'] / pra['rank']
|
|
1058
|
+
|
|
1059
|
+
# # R-style precision thresholds
|
|
1060
|
+
# prec_min = pra['precision'].min()
|
|
1061
|
+
# prec_max = pra['precision'].max()
|
|
1062
|
+
# precision_cutoffs = [round(prec_min, 3)]
|
|
1063
|
+
# cutoffs_range = np.arange(0.1, prec_max + 0.001, 0.025)
|
|
1064
|
+
# precision_cutoffs += [round(t, 3) for t in cutoffs_range if t > prec_min]
|
|
1065
|
+
# thresholds = sorted(set(precision_cutoffs)) # Ensure unique and sorted
|
|
1066
|
+
|
|
1067
|
+
# results = {}
|
|
1068
|
+
# for t in thresholds:
|
|
1069
|
+
# if pra['precision'].max() < t:
|
|
1070
|
+
# continue
|
|
1071
|
+
# cand = pra[pra['precision'] >= t]
|
|
1072
|
+
# if cand.empty:
|
|
1073
|
+
# continue
|
|
1074
|
+
# k = cand.index.max() # rightmost index where precision >= t
|
|
1075
|
+
# tp_target = pra.loc[k, 'cumTP']
|
|
1076
|
+
# # Find the smallest m where cumTP[m] >= tp_target
|
|
1077
|
+
# ind = pra[pra['cumTP'] >= tp_target].index.min()
|
|
1078
|
+
# if pd.isna(ind):
|
|
1079
|
+
# continue
|
|
1080
|
+
# # Select top (ind+1) rows
|
|
1081
|
+
# tmp = pra.iloc[0:ind + 1].copy()
|
|
1082
|
+
# # Filter for predicted positives (true == 1)
|
|
1083
|
+
# tmp = tmp[tmp['prediction'] == 1]
|
|
1084
|
+
# tmp = tmp[tmp["complex_id"].notnull()]
|
|
1085
|
+
# tmp["ID"] = tmp["complex_id"].apply(lambda ids: ";".join(str(int(i)) for i in ids if pd.notnull(i)))
|
|
1086
|
+
# # Now greedy logic
|
|
1087
|
+
# final_contrib = {}
|
|
1088
|
+
# while not tmp.empty:
|
|
1089
|
+
# all_ids = tmp["ID"].str.split(";").explode()
|
|
1090
|
+
# contrib = all_ids.value_counts()
|
|
1091
|
+
# if contrib.empty:
|
|
1092
|
+
# break
|
|
1093
|
+
# top_id = contrib.idxmax()
|
|
1094
|
+
# final_contrib[top_id] = contrib[top_id]
|
|
1095
|
+
# tmp = tmp[~tmp["ID"].str.contains(rf"\b{top_id}\b", regex=True)]
|
|
1096
|
+
# for cid, count in final_contrib.items():
|
|
1097
|
+
# if cid not in results:
|
|
1098
|
+
# results[cid] = [0] * len(thresholds)
|
|
1099
|
+
# results[cid][thresholds.index(t)] = count
|
|
1100
|
+
|
|
1101
|
+
# # Add back gold standard complexes with 0 contribution
|
|
1102
|
+
# gold_ids = set(terms.index.astype(str))
|
|
1103
|
+
# all_ids = set(results.keys())
|
|
1104
|
+
# missing_ids = gold_ids - all_ids
|
|
1105
|
+
# for cid in missing_ids:
|
|
1106
|
+
# results[cid] = [0] * len(thresholds)
|
|
1107
|
+
|
|
1108
|
+
# # Build result DataFrame
|
|
1109
|
+
# r = pd.DataFrame(results, index=thresholds).T
|
|
1110
|
+
# r['Name'] = r.index.astype(int).map(terms['Name'])
|
|
1111
|
+
# r = r[['Name'] + [c for c in r.columns if c != 'Name']] # Name as first col
|
|
1112
|
+
# r = r[(r.drop(columns="Name").sum(axis=1) > 0)]
|
|
1113
|
+
# # Move ID to first column, keep Name second, then precision columns in order
|
|
1114
|
+
# dsave(r, "complex_contributions", name)
|
|
1115
|
+
# log.info(f"Greedy R-style complex contribution completed for dataset: {name}")
|
|
1116
|
+
# return r
|
|
1117
|
+
|
|
1118
|
+
|
|
1119
|
+
|
|
1120
|
+
# def pra(dataset_name, matrix, is_corr=False):
|
|
1121
|
+
# log.info(f"******************** {dataset_name} ********************")
|
|
1122
|
+
# log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
|
|
1123
|
+
# config = dload("config")
|
|
1124
|
+
|
|
1125
|
+
# terms_data = dload("tmp", "terms")
|
|
1126
|
+
# if terms_data is None or not isinstance(terms_data, pd.DataFrame):
|
|
1127
|
+
# raise ValueError("Expected 'terms' to be a DataFrame, but got None or invalid type.")
|
|
1128
|
+
# terms = terms_data
|
|
1129
|
+
# genes_present = dload("tmp", "genes_present_in_terms")
|
|
1130
|
+
# sorting = dload("input", "sorting")
|
|
1131
|
+
# sort_order = sorting.get(dataset_name, "high")
|
|
1132
|
+
|
|
1133
|
+
# if not is_corr:
|
|
1134
|
+
# matrix = perform_corr(matrix, config.get("corr_function"))
|
|
1135
|
+
|
|
1136
|
+
# matrix = filter_matrix_by_genes(matrix, genes_present)
|
|
1137
|
+
|
|
1138
|
+
# log.info(f"Matrix shape: {matrix.shape}")
|
|
1139
|
+
# df = binary(matrix)
|
|
1140
|
+
# log.info(f"Pair-wise shape: {df.shape}")
|
|
1141
|
+
# df = quick_sort(df, ascending=(sort_order == "low"))
|
|
1142
|
+
|
|
1143
|
+
# gold_pair_to_complex = defaultdict(list)
|
|
1144
|
+
# for idx, row in terms.iterrows():
|
|
1145
|
+
# genes = row.used_genes
|
|
1146
|
+
# if len(genes) < 2:
|
|
1147
|
+
# continue
|
|
1148
|
+
# for i, g1 in enumerate(genes):
|
|
1149
|
+
# for g2 in genes[i + 1:]:
|
|
1150
|
+
# pair = tuple(sorted((g1, g2)))
|
|
1151
|
+
# gold_pair_to_complex[pair].append(idx)
|
|
1152
|
+
|
|
1153
|
+
|
|
1154
|
+
# # Label predictions and complex IDs
|
|
1155
|
+
# complex_ids = []
|
|
1156
|
+
# predictions = []
|
|
1157
|
+
# for g1, g2 in zip(df["gene1"], df["gene2"]):
|
|
1158
|
+
# pair = tuple(sorted((g1, g2)))
|
|
1159
|
+
# ids = gold_pair_to_complex.get(pair, [])
|
|
1160
|
+
# if ids:
|
|
1161
|
+
# predictions.append(1)
|
|
1162
|
+
# complex_ids.append(ids)
|
|
1163
|
+
# else:
|
|
1164
|
+
# predictions.append(0)
|
|
1165
|
+
# complex_ids.append([])
|
|
1166
|
+
|
|
1167
|
+
# df["prediction"] = predictions
|
|
1168
|
+
# df["complex_id"] = complex_ids
|
|
1169
|
+
|
|
1170
|
+
# if df["prediction"].sum() == 0:
|
|
1171
|
+
# log.info("No true positives found in dataset.")
|
|
1172
|
+
# pr_auc = np.nan
|
|
1173
|
+
# else:
|
|
1174
|
+
# tp = df["prediction"].cumsum()
|
|
1175
|
+
# df["tp"] = tp
|
|
1176
|
+
# precision = tp / (np.arange(len(df)) + 1)
|
|
1177
|
+
# recall = tp / tp.iloc[-1]
|
|
1178
|
+
# pr_auc = metrics.auc(recall, precision)
|
|
1179
|
+
# df["precision"] = precision
|
|
1180
|
+
# df["recall"] = recall
|
|
1181
|
+
|
|
1182
|
+
# log.info(f"PR-AUC: {pr_auc:.4f}, Number of true positives: {df['prediction'].sum()}")
|
|
1183
|
+
# dsave(df, "pra", dataset_name)
|
|
1184
|
+
# dsave(pr_auc, "pr_auc", dataset_name)
|
|
1185
|
+
# log.done(f"Global PRA completed for {dataset_name}")
|
|
1186
|
+
# return df, pr_auc
|
|
1187
|
+
|
|
1188
|
+
|
|
1189
|
+
|
|
1190
|
+
# def compute_pra(df):
|
|
1191
|
+
# log.info("Calculating precision-recall and AUC score.")
|
|
1192
|
+
# if df.empty:
|
|
1193
|
+
# log.warning("Empty DataFrame encountered in compute_pra. Returning empty DataFrame.")
|
|
1194
|
+
# return df
|
|
1195
|
+
# df["tp"] = df["prediction"].cumsum()
|
|
1196
|
+
# df.reset_index(drop=True, inplace=True)
|
|
1197
|
+
# df["precision"] = df["tp"] / (df.index + 1)
|
|
1198
|
+
# df["recall"] = df["tp"] / df["tp"].iloc[-1]
|
|
1199
|
+
# log.info("DONE: Calculating precision-recall AUC score.")
|
|
1200
|
+
# return df
|
|
1201
|
+
|
|
1202
|
+
|
|
1203
|
+
# def pra(dataset_name, matrix, is_corr=False):
|
|
1204
|
+
# log.info(f"PRA computation started for {dataset_name}.")
|
|
1205
|
+
# genes_present_in_terms = dload("tmp", "genes_present_in_terms")
|
|
1206
|
+
# #terms_hash_table = dload("tmp", "terms_hash_table")
|
|
1207
|
+
# sorting_prefs = dload("input", "sorting")
|
|
1208
|
+
# sort_order = sorting_prefs.get(dataset_name, "high")
|
|
1209
|
+
# if not is_corr: matrix = perform_corr(matrix, "numpy")
|
|
1210
|
+
# matrix = filter_matrix_by_genes(matrix, genes_present_in_terms)
|
|
1211
|
+
# stack = binary(matrix)
|
|
1212
|
+
|
|
1213
|
+
# log.info("Checking gene pairs against the gold standard.")
|
|
1214
|
+
# gene_pairs = list(zip(stack["gene1"], stack["gene2"]))
|
|
1215
|
+
# hashed_pairs = [hash(pair) for pair in gene_pairs]
|
|
1216
|
+
# stack["complex_id"] = [terms_hash_table.get(h, 0) for h in hashed_pairs]
|
|
1217
|
+
# stack["prediction"] = [1 if h in terms_hash_table else 0 for h in hashed_pairs]
|
|
1218
|
+
|
|
1219
|
+
# annotated = stack.copy()
|
|
1220
|
+
# if sort_order == "low":
|
|
1221
|
+
# ann_sorted = quick_sort(annotated, ascending=True)
|
|
1222
|
+
# else:
|
|
1223
|
+
# ann_sorted = quick_sort(annotated)
|
|
1224
|
+
|
|
1225
|
+
# pra = compute_pra(ann_sorted)
|
|
1226
|
+
# pr_auc = metrics.auc(pra.recall, pra.precision)
|
|
1227
|
+
# dsave(pra, "pra", dataset_name)
|
|
1228
|
+
# dsave(pr_auc, "pr_auc", dataset_name)
|
|
1229
|
+
# log.info(f"PRA computation completed for {dataset_name} (Sorting: {sort_order}).")
|
|
1230
|
+
# return pra, pr_auc
|
|
1231
|
+
|
|
1232
|
+
|
|
1233
|
+
|
|
1234
|
+
|
|
1235
|
+
# new but not seperated to functions (Build gold standard etc.)
|
|
1236
|
+
|
|
1237
|
+
# def pra(dataset_name, matrix, is_corr=False):
|
|
1238
|
+
# log.info(f"******************** {dataset_name} ********************")
|
|
1239
|
+
# log.started(f"** Global Precision-Recall Analysis - {dataset_name} **")
|
|
1240
|
+
# terms_data = dload("tmp", "terms")
|
|
1241
|
+
# if terms_data is None or not isinstance(terms_data, pd.DataFrame):
|
|
1242
|
+
# raise ValueError("Expected 'terms' to be a DataFrame, but got None or invalid type.")
|
|
1243
|
+
# terms = terms_data.reset_index(drop=True)
|
|
1244
|
+
# genes_present = dload("tmp", "genes_present_in_terms")
|
|
1245
|
+
# sorting = dload("input", "sorting")
|
|
1246
|
+
# sort_order = sorting.get(dataset_name, "high")
|
|
1247
|
+
|
|
1248
|
+
# if not is_corr:
|
|
1249
|
+
# matrix = perform_corr(matrix, "numpy")
|
|
1250
|
+
|
|
1251
|
+
# matrix = filter_matrix_by_genes(matrix, genes_present)
|
|
1252
|
+
|
|
1253
|
+
# log.info(f"Matrix shape: {matrix.shape}")
|
|
1254
|
+
# df = binary(matrix)
|
|
1255
|
+
# log.info(f"Pair-wise shape: {df.shape}")
|
|
1256
|
+
# df = quick_sort(df, ascending=(sort_order == "low"))
|
|
1257
|
+
# df = df.reset_index(drop=True)
|
|
1258
|
+
|
|
1259
|
+
# # Build gold standard: map pair → complex ID
|
|
1260
|
+
# gold_pair_to_complex = {}
|
|
1261
|
+
# for idx, row in terms.iterrows():
|
|
1262
|
+
# genes = row.used_genes
|
|
1263
|
+
# if len(genes) < 2:
|
|
1264
|
+
# continue
|
|
1265
|
+
# for i, g1 in enumerate(genes):
|
|
1266
|
+
# for g2 in genes[i + 1:]:
|
|
1267
|
+
# pair = tuple(sorted((g1, g2)))
|
|
1268
|
+
# gold_pair_to_complex[pair] = idx
|
|
1269
|
+
|
|
1270
|
+
# # Label predictions and complex IDs
|
|
1271
|
+
# complex_ids = []
|
|
1272
|
+
# predictions = []
|
|
1273
|
+
# for g1, g2 in zip(df["gene1"], df["gene2"]):
|
|
1274
|
+
# pair = tuple(sorted((g1, g2)))
|
|
1275
|
+
# if pair in gold_pair_to_complex:
|
|
1276
|
+
# predictions.append(1)
|
|
1277
|
+
# complex_ids.append(gold_pair_to_complex[pair])
|
|
1278
|
+
# else:
|
|
1279
|
+
# predictions.append(0)
|
|
1280
|
+
# complex_ids.append(0)
|
|
1281
|
+
# df["prediction"] = predictions
|
|
1282
|
+
# df["complex_id"] = complex_ids
|
|
1283
|
+
# if df["prediction"].sum() == 0:
|
|
1284
|
+
# log.info("No true positives found in dataset.")
|
|
1285
|
+
# pr_auc = np.nan
|
|
1286
|
+
# else:
|
|
1287
|
+
# tp = df["prediction"].cumsum()
|
|
1288
|
+
# df["tp"] = tp
|
|
1289
|
+
# precision = tp / (np.arange(len(df)) + 1)
|
|
1290
|
+
# recall = tp / tp.iloc[-1]
|
|
1291
|
+
# pr_auc = metrics.auc(recall, precision)
|
|
1292
|
+
# df["precision"] = precision
|
|
1293
|
+
# df["recall"] = recall
|
|
1294
|
+
# log.info(f"PR-AUC: {pr_auc:.4f}, Number of true positives: {df['prediction'].sum()}")
|
|
1295
|
+
# dsave(df, "pra", dataset_name)
|
|
1296
|
+
# dsave(pr_auc, "pr_auc", dataset_name)
|
|
1297
|
+
# log.done(f"Global PRA completed for {dataset_name}")
|
|
1298
|
+
# return df, pr_auc
|
|
1299
|
+
|