pythonflex 0.3.2__tar.gz → 0.3.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (27) hide show
  1. {pythonflex-0.3.2 → pythonflex-0.3.4}/PKG-INFO +6 -3
  2. {pythonflex-0.3.2 → pythonflex-0.3.4}/README.md +5 -2
  3. {pythonflex-0.3.2 → pythonflex-0.3.4}/pyproject.toml +1 -1
  4. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/__init__.py +2 -2
  5. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/analysis.py +83 -1
  6. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/examples/basic_usage.py +15 -14
  7. pythonflex-0.3.4/src/pythonflex/examples/manuscript.py +112 -0
  8. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/plotting.py +128 -7
  9. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/preprocessing.py +188 -24
  10. {pythonflex-0.3.2 → pythonflex-0.3.4}/.gitignore +0 -0
  11. {pythonflex-0.3.2 → pythonflex-0.3.4}/.python-version +0 -0
  12. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/__init__.py +0 -0
  13. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/dataset/__init__.py +0 -0
  14. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/dataset/liver_cell_lines_500_genes.csv +0 -0
  15. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/dataset/melanoma_cell_lines_500_genes.csv +0 -0
  16. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/dataset/neuroblastoma_cell_lines_500_genes.csv +0 -0
  17. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/gold_standard/CORUM.parquet +0 -0
  18. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/gold_standard/GOBP.parquet +0 -0
  19. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/gold_standard/PATHWAY.parquet +0 -0
  20. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/gold_standard/__init__.py +0 -0
  21. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/gold_standard/corum.csv +0 -0
  22. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/gold_standard/gobp.csv +0 -0
  23. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/data/gold_standard/pathway.csv +0 -0
  24. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/logging_config.py +0 -0
  25. {pythonflex-0.3.2 → pythonflex-0.3.4}/src/pythonflex/utils.py +0 -0
  26. {pythonflex-0.3.2 → pythonflex-0.3.4}/todo.txt +0 -0
  27. {pythonflex-0.3.2 → pythonflex-0.3.4}/uv.lock +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pythonflex
3
- Version: 0.3.2
3
+ Version: 0.3.4
4
4
  Summary: pythonFLEX is a benchmarking toolkit for evaluating CRISPR screen results against biological gold standards. The toolkit computes gene-level and complex-level performance metrics, helping researchers systematically assess the biological relevance and resolution of their CRISPR screening data.
5
5
  Author-email: Yasir Demirtaş <tyasird@hotmail.com>
6
6
  Classifier: License :: OSI Approved :: MIT License
@@ -114,6 +114,7 @@ default_config = {
114
114
  "gold_standard": "GOBP",
115
115
  "color_map": "RdYlBu",
116
116
  "jaccard": True,
117
+ "jaccard_threshold": 1.0, # set e.g. 0.90 to remove highly similar terms
117
118
  "plotting": {
118
119
  "save_plot": True,
119
120
  "output_type": "png",
@@ -124,7 +125,7 @@ default_config = {
124
125
  },
125
126
  "corr_function": "numpy",
126
127
  "logging": {
127
- "visible_levels": ["DONE","STARTED"] # "PROGRESS", "STARTED", ,"INFO","WARNING"
128
+ "visible_levels": ["DONE","INFO", "WARNING"] # "PROGRESS", "STARTED", ,"INFO","WARNING"
128
129
  }
129
130
  }
130
131
 
@@ -149,8 +150,10 @@ flex.plot_percomplex_scatter()
149
150
  flex.plot_percomplex_scatter_bysize()
150
151
  flex.plot_significant_complexes()
151
152
  flex.plot_complex_contributions()
153
+ flex.plot_mpr_tp_multi(show_filters="all")
154
+ flex.plot_mpr_complexes_multi(show_filters="all")
155
+ flex.plot_mpr_complexes_auc_scores("all")
152
156
 
153
- # Save Result CSVspyflex.save_results_to_csv()
154
157
  flex.save_results_to_csv()
155
158
 
156
159
 
@@ -83,6 +83,7 @@ default_config = {
83
83
  "gold_standard": "GOBP",
84
84
  "color_map": "RdYlBu",
85
85
  "jaccard": True,
86
+ "jaccard_threshold": 1.0, # set e.g. 0.90 to remove highly similar terms
86
87
  "plotting": {
87
88
  "save_plot": True,
88
89
  "output_type": "png",
@@ -93,7 +94,7 @@ default_config = {
93
94
  },
94
95
  "corr_function": "numpy",
95
96
  "logging": {
96
- "visible_levels": ["DONE","STARTED"] # "PROGRESS", "STARTED", ,"INFO","WARNING"
97
+ "visible_levels": ["DONE","INFO", "WARNING"] # "PROGRESS", "STARTED", ,"INFO","WARNING"
97
98
  }
98
99
  }
99
100
 
@@ -118,8 +119,10 @@ flex.plot_percomplex_scatter()
118
119
  flex.plot_percomplex_scatter_bysize()
119
120
  flex.plot_significant_complexes()
120
121
  flex.plot_complex_contributions()
122
+ flex.plot_mpr_tp_multi(show_filters="all")
123
+ flex.plot_mpr_complexes_multi(show_filters="all")
124
+ flex.plot_mpr_complexes_auc_scores("all")
121
125
 
122
- # Save Result CSVspyflex.save_results_to_csv()
123
126
  flex.save_results_to_csv()
124
127
 
125
128
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "pythonflex"
3
- version = "0.3.2"
3
+ version = "0.3.4"
4
4
  description = "pythonFLEX is a benchmarking toolkit for evaluating CRISPR screen results against biological gold standards. The toolkit computes gene-level and complex-level performance metrics, helping researchers systematically assess the biological relevance and resolution of their CRISPR screening data."
5
5
  readme = "README.md"
6
6
  authors = [
@@ -5,7 +5,7 @@ from .analysis import initialize, pra, pra_percomplex, fast_corr, perform_corr,
5
5
  from .plotting import (
6
6
  adjust_text_positions, plot_precision_recall_curve, plot_aggregated_pra, plot_iqr_pra, plot_all_runs_pra, plot_percomplex_scatter,
7
7
  plot_percomplex_scatter_bysize, plot_complex_contributions, plot_significant_complexes, plot_auc_scores,
8
- plot_mpr_tp, plot_mpr_complexes, plot_mpr_tp_multi, plot_mpr_complexes_multi
8
+ plot_mpr_tp, plot_mpr_complexes, plot_mpr_tp_multi, plot_mpr_complexes_multi, plot_mpr_complexes_auc_scores
9
9
  )
10
10
 
11
11
  __all__ = [ "log", "get_example_data_path", "fast_corr",
@@ -14,7 +14,7 @@ __all__ = [ "log", "get_example_data_path", "fast_corr",
14
14
  "perform_corr", "is_symmetric", "binary", "has_mirror_of_first_pair", "convert_full_to_half_matrix",
15
15
  "drop_mirror_pairs", "quick_sort", "complex_contributions", "adjust_text_positions", "plot_precision_recall_curve",
16
16
  "plot_aggregated_pra", "plot_iqr_pra", "plot_all_runs_pra", "plot_percomplex_scatter", "plot_percomplex_scatter_bysize", "plot_complex_contributions",
17
- "plot_significant_complexes", "plot_auc_scores", "save_results_to_csv", "update_matploblib_config",
17
+ "plot_significant_complexes", "plot_auc_scores", "plot_mpr_complexes_auc_scores", "save_results_to_csv", "update_matploblib_config",
18
18
  "mpr_prepare", "plot_mpr_tp", "plot_mpr_complexes",
19
19
  "plot_mpr_tp_multi", "plot_mpr_complexes_multi"
20
20
  ]
@@ -43,6 +43,7 @@ def initialize(config={}):
43
43
  "gold_standard": "CORUM",
44
44
  "color_map": "RdYlBu",
45
45
  "jaccard": True,
46
+ "jaccard_threshold": 1.0,
46
47
  "use_common_genes": True,
47
48
  "plotting": {
48
49
  "save_plot": True,
@@ -844,7 +845,7 @@ def quick_sort(df, ascending=False):
844
845
  log.done("Pair-wise matrix sorting.")
845
846
  return sorted_df
846
847
 
847
- def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_percomplex"]):
848
+ def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_percomplex", "mpr_complexes_auc"]):
848
849
 
849
850
  config = dload("config") # Load config to get output folder
850
851
  output_folder = Path(config.get("output_folder", "output"))
@@ -856,6 +857,18 @@ def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_pe
856
857
  if data is None:
857
858
  log.warning(f"No data found for category '{category}'. Skipping save.")
858
859
  continue
860
+
861
+ if category == "mpr_complexes_auc" and isinstance(data, dict):
862
+ # Dict[dataset_name -> Dict[filter_key -> auc]]
863
+ try:
864
+ df = pd.DataFrame.from_dict(data, orient="index")
865
+ df.index.name = "Dataset"
866
+ csv_path = output_folder / f"{category}.csv"
867
+ df.to_csv(csv_path, index=True)
868
+ log.info(f"Saved '{category}' to {csv_path}")
869
+ except Exception as e:
870
+ log.warning(f"Failed to convert and save '{category}': {e}")
871
+ continue
859
872
 
860
873
  if category == "pr_auc" and isinstance(data, dict):
861
874
  # Special handling: Convert dict to DataFrame (assuming keys are indices, values are data)
@@ -1312,6 +1325,64 @@ def _mpr_module_coverage(contrib_df, terms, tp_th=1, percent_th=0.1):
1312
1325
  return coverage
1313
1326
 
1314
1327
 
1328
+ def _mpr_complexes_auc(
1329
+ coverage: np.ndarray,
1330
+ precision_cutoffs: np.ndarray,
1331
+ max_complexes: float = 200.0,
1332
+ ) -> float:
1333
+ """Compute AUC for the Fig. 1F-style mPR curve (#complexes vs precision).
1334
+
1335
+ The plot uses:
1336
+ x = #covered complexes (capped at `max_complexes`, shown on a log axis)
1337
+ y = precision cutoff
1338
+
1339
+ We compute a normalized AUC by integrating precision over the *normalized*
1340
+ coverage axis:
1341
+ AUC = \int y \, d(x/max_complexes)
1342
+
1343
+ This yields a score in [0, 1] (or NaN if insufficient data).
1344
+ """
1345
+ cov = np.asarray(coverage, dtype=float)
1346
+ prec = np.asarray(precision_cutoffs, dtype=float)
1347
+
1348
+ if cov.size == 0 or prec.size == 0:
1349
+ return 0.0
1350
+
1351
+ # Match plot_mpr_complexes_multi(): only count cov>0 (log-x cannot show 0)
1352
+ mask = (
1353
+ np.isfinite(cov)
1354
+ & np.isfinite(prec)
1355
+ & (cov > 0)
1356
+ & (cov <= max_complexes)
1357
+ & (prec >= 0)
1358
+ & (prec <= 1.0)
1359
+ )
1360
+ if not np.any(mask):
1361
+ return 0.0
1362
+
1363
+ x_cov = cov[mask]
1364
+ y = prec[mask]
1365
+
1366
+ # x-axis is log-scaled in the plot; normalize so cov=1 -> 0, cov=max_complexes -> 1
1367
+ # (This matches the plot's tick hack where 1 is labeled as "0".)
1368
+ x = np.log10(x_cov) / np.log10(float(max_complexes))
1369
+
1370
+ # Sort by x and collapse duplicate x values by taking max y (upper envelope)
1371
+ order = np.argsort(x)
1372
+ x = x[order]
1373
+ y = y[order]
1374
+
1375
+ x_unique = np.unique(x)
1376
+ if x_unique.size != x.size:
1377
+ y = np.array([float(np.nanmax(y[x == xv])) for xv in x_unique], dtype=float)
1378
+ x = x_unique
1379
+
1380
+ if x.size < 2:
1381
+ return 0.0
1382
+
1383
+ return float(np.trapz(y, x))
1384
+
1385
+
1315
1386
 
1316
1387
 
1317
1388
 
@@ -1379,6 +1450,7 @@ def mpr_prepare(
1379
1450
 
1380
1451
  tp_curves = {}
1381
1452
  coverage_curves = {}
1453
+ complexes_auc = {}
1382
1454
  precision_cutoffs = None
1383
1455
 
1384
1456
  for label, removed in filter_sets.items():
@@ -1393,6 +1465,7 @@ def mpr_prepare(
1393
1465
  "precision": np.array([], dtype=float),
1394
1466
  }
1395
1467
  coverage_curves[label] = np.zeros(0, dtype=float)
1468
+ complexes_auc[label] = float("nan")
1396
1469
  continue
1397
1470
 
1398
1471
  tp_cum = true.cumsum()
@@ -1417,11 +1490,17 @@ def mpr_prepare(
1417
1490
  percent_th=percent_th,
1418
1491
  )
1419
1492
  coverage_curves[label] = cov
1493
+ complexes_auc[label] = _mpr_complexes_auc(
1494
+ cov,
1495
+ precision_cutoffs,
1496
+ max_complexes=200.0,
1497
+ )
1420
1498
 
1421
1499
  mpr_data = {
1422
1500
  "precision_cutoffs": precision_cutoffs,
1423
1501
  "tp_curves": tp_curves,
1424
1502
  "coverage_curves": coverage_curves,
1503
+ "complexes_auc": complexes_auc,
1425
1504
  "filters": {
1426
1505
  "no_mtRibo_ETCI": sorted(mtRibo_ids),
1427
1506
  "no_small_highAUPRC": sorted(small_hi_ids),
@@ -1435,6 +1514,9 @@ def mpr_prepare(
1435
1514
 
1436
1515
  dsave(mpr_data, "mpr", name)
1437
1516
 
1517
+ # Convenience: store AUCs as their own category for easy export / plotting.
1518
+ dsave(complexes_auc, "mpr_complexes_auc", name)
1519
+
1438
1520
 
1439
1521
 
1440
1522
  ### OLD FUNCTIONS
@@ -31,7 +31,8 @@ default_config = {
31
31
  "output_folder": "CORUM",
32
32
  "gold_standard": "CORUM",
33
33
  "color_map": "BuGn",
34
- "jaccard": False,
34
+ "jaccard": True,
35
+ "jaccard_threshold": 1,
35
36
  "use_common_genes": False, # Set to False for individual dataset-gold standard intersections
36
37
  "plotting": {
37
38
  "save_plot": True,
@@ -61,26 +62,26 @@ for name, dataset in data.items():
61
62
  fpc = flex.pra_percomplex(name, dataset, is_corr=False)
62
63
  cc = flex.complex_contributions(name)
63
64
  flex.mpr_prepare(name)
65
+
64
66
 
65
67
 
66
68
 
67
69
 
68
70
  #%%
69
71
  # Generate plots
70
- flex.plot_precision_recall_curve()
71
- flex.plot_auc_scores()
72
- flex.plot_significant_complexes()
73
- flex.plot_percomplex_scatter(n_top=20)
74
- flex.plot_percomplex_scatter_bysize()
75
- flex.plot_complex_contributions()
76
- ##
77
- flex.plot_mpr_tp_multi()
78
- flex.plot_mpr_complexes_multi()
72
+ # flex.plot_precision_recall_curve()
73
+ # flex.plot_auc_scores()
74
+ # flex.plot_significant_complexes()
75
+ # flex.plot_percomplex_scatter(n_top=20)
76
+ # flex.plot_percomplex_scatter_bysize()
77
+ # flex.plot_complex_contributions()
78
+ # flex.plot_mpr_tp_multi(show_filters="all")
79
+ # flex.plot_mpr_complexes_multi(show_filters="all")
80
+ # flex.plot_mpr_complexes_auc_scores("all")
81
+
82
+
79
83
 
80
84
  #%%
81
85
  # Save results to CSV
82
- flex.save_results_to_csv()
86
+ # flex.save_results_to_csv()
83
87
 
84
- # %%
85
- flex.plot_mpr_complexes_multi(show_filters="no_mtRibo_ETCI")
86
- # %%
@@ -0,0 +1,112 @@
1
+ """
2
+ Basic usage example of the pythonFLEX package.
3
+ Demonstrates initialization, data loading, analysis, and plotting.
4
+ """
5
+ #%%
6
+ import pythonflex as flex
7
+ import pandas as pd
8
+
9
+ gene_effect = pd.read_csv('C:/Users/yd/Desktop/projects/_datasets/depmap/25Q2/gene_effect.csv', index_col=0)
10
+
11
+ skin = pd.read_csv('C:/Users/yd/Desktop/projects/_datasets/depmap/25Q2/subset/skin_cell_lines.csv', index_col=0)
12
+
13
+ soft = pd.read_csv('C:/Users/yd/Desktop/projects/_datasets/depmap/25Q2/subset/soft_tissue_cell_lines.csv', index_col=0)
14
+
15
+
16
+ cholesky = pd.read_csv('C:/Users/yd/Desktop/projects/_datasets/depmap/25Q2/25Q2_chronos_whitened_Cholesky.csv', index_col=0).T
17
+
18
+ # inputs = {
19
+ # "All Screens": {
20
+ # "path": gene_effect,
21
+ # "sort": "high",
22
+ # "color": "#000000"
23
+ # },
24
+ # "Skin": {
25
+ # "path": skin,
26
+ # "sort": "high",
27
+ # "color": "#FF0000"
28
+ # },
29
+ # "Soft Tissue": {
30
+ # "path": soft,
31
+ # "sort": "high",
32
+ # "color": "#FFFF00"
33
+ # },
34
+ # }
35
+
36
+
37
+ inputs = {
38
+ "DM All Screens": {
39
+ "path": gene_effect,
40
+ "sort": "high",
41
+ "color": "#000000"
42
+ },
43
+ "DM Cholesky Whitening": {
44
+ "path": cholesky,
45
+ "sort": "high",
46
+ "color": "#FF0000"
47
+ },
48
+
49
+ }
50
+
51
+
52
+
53
+
54
+ default_config = {
55
+ "min_genes_in_complex": 2,
56
+ "min_genes_per_complex_analysis": 3,
57
+ "output_folder": "CORUM_DMvsCholesky",
58
+ "gold_standard": "CORUM",
59
+ "color_map": "BuGn",
60
+ "jaccard": False,
61
+ "jaccard_threshold": 1.0,
62
+ "use_common_genes": False, # Set to False for individual dataset-gold standard intersections
63
+ "plotting": {
64
+ "save_plot": True,
65
+ "output_type": "pdf",
66
+ },
67
+ "preprocessing": {
68
+ "fill_na": True,
69
+ "normalize": False,
70
+ },
71
+ "corr_function": "numpy",
72
+ "logging": {
73
+ "visible_levels": ["DONE"]
74
+ # "PROGRESS", "STARTED", ,"INFO","WARNING"
75
+ }
76
+ }
77
+
78
+ # Initialize logger, config, and output folder
79
+ flex.initialize(default_config)
80
+
81
+ # Load datasets and gold standard terms
82
+ data, _ = flex.load_datasets(inputs)
83
+ terms, genes_in_terms = flex.load_gold_standard()
84
+
85
+ # Run analysis
86
+ for name, dataset in data.items():
87
+ pra = flex.pra(name, dataset, is_corr=False)
88
+ fpc = flex.pra_percomplex(name, dataset, is_corr=False)
89
+ cc = flex.complex_contributions(name)
90
+ flex.mpr_prepare(name)
91
+
92
+
93
+
94
+
95
+ #%%
96
+ # Generate plots
97
+ flex.plot_precision_recall_curve()
98
+ flex.plot_auc_scores()
99
+ flex.plot_significant_complexes()
100
+ flex.plot_percomplex_scatter(n_top=20)
101
+ flex.plot_percomplex_scatter_bysize()
102
+ flex.plot_complex_contributions()
103
+ ##
104
+ #%%
105
+ flex.plot_mpr_tp_multi(show_filters="all")
106
+ flex.plot_mpr_complexes_multi(show_filters="all")
107
+
108
+ # Save results to CSV
109
+ flex.save_results_to_csv()
110
+
111
+ # %%
112
+ # %%
@@ -1220,6 +1220,107 @@ def plot_auc_scores():
1220
1220
  plt.close(fig)
1221
1221
  return pra_dict
1222
1222
 
1223
+
1224
+ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1225
+ """Plot AUC scores for the mPR complexes curve (Fig 1F-style).
1226
+
1227
+ Requires `mpr_prepare()` to have been run for each dataset.
1228
+
1229
+ Parameters
1230
+ ----------
1231
+ filter_key : str
1232
+ One of: "all", "no_mtRibo_ETCI", "no_small_highAUPRC".
1233
+
1234
+ Returns
1235
+ -------
1236
+ pd.Series
1237
+ AUC values indexed by dataset name (sorted descending).
1238
+ """
1239
+ config = dload("config")
1240
+ plot_config = config["plotting"]
1241
+ mpr_auc_dict = dload("mpr_complexes_auc")
1242
+ input_colors = dload("input", "colors")
1243
+
1244
+ if input_colors:
1245
+ input_colors = {_sanitize(k): v for k, v in input_colors.items()}
1246
+
1247
+ if not isinstance(mpr_auc_dict, dict) or not mpr_auc_dict:
1248
+ log.warning(
1249
+ "No mPR complexes AUC data found. Run mpr_prepare() first (it stores 'mpr_complexes_auc')."
1250
+ )
1251
+ return pd.Series(dtype=float)
1252
+
1253
+ # Build Series: dataset -> auc
1254
+ auc_by_dataset = {}
1255
+ for dataset, per_filter in mpr_auc_dict.items():
1256
+ if not isinstance(per_filter, dict):
1257
+ continue
1258
+ val = per_filter.get(filter_key)
1259
+ if val is None:
1260
+ continue
1261
+ try:
1262
+ auc_by_dataset[dataset] = float(val)
1263
+ except (TypeError, ValueError):
1264
+ continue
1265
+
1266
+ if not auc_by_dataset:
1267
+ log.warning(
1268
+ f"No mPR complexes AUC scores found for filter '{filter_key}'. Available filters: {list(FILTER_STYLES.keys())}"
1269
+ )
1270
+ return pd.Series(dtype=float)
1271
+
1272
+ s = pd.Series(auc_by_dataset).sort_values(ascending=False)
1273
+ datasets = list(s.index)
1274
+ auc_scores = list(s.values)
1275
+
1276
+ fig, ax = plt.subplots()
1277
+
1278
+ # Color logic (match other bar plots)
1279
+ cmap_name = config.get("color_map", "tab10")
1280
+ try:
1281
+ cmap = get_cmap(cmap_name)
1282
+ except ValueError:
1283
+ cmap = get_cmap("tab10")
1284
+
1285
+ num_datasets = len(datasets)
1286
+ if num_datasets <= 10 and cmap_name == "tab10":
1287
+ default_colors = [cmap(i) for i in range(num_datasets)]
1288
+ else:
1289
+ default_colors = [cmap(float(i) / max(num_datasets - 1, 1)) for i in range(num_datasets)]
1290
+
1291
+ final_colors = []
1292
+ for i, dataset in enumerate(datasets):
1293
+ color = input_colors.get(dataset) if input_colors else None
1294
+ if color is None:
1295
+ color = default_colors[i]
1296
+ final_colors.append(color)
1297
+
1298
+ ax.bar(datasets, auc_scores, color=final_colors, edgecolor="black")
1299
+
1300
+ ymax = max([v for v in auc_scores if np.isfinite(v)], default=0.0)
1301
+ ax.set_ylim(0, ymax + 0.01)
1302
+ ax.set_ylabel("mPR complexes AUC")
1303
+ plt.xticks(rotation=45, ha="right")
1304
+
1305
+ # Styling consistent with other plots
1306
+ ax.grid(visible=False, which="both", axis="both")
1307
+ ax.set_axisbelow(False)
1308
+ ax.spines["top"].set_visible(False)
1309
+ ax.spines["right"].set_visible(False)
1310
+
1311
+ if plot_config.get("save_plot", False):
1312
+ output_type = plot_config.get("output_type", "pdf")
1313
+ output_folder = Path(config["output_folder"])
1314
+ output_folder.mkdir(parents=True, exist_ok=True)
1315
+ output_path = output_folder / f"mpr_complexes_auc_{filter_key}.{output_type}"
1316
+ plt.savefig(output_path, bbox_inches="tight", format=output_type)
1317
+
1318
+ if plot_config.get("show_plot", True):
1319
+ plt.show()
1320
+
1321
+ plt.close(fig)
1322
+ return s
1323
+
1223
1324
  # -----------------------------------------------------------------------------
1224
1325
  # mPR plots (Fig. 1E and Fig. 1F)
1225
1326
  # -----------------------------------------------------------------------------
@@ -1620,6 +1721,8 @@ def plot_mpr_complexes_multi(
1620
1721
  outname=None,
1621
1722
  linewidth=1.8,
1622
1723
  show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
1724
+ show_markers="auto",
1725
+ marker_size=20,
1623
1726
  ):
1624
1727
  """
1625
1728
  Plot module-level PR (#complexes vs precision) for multiple datasets.
@@ -1644,6 +1747,11 @@ def plot_mpr_complexes_multi(
1644
1747
  Line width for all curves
1645
1748
  show_filters : tuple of str
1646
1749
  Which filters to show. Default is all three.
1750
+ show_markers : bool or "auto"
1751
+ If True, draw markers on curves to make short curves visible.
1752
+ If "auto" (default), markers are drawn only for curves with <= 10 points.
1753
+ marker_size : int
1754
+ Scatter marker size (points^2) when markers are shown.
1647
1755
 
1648
1756
  Returns
1649
1757
  -------
@@ -1730,13 +1838,26 @@ def plot_mpr_complexes_multi(
1730
1838
  prec_plot = precision_cutoffs[mask]
1731
1839
 
1732
1840
  style = FILTER_STYLES.get(filter_key, {})
1733
- ax.plot(
1734
- cov_plot,
1735
- prec_plot,
1736
- color=color,
1737
- linestyle=style.get("linestyle", "-"),
1738
- linewidth=linewidth,
1739
- )
1841
+
1842
+ # Decide marker visibility
1843
+ if show_markers == "auto":
1844
+ use_markers = (cov_plot.size <= 10)
1845
+ else:
1846
+ use_markers = bool(show_markers)
1847
+
1848
+ if cov_plot.size == 1:
1849
+ # A single point is effectively invisible as a line; draw a marker.
1850
+ ax.scatter(cov_plot, prec_plot, color=color, s=marker_size, zorder=3)
1851
+ else:
1852
+ ax.plot(
1853
+ cov_plot,
1854
+ prec_plot,
1855
+ color=color,
1856
+ linestyle=style.get("linestyle", "-"),
1857
+ linewidth=linewidth,
1858
+ marker=("o" if use_markers else None),
1859
+ markersize=(3 if use_markers else None),
1860
+ )
1740
1861
 
1741
1862
  # Configure axes
1742
1863
  ax.set_xscale("log")
@@ -189,7 +189,18 @@ def load_gold_standard():
189
189
  use_common_genes = config.get("use_common_genes", True)
190
190
 
191
191
  gold_standard_source = config['gold_standard']
192
- log.started(f"Loading gold standard: {gold_standard_source}, Min complex size: {config['min_genes_in_complex']}, Jaccard filtering: {config['jaccard']}, use_common_genes: {use_common_genes}")
192
+ jaccard_enabled = bool(config.get("jaccard", False))
193
+ jaccard_threshold_raw = config.get("jaccard_threshold", 1.0)
194
+ try:
195
+ jaccard_threshold = float(jaccard_threshold_raw) # type: ignore[arg-type]
196
+ except (TypeError, ValueError):
197
+ raise ValueError(
198
+ f"config['jaccard_threshold'] must be a number in (0, 1], got {jaccard_threshold_raw!r}"
199
+ )
200
+ log.done(
201
+ f"Loading gold standard: {gold_standard_source}, Min complex size: {config['min_genes_in_complex']}, "
202
+ f"Jaccard filtering: {jaccard_enabled} (threshold={jaccard_threshold}), use_common_genes: {use_common_genes}"
203
+ )
193
204
 
194
205
  # Define gold standard file paths for predefined sources
195
206
  gold_standard_files = {
@@ -217,34 +228,44 @@ def load_gold_standard():
217
228
 
218
229
  # Store raw gold standard for later per-dataset filtering
219
230
  terms["all_genes"] = terms["Genes"].apply(lambda x: list(set(x.split(";"))))
220
- log.info(f"Gold standard loaded with {len(terms)} terms")
231
+ log.done(f"Gold standard loaded with {len(terms)} terms")
221
232
 
222
233
  # Basic filtering by minimum complex size (before gene filtering)
223
234
  terms["n_all_genes"] = terms["all_genes"].apply(len)
224
235
  terms = terms[terms["n_all_genes"] >= config['min_genes_in_complex']]
225
- log.info(f"After min_genes_in_complex filtering: {len(terms)} terms")
226
-
227
- if config['jaccard']:
228
- log.info("Applying Jaccard filtering. Remove terms with identical gene sets.")
229
- # Use all genes for jaccard filtering
230
- terms["gene_set"] = terms["all_genes"].map(lambda x: frozenset(x))
231
- grouped = terms.groupby("gene_set", sort=False)
232
- duplicate_clusters = []
233
- for _, group in grouped:
234
- if len(group) > 1:
235
- duplicate_clusters.append(group["ID"].values if "ID" in group.columns else group.index.values)
236
-
237
- keep_ids = set(terms["ID"] if "ID" in terms.columns else terms.index)
238
- for cluster in duplicate_clusters:
239
- sorted_ids = sorted(cluster)
240
- keep_ids.difference_update(sorted_ids[1:])
241
-
242
- if "ID" in terms.columns:
243
- terms = terms[terms["ID"].isin(keep_ids)].copy()
236
+ log.done(f"After min_genes_in_complex filtering: {len(terms)} terms")
237
+
238
+ if jaccard_enabled:
239
+ if not (0.0 < jaccard_threshold <= 1.0):
240
+ raise ValueError(f"config['jaccard_threshold'] must be in (0, 1], got {jaccard_threshold}")
241
+
242
+ if jaccard_threshold >= 1.0:
243
+ log.done("Applying Jaccard filtering (threshold=1.0). Removing terms with identical gene sets.")
244
+ # Use all genes for jaccard filtering
245
+ terms["gene_set"] = terms["all_genes"].map(lambda x: frozenset(x))
246
+ grouped = terms.groupby("gene_set", sort=False)
247
+ duplicate_clusters = []
248
+ for _, group in grouped:
249
+ if len(group) > 1:
250
+ duplicate_clusters.append(group["ID"].values if "ID" in group.columns else group.index.values)
251
+
252
+ keep_ids = set(terms["ID"] if "ID" in terms.columns else terms.index)
253
+ for cluster in duplicate_clusters:
254
+ sorted_ids = sorted(cluster)
255
+ keep_ids.difference_update(sorted_ids[1:])
256
+
257
+ if "ID" in terms.columns:
258
+ terms = terms[terms["ID"].isin(keep_ids)].copy()
259
+ else:
260
+ terms = terms[terms.index.isin(keep_ids)].copy()
261
+ terms.drop(columns=["gene_set"], inplace=True, errors="ignore")
262
+ log.done(f"After Jaccard filtering: {len(terms)} terms")
244
263
  else:
245
- terms = terms[terms.index.isin(keep_ids)].copy()
246
- terms.drop(columns=["gene_set"], inplace=True, errors='ignore')
247
- log.info(f"After Jaccard filtering: {len(terms)} terms")
264
+ log.done(
265
+ f"Applying Jaccard filtering (threshold={jaccard_threshold}). Removing highly similar terms."
266
+ )
267
+ terms = _filter_terms_by_jaccard_threshold(terms, threshold=jaccard_threshold, genes_col="all_genes")
268
+ log.done(f"After Jaccard filtering: {len(terms)} terms")
248
269
 
249
270
  # if there is column called "ID", set it as index
250
271
  if "ID" in terms.columns:
@@ -255,6 +276,149 @@ def load_gold_standard():
255
276
  return terms, None # Return None for genes_present_in_terms - will be computed per dataset
256
277
 
257
278
 
279
+ class _UnionFind:
280
+ def __init__(self, n: int):
281
+ self.parent = list(range(n))
282
+ self.rank = [0] * n
283
+
284
+ def find(self, x: int) -> int:
285
+ while self.parent[x] != x:
286
+ self.parent[x] = self.parent[self.parent[x]]
287
+ x = self.parent[x]
288
+ return x
289
+
290
+ def union(self, a: int, b: int) -> None:
291
+ ra, rb = self.find(a), self.find(b)
292
+ if ra == rb:
293
+ return
294
+ if self.rank[ra] < self.rank[rb]:
295
+ self.parent[ra] = rb
296
+ elif self.rank[ra] > self.rank[rb]:
297
+ self.parent[rb] = ra
298
+ else:
299
+ self.parent[rb] = ra
300
+ self.rank[ra] += 1
301
+
302
+
303
+ def _safe_id_sort_key(val):
304
+ """Sort key that prefers numeric ordering when IDs look like ints."""
305
+ try:
306
+ return (0, int(val))
307
+ except Exception:
308
+ return (1, str(val))
309
+
310
+
311
+ def _jaccard_similarity(a: set, b: set) -> float:
312
+ if not a and not b:
313
+ return 1.0
314
+ if not a or not b:
315
+ return 0.0
316
+ inter = len(a.intersection(b))
317
+ if inter == 0:
318
+ return 0.0
319
+ union = len(a) + len(b) - inter
320
+ return inter / union
321
+
322
+
323
+ def _filter_terms_by_jaccard_threshold(terms: pd.DataFrame, threshold: float, genes_col: str = "all_genes") -> pd.DataFrame:
324
+ """Remove near-duplicate terms whose gene sets have Jaccard similarity >= threshold.
325
+
326
+ Keeps one representative per similarity-connected component (smallest ID).
327
+ This uses an exact Jaccard similarity join with prefix-filter candidate generation.
328
+ """
329
+ if not (0.0 < threshold < 1.0):
330
+ # threshold == 1.0 handled elsewhere; invalid values rejected earlier
331
+ return terms
332
+
333
+ # Build IDs and gene sets
334
+ id_col = "ID" if "ID" in terms.columns else None
335
+ term_ids = (terms["ID"].tolist() if id_col else terms.index.tolist())
336
+ gene_sets = []
337
+ for genes in terms[genes_col].tolist():
338
+ gene_sets.append(set(genes))
339
+
340
+ sizes = [len(s) for s in gene_sets]
341
+ if len(gene_sets) <= 1:
342
+ return terms
343
+
344
+ # Global token frequency for ordering (rare tokens first)
345
+ from collections import Counter, defaultdict
346
+ freq = Counter()
347
+ for s in gene_sets:
348
+ freq.update(s)
349
+
350
+ def sort_tokens(s: set):
351
+ return sorted(s, key=lambda tok: (freq.get(tok, 0), str(tok)))
352
+
353
+ # Process smaller sets first (helps size filtering and keeps index smaller)
354
+ order = sorted(range(len(gene_sets)), key=lambda i: (sizes[i], _safe_id_sort_key(term_ids[i])))
355
+ ordered_tokens = [sort_tokens(gene_sets[i]) for i in range(len(gene_sets))]
356
+
357
+ # Inverted index over prefix tokens
358
+ inv_index = defaultdict(list) # token -> list of processed term indices (original idx)
359
+
360
+ uf = _UnionFind(len(gene_sets))
361
+
362
+ # Precompute prefix lengths
363
+ import math
364
+ prefix_len = []
365
+ for i in range(len(gene_sets)):
366
+ m = sizes[i]
367
+ # PPJoin prefix length for Jaccard threshold
368
+ p = m - math.ceil(threshold * m) + 1
369
+ if p < 0:
370
+ p = 0
371
+ if p > m:
372
+ p = m
373
+ prefix_len.append(p)
374
+
375
+ # Candidate generation + exact verification
376
+ for idx_pos, i in enumerate(order):
377
+ tokens_i = ordered_tokens[i]
378
+ p_i = prefix_len[i]
379
+
380
+ # Count shared prefix tokens with previously indexed sets
381
+ candidate_overlap_lb = defaultdict(int)
382
+ for tok in tokens_i[:p_i]:
383
+ for j in inv_index.get(tok, []):
384
+ # size filter: if too different in size, cannot meet Jaccard threshold
385
+ if sizes[j] < threshold * sizes[i]:
386
+ continue
387
+ if sizes[j] > sizes[i] / threshold:
388
+ continue
389
+ candidate_overlap_lb[j] += 1
390
+ inv_index[tok].append(i)
391
+
392
+ if not candidate_overlap_lb:
393
+ continue
394
+
395
+ set_i = gene_sets[i]
396
+ for j in candidate_overlap_lb.keys():
397
+ # Exact verification
398
+ sim = _jaccard_similarity(set_i, gene_sets[j])
399
+ if sim >= threshold:
400
+ uf.union(i, j)
401
+
402
+ # Choose representative (smallest ID) for each connected component
403
+ components = {}
404
+ for i in range(len(gene_sets)):
405
+ root = uf.find(i)
406
+ components.setdefault(root, []).append(i)
407
+
408
+ keep_original_indices = set()
409
+ for members in components.values():
410
+ # Keep smallest ID among members
411
+ keep = min(members, key=lambda k: _safe_id_sort_key(term_ids[k]))
412
+ keep_original_indices.add(keep)
413
+
414
+ if id_col:
415
+ keep_ids = {term_ids[i] for i in keep_original_indices}
416
+ return terms[terms["ID"].isin(keep_ids)].copy()
417
+ else:
418
+ keep_index = {term_ids[i] for i in keep_original_indices}
419
+ return terms[terms.index.isin(keep_index)].copy()
420
+
421
+
258
422
 
259
423
 
260
424
 
File without changes
File without changes
File without changes
File without changes