pythonflex 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pythonflex/__init__.py CHANGED
@@ -3,9 +3,9 @@ from .utils import dsave, dload
3
3
  from .preprocessing import get_example_data_path, load_datasets, get_common_genes, filter_matrix_by_genes, load_gold_standard, filter_duplicate_terms
4
4
  from .analysis import initialize, pra, pra_percomplex, fast_corr, perform_corr, is_symmetric, binary, has_mirror_of_first_pair, convert_full_to_half_matrix, drop_mirror_pairs, quick_sort, complex_contributions, save_results_to_csv, update_matploblib_config, mpr_prepare
5
5
  from .plotting import (
6
- adjust_text_positions, plot_precision_recall_curve, plot_percomplex_scatter,
6
+ adjust_text_positions, plot_precision_recall_curve, plot_aggregated_pra, plot_iqr_pra, plot_all_runs_pra, plot_percomplex_scatter,
7
7
  plot_percomplex_scatter_bysize, plot_complex_contributions, plot_significant_complexes, plot_auc_scores,
8
- plot_mpr_tp, plot_mpr_complexes, plot_mpr_tp_multi, plot_mpr_complexes_multi
8
+ plot_mpr_tp, plot_mpr_complexes, plot_mpr_tp_multi, plot_mpr_complexes_multi, plot_mpr_complexes_auc_scores
9
9
  )
10
10
 
11
11
  __all__ = [ "log", "get_example_data_path", "fast_corr",
@@ -13,8 +13,8 @@ __all__ = [ "log", "get_example_data_path", "fast_corr",
13
13
  "filter_matrix_by_genes", "load_gold_standard", "filter_duplicate_terms", "pra", "pra_percomplex",
14
14
  "perform_corr", "is_symmetric", "binary", "has_mirror_of_first_pair", "convert_full_to_half_matrix",
15
15
  "drop_mirror_pairs", "quick_sort", "complex_contributions", "adjust_text_positions", "plot_precision_recall_curve",
16
- "plot_percomplex_scatter", "plot_percomplex_scatter_bysize", "plot_complex_contributions",
17
- "plot_significant_complexes", "plot_auc_scores", "save_results_to_csv", "update_matploblib_config",
16
+ "plot_aggregated_pra", "plot_iqr_pra", "plot_all_runs_pra", "plot_percomplex_scatter", "plot_percomplex_scatter_bysize", "plot_complex_contributions",
17
+ "plot_significant_complexes", "plot_auc_scores", "plot_mpr_complexes_auc_scores", "save_results_to_csv", "update_matploblib_config",
18
18
  "mpr_prepare", "plot_mpr_tp", "plot_mpr_complexes",
19
19
  "plot_mpr_tp_multi", "plot_mpr_complexes_multi"
20
20
  ]
pythonflex/analysis.py CHANGED
@@ -844,7 +844,7 @@ def quick_sort(df, ascending=False):
844
844
  log.done("Pair-wise matrix sorting.")
845
845
  return sorted_df
846
846
 
847
- def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_percomplex"]):
847
+ def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_percomplex", "mpr_complexes_auc"]):
848
848
 
849
849
  config = dload("config") # Load config to get output folder
850
850
  output_folder = Path(config.get("output_folder", "output"))
@@ -856,6 +856,18 @@ def save_results_to_csv(categories = ["complex_contributions", "pr_auc", "pra_pe
856
856
  if data is None:
857
857
  log.warning(f"No data found for category '{category}'. Skipping save.")
858
858
  continue
859
+
860
+ if category == "mpr_complexes_auc" and isinstance(data, dict):
861
+ # Dict[dataset_name -> Dict[filter_key -> auc]]
862
+ try:
863
+ df = pd.DataFrame.from_dict(data, orient="index")
864
+ df.index.name = "Dataset"
865
+ csv_path = output_folder / f"{category}.csv"
866
+ df.to_csv(csv_path, index=True)
867
+ log.info(f"Saved '{category}' to {csv_path}")
868
+ except Exception as e:
869
+ log.warning(f"Failed to convert and save '{category}': {e}")
870
+ continue
859
871
 
860
872
  if category == "pr_auc" and isinstance(data, dict):
861
873
  # Special handling: Convert dict to DataFrame (assuming keys are indices, values are data)
@@ -1312,6 +1324,64 @@ def _mpr_module_coverage(contrib_df, terms, tp_th=1, percent_th=0.1):
1312
1324
  return coverage
1313
1325
 
1314
1326
 
1327
+ def _mpr_complexes_auc(
1328
+ coverage: np.ndarray,
1329
+ precision_cutoffs: np.ndarray,
1330
+ max_complexes: float = 200.0,
1331
+ ) -> float:
1332
+ """Compute AUC for the Fig. 1F-style mPR curve (#complexes vs precision).
1333
+
1334
+ The plot uses:
1335
+ x = #covered complexes (capped at `max_complexes`, shown on a log axis)
1336
+ y = precision cutoff
1337
+
1338
+ We compute a normalized AUC by integrating precision over the *normalized*
1339
+ coverage axis:
1340
+ AUC = \int y \, d(x/max_complexes)
1341
+
1342
+ This yields a score in [0, 1] (or NaN if insufficient data).
1343
+ """
1344
+ cov = np.asarray(coverage, dtype=float)
1345
+ prec = np.asarray(precision_cutoffs, dtype=float)
1346
+
1347
+ if cov.size == 0 or prec.size == 0:
1348
+ return 0.0
1349
+
1350
+ # Match plot_mpr_complexes_multi(): only count cov>0 (log-x cannot show 0)
1351
+ mask = (
1352
+ np.isfinite(cov)
1353
+ & np.isfinite(prec)
1354
+ & (cov > 0)
1355
+ & (cov <= max_complexes)
1356
+ & (prec >= 0)
1357
+ & (prec <= 1.0)
1358
+ )
1359
+ if not np.any(mask):
1360
+ return 0.0
1361
+
1362
+ x_cov = cov[mask]
1363
+ y = prec[mask]
1364
+
1365
+ # x-axis is log-scaled in the plot; normalize so cov=1 -> 0, cov=max_complexes -> 1
1366
+ # (This matches the plot's tick hack where 1 is labeled as "0".)
1367
+ x = np.log10(x_cov) / np.log10(float(max_complexes))
1368
+
1369
+ # Sort by x and collapse duplicate x values by taking max y (upper envelope)
1370
+ order = np.argsort(x)
1371
+ x = x[order]
1372
+ y = y[order]
1373
+
1374
+ x_unique = np.unique(x)
1375
+ if x_unique.size != x.size:
1376
+ y = np.array([float(np.nanmax(y[x == xv])) for xv in x_unique], dtype=float)
1377
+ x = x_unique
1378
+
1379
+ if x.size < 2:
1380
+ return 0.0
1381
+
1382
+ return float(np.trapz(y, x))
1383
+
1384
+
1315
1385
 
1316
1386
 
1317
1387
 
@@ -1379,6 +1449,7 @@ def mpr_prepare(
1379
1449
 
1380
1450
  tp_curves = {}
1381
1451
  coverage_curves = {}
1452
+ complexes_auc = {}
1382
1453
  precision_cutoffs = None
1383
1454
 
1384
1455
  for label, removed in filter_sets.items():
@@ -1393,6 +1464,7 @@ def mpr_prepare(
1393
1464
  "precision": np.array([], dtype=float),
1394
1465
  }
1395
1466
  coverage_curves[label] = np.zeros(0, dtype=float)
1467
+ complexes_auc[label] = float("nan")
1396
1468
  continue
1397
1469
 
1398
1470
  tp_cum = true.cumsum()
@@ -1417,11 +1489,17 @@ def mpr_prepare(
1417
1489
  percent_th=percent_th,
1418
1490
  )
1419
1491
  coverage_curves[label] = cov
1492
+ complexes_auc[label] = _mpr_complexes_auc(
1493
+ cov,
1494
+ precision_cutoffs,
1495
+ max_complexes=200.0,
1496
+ )
1420
1497
 
1421
1498
  mpr_data = {
1422
1499
  "precision_cutoffs": precision_cutoffs,
1423
1500
  "tp_curves": tp_curves,
1424
1501
  "coverage_curves": coverage_curves,
1502
+ "complexes_auc": complexes_auc,
1425
1503
  "filters": {
1426
1504
  "no_mtRibo_ETCI": sorted(mtRibo_ids),
1427
1505
  "no_small_highAUPRC": sorted(small_hi_ids),
@@ -1435,6 +1513,9 @@ def mpr_prepare(
1435
1513
 
1436
1514
  dsave(mpr_data, "mpr", name)
1437
1515
 
1516
+ # Convenience: store AUCs as their own category for easy export / plotting.
1517
+ dsave(complexes_auc, "mpr_complexes_auc", name)
1518
+
1438
1519
 
1439
1520
 
1440
1521
  ### OLD FUNCTIONS
@@ -8,32 +8,34 @@ import pythonflex as flex
8
8
  inputs = {
9
9
  "Melanoma (63 Screens)": {
10
10
  "path": flex.get_example_data_path("melanoma_cell_lines_500_genes.csv"),
11
- "sort": "high"
11
+ "sort": "high",
12
+ "color": "#FF0000"
12
13
  },
13
14
  "Liver (24 Screens)": {
14
15
  "path": flex.get_example_data_path("liver_cell_lines_500_genes.csv"),
15
- "sort": "high"
16
+ "sort": "high",
17
+ "color": "#FFDD00"
16
18
  },
17
19
  "Neuroblastoma (37 Screens)": {
18
20
  "path": flex.get_example_data_path("neuroblastoma_cell_lines_500_genes.csv"),
19
- "sort": "high"
21
+ "sort": "high",
22
+ "color": "#FFDDDD"
20
23
  },
21
24
  }
22
25
 
23
26
 
24
27
 
25
- #%%
26
28
  default_config = {
27
29
  "min_genes_in_complex": 0,
28
30
  "min_genes_per_complex_analysis": 3,
29
- "output_folder": "output",
31
+ "output_folder": "CORUM",
30
32
  "gold_standard": "CORUM",
31
- "color_map": "RdYlBu",
32
- "jaccard": True,
33
+ "color_map": "BuGn",
34
+ "jaccard": False,
33
35
  "use_common_genes": False, # Set to False for individual dataset-gold standard intersections
34
36
  "plotting": {
35
37
  "save_plot": True,
36
- "output_type": "pdf",
38
+ "output_type": "png",
37
39
  },
38
40
  "preprocessing": {
39
41
  "fill_na": True,
@@ -41,7 +43,8 @@ default_config = {
41
43
  },
42
44
  "corr_function": "numpy",
43
45
  "logging": {
44
- "visible_levels": ["DONE","STARTED"] # "PROGRESS", "STARTED", ,"INFO","WARNING"
46
+ "visible_levels": ["DONE"]
47
+ # "PROGRESS", "STARTED", ,"INFO","WARNING"
45
48
  }
46
49
  }
47
50
 
@@ -52,26 +55,33 @@ flex.initialize(default_config)
52
55
  data, _ = flex.load_datasets(inputs)
53
56
  terms, genes_in_terms = flex.load_gold_standard()
54
57
 
55
-
56
- #%%
57
58
  # Run analysis
58
59
  for name, dataset in data.items():
59
60
  pra = flex.pra(name, dataset, is_corr=False)
60
61
  fpc = flex.pra_percomplex(name, dataset, is_corr=False)
61
62
  cc = flex.complex_contributions(name)
62
-
63
+ flex.mpr_prepare(name)
64
+
65
+
63
66
 
64
67
 
65
68
  #%%
66
69
  # Generate plots
67
- flex.plot_auc_scores()
68
- flex.plot_precision_recall_curve()
69
- flex.plot_percomplex_scatter(n_top=20)
70
- flex.plot_percomplex_scatter_bysize()
71
- flex.plot_significant_complexes()
72
- flex.plot_complex_contributions()
73
-
70
+ # flex.plot_precision_recall_curve()
71
+ # flex.plot_auc_scores()
72
+ # flex.plot_significant_complexes()
73
+ # flex.plot_percomplex_scatter(n_top=20)
74
+ # flex.plot_percomplex_scatter_bysize()
75
+ # flex.plot_complex_contributions()
76
+ #%%
77
+ #flex.plot_mpr_tp_multi(show_filters="all")
78
+ flex.plot_mpr_complexes_multi(show_filters="all")
74
79
 
75
80
  #%%
76
81
  # Save results to CSV
77
82
  flex.save_results_to_csv()
83
+
84
+
85
+ # %%
86
+ flex.plot_mpr_complexes_auc_scores("all")
87
+ # %%
@@ -0,0 +1,111 @@
1
+ """
2
+ Basic usage example of the pythonFLEX package.
3
+ Demonstrates initialization, data loading, analysis, and plotting.
4
+ """
5
+ #%%
6
+ import pythonflex as flex
7
+ import pandas as pd
8
+
9
+ gene_effect = pd.read_csv('C:/Users/yd/Desktop/projects/_datasets/depmap/25Q2/gene_effect.csv', index_col=0)
10
+
11
+ skin = pd.read_csv('C:/Users/yd/Desktop/projects/_datasets/depmap/25Q2/subset/skin_cell_lines.csv', index_col=0)
12
+
13
+ soft = pd.read_csv('C:/Users/yd/Desktop/projects/_datasets/depmap/25Q2/subset/soft_tissue_cell_lines.csv', index_col=0)
14
+
15
+
16
+ cholesky = pd.read_csv('C:/Users/yd/Desktop/projects/_datasets/depmap/25Q2/25Q2_chronos_whitened_Cholesky.csv', index_col=0).T
17
+
18
+ # inputs = {
19
+ # "All Screens": {
20
+ # "path": gene_effect,
21
+ # "sort": "high",
22
+ # "color": "#000000"
23
+ # },
24
+ # "Skin": {
25
+ # "path": skin,
26
+ # "sort": "high",
27
+ # "color": "#FF0000"
28
+ # },
29
+ # "Soft Tissue": {
30
+ # "path": soft,
31
+ # "sort": "high",
32
+ # "color": "#FFFF00"
33
+ # },
34
+ # }
35
+
36
+
37
+ inputs = {
38
+ "DM All Screens": {
39
+ "path": gene_effect,
40
+ "sort": "high",
41
+ "color": "#000000"
42
+ },
43
+ "DM Cholesky Whitening": {
44
+ "path": cholesky,
45
+ "sort": "high",
46
+ "color": "#FF0000"
47
+ },
48
+
49
+ }
50
+
51
+
52
+
53
+
54
+ default_config = {
55
+ "min_genes_in_complex": 2,
56
+ "min_genes_per_complex_analysis": 3,
57
+ "output_folder": "CORUM_DMvsCholesky",
58
+ "gold_standard": "CORUM",
59
+ "color_map": "BuGn",
60
+ "jaccard": False,
61
+ "use_common_genes": False, # Set to False for individual dataset-gold standard intersections
62
+ "plotting": {
63
+ "save_plot": True,
64
+ "output_type": "pdf",
65
+ },
66
+ "preprocessing": {
67
+ "fill_na": True,
68
+ "normalize": False,
69
+ },
70
+ "corr_function": "numpy",
71
+ "logging": {
72
+ "visible_levels": ["DONE"]
73
+ # "PROGRESS", "STARTED", ,"INFO","WARNING"
74
+ }
75
+ }
76
+
77
+ # Initialize logger, config, and output folder
78
+ flex.initialize(default_config)
79
+
80
+ # Load datasets and gold standard terms
81
+ data, _ = flex.load_datasets(inputs)
82
+ terms, genes_in_terms = flex.load_gold_standard()
83
+
84
+ # Run analysis
85
+ for name, dataset in data.items():
86
+ pra = flex.pra(name, dataset, is_corr=False)
87
+ fpc = flex.pra_percomplex(name, dataset, is_corr=False)
88
+ cc = flex.complex_contributions(name)
89
+ flex.mpr_prepare(name)
90
+
91
+
92
+
93
+
94
+ #%%
95
+ # Generate plots
96
+ flex.plot_precision_recall_curve()
97
+ flex.plot_auc_scores()
98
+ flex.plot_significant_complexes()
99
+ flex.plot_percomplex_scatter(n_top=20)
100
+ flex.plot_percomplex_scatter_bysize()
101
+ flex.plot_complex_contributions()
102
+ ##
103
+ #%%
104
+ flex.plot_mpr_tp_multi(show_filters="all")
105
+ flex.plot_mpr_complexes_multi(show_filters="all")
106
+
107
+ # Save results to CSV
108
+ flex.save_results_to_csv()
109
+
110
+ # %%
111
+ # %%
pythonflex/plotting.py CHANGED
@@ -114,6 +114,171 @@ def plot_precision_recall_curve(line_width=2.0, hide_minor_ticks=True):
114
114
  plt.show()
115
115
  plt.close(fig)
116
116
 
117
+ def plot_aggregated_pra(agg_df, line_width=2.0, hide_minor_ticks=True):
118
+ """
119
+ Plots an aggregated Precision-Recall curve with mean line and min-max shading.
120
+ agg_df should be indexed by 'tp' and contain 'mean', 'min', 'max' columns for precision.
121
+ """
122
+ config = dload("config")
123
+ plot_config = config["plotting"]
124
+
125
+ # Increase figure width to accommodate external legend without squashing axes
126
+ fig, ax = plt.subplots(figsize=(6, 4))
127
+
128
+ # Adjust layout to make room for legend on the right
129
+ plt.subplots_adjust(right=0.7)
130
+
131
+ ax.set_xscale("log")
132
+
133
+ # optionally hide minor ticks on the log axis
134
+ if hide_minor_ticks:
135
+ ax.xaxis.set_minor_locator(NullLocator())
136
+ ax.xaxis.set_minor_formatter(NullFormatter())
137
+
138
+ # Filter out very low TP counts if necessary, similar to plot_precision_recall_curve
139
+ agg_df = agg_df[agg_df.index > 10]
140
+
141
+ tp = agg_df.index
142
+ mean_prec = agg_df['mean']
143
+ min_prec = agg_df['min']
144
+ max_prec = agg_df['max']
145
+
146
+ # Plot shading
147
+ ax.fill_between(tp, min_prec, max_prec, color='gray', alpha=0.3, label='Range (Min-Max)')
148
+
149
+ # Plot mean line
150
+ ax.plot(tp, mean_prec, c="black", label="Mean Precision", linewidth=line_width, alpha=0.9)
151
+
152
+ ax.set(title="",
153
+ xlabel="Number of True Positives (TP)",
154
+ ylabel="Precision")
155
+ ax.legend(loc="upper left", bbox_to_anchor=(1.05, 1), frameon=False)
156
+ ax.set_ylim(0, 1)
157
+
158
+ # Nature style: no grid, open top/right spines
159
+ ax.grid(False)
160
+ ax.spines['top'].set_visible(False)
161
+ ax.spines['right'].set_visible(False)
162
+
163
+ if plot_config["save_plot"]:
164
+ output_type = plot_config["output_type"]
165
+ output_path = Path(config["output_folder"]) / f"aggregated_precision_recall_curve.{output_type}"
166
+ fig.savefig(output_path, bbox_inches="tight", format=output_type)
167
+
168
+ if plot_config.get("show_plot", True):
169
+ plt.show()
170
+ plt.close(fig)
171
+
172
+ def plot_iqr_pra(agg_df, line_width=2.0, hide_minor_ticks=True):
173
+ """
174
+ Plots an aggregated Precision-Recall curve with mean line and IQR (25-75%) shading.
175
+ agg_df should be indexed by 'tp' and contain 'mean', '25%', '75%' columns for precision.
176
+ """
177
+ config = dload("config")
178
+ plot_config = config["plotting"]
179
+
180
+ # Increase figure width to accommodate external legend without squashing axes
181
+ fig, ax = plt.subplots(figsize=(6, 4))
182
+
183
+ # Adjust layout to make room for legend on the right
184
+ plt.subplots_adjust(right=0.7)
185
+
186
+ ax.set_xscale("log")
187
+
188
+ # optionally hide minor ticks on the log axis
189
+ if hide_minor_ticks:
190
+ ax.xaxis.set_minor_locator(NullLocator())
191
+ ax.xaxis.set_minor_formatter(NullFormatter())
192
+
193
+ # Filter out very low TP counts
194
+ agg_df = agg_df[agg_df.index > 10]
195
+
196
+ tp = agg_df.index
197
+ mean_prec = agg_df['mean']
198
+ q25_prec = agg_df['25%']
199
+ q75_prec = agg_df['75%']
200
+
201
+ # Plot shading
202
+ ax.fill_between(tp, q25_prec, q75_prec, color='gray', alpha=0.3, label='IQR (25-75%)')
203
+
204
+ # Plot mean line
205
+ ax.plot(tp, mean_prec, c="black", label="Mean Precision", linewidth=line_width, alpha=0.9)
206
+
207
+ ax.set(title="Precision-Recall (IQR)",
208
+ xlabel="Number of True Positives (TP)",
209
+ ylabel="Precision")
210
+ ax.legend(loc="upper left", bbox_to_anchor=(1.05, 1), frameon=False)
211
+ ax.set_ylim(0, 1)
212
+
213
+ # Nature style
214
+ ax.grid(False)
215
+ ax.spines['top'].set_visible(False)
216
+ ax.spines['right'].set_visible(False)
217
+
218
+ if plot_config["save_plot"]:
219
+ output_type = plot_config["output_type"]
220
+ output_path = Path(config["output_folder"]) / f"aggregated_iqr_precision_recall_curve.{output_type}"
221
+ fig.savefig(output_path, bbox_inches="tight", format=output_type)
222
+
223
+ if plot_config.get("show_plot", True):
224
+ plt.show()
225
+ plt.close(fig)
226
+
227
+ def plot_all_runs_pra(pra_list, mean_df=None, line_width=2.0, hide_minor_ticks=True):
228
+ """
229
+ Plots all individual Precision-Recall curves faintly, with an optional mean line.
230
+ pra_list: list of dataframes (each with 'tp' and 'precision' columns) OR list of Series (if index is tp)
231
+ mean_df: optional dataframe with 'mean' column indexed by tp
232
+ """
233
+ config = dload("config")
234
+ plot_config = config["plotting"]
235
+
236
+ fig, ax = plt.subplots(figsize=(6, 4))
237
+ plt.subplots_adjust(right=0.7)
238
+
239
+ ax.set_xscale("log")
240
+
241
+ if hide_minor_ticks:
242
+ ax.xaxis.set_minor_locator(NullLocator())
243
+ ax.xaxis.set_minor_formatter(NullFormatter())
244
+
245
+ # Plot individual lines
246
+ for i, df in enumerate(pra_list):
247
+ # Ensure we filter low TPs same as others
248
+ df_filtered = df[df['tp'] > 10] if 'tp' in df.columns else df[df.index > 10]
249
+
250
+ x = df_filtered['tp'] if 'tp' in df_filtered.columns else df_filtered.index
251
+ y = df_filtered['precision'] if 'precision' in df_filtered.columns else df_filtered.values
252
+
253
+ # Only add label for the first line to avoid cluttering legend
254
+ lbl = "Individual Runs" if i == 0 else None
255
+ ax.plot(x, y, c="gray", linewidth=0.5, alpha=0.3, label=lbl)
256
+
257
+ # Plot mean line if provided
258
+ if mean_df is not None:
259
+ mean_df = mean_df[mean_df.index > 10]
260
+ ax.plot(mean_df.index, mean_df['mean'], c="black", label="Mean Precision", linewidth=line_width, alpha=0.9)
261
+
262
+ ax.set(title="Precision-Recall (All Runs)",
263
+ xlabel="Number of True Positives (TP)",
264
+ ylabel="Precision")
265
+ ax.legend(loc="upper left", bbox_to_anchor=(1.05, 1), frameon=False)
266
+ ax.set_ylim(0, 1)
267
+
268
+ # Nature style
269
+ ax.grid(False)
270
+ ax.spines['top'].set_visible(False)
271
+ ax.spines['right'].set_visible(False)
272
+
273
+ if plot_config["save_plot"]:
274
+ output_type = plot_config["output_type"]
275
+ output_path = Path(config["output_folder"]) / f"aggregated_all_runs_precision_recall_curve.{output_type}"
276
+ fig.savefig(output_path, bbox_inches="tight", format=output_type)
277
+
278
+ if plot_config.get("show_plot", True):
279
+ plt.show()
280
+ plt.close(fig)
281
+
117
282
  def plot_percomplex_scatter(n_top=10, sig_color='#B71A2A', nonsig_color='#DBDDDD', label_color='black', border_color='black', border_width=1.0, show_text_background=True):
118
283
  config = dload("config")
119
284
  plot_config = config["plotting"]
@@ -1056,13 +1221,110 @@ def plot_auc_scores():
1056
1221
  return pra_dict
1057
1222
 
1058
1223
 
1224
+ def plot_mpr_complexes_auc_scores(filter_key: str = "all"):
1225
+ """Plot AUC scores for the mPR complexes curve (Fig 1F-style).
1226
+
1227
+ Requires `mpr_prepare()` to have been run for each dataset.
1228
+
1229
+ Parameters
1230
+ ----------
1231
+ filter_key : str
1232
+ One of: "all", "no_mtRibo_ETCI", "no_small_highAUPRC".
1233
+
1234
+ Returns
1235
+ -------
1236
+ pd.Series
1237
+ AUC values indexed by dataset name (sorted descending).
1238
+ """
1239
+ config = dload("config")
1240
+ plot_config = config["plotting"]
1241
+ mpr_auc_dict = dload("mpr_complexes_auc")
1242
+ input_colors = dload("input", "colors")
1243
+
1244
+ if input_colors:
1245
+ input_colors = {_sanitize(k): v for k, v in input_colors.items()}
1246
+
1247
+ if not isinstance(mpr_auc_dict, dict) or not mpr_auc_dict:
1248
+ log.warning(
1249
+ "No mPR complexes AUC data found. Run mpr_prepare() first (it stores 'mpr_complexes_auc')."
1250
+ )
1251
+ return pd.Series(dtype=float)
1252
+
1253
+ # Build Series: dataset -> auc
1254
+ auc_by_dataset = {}
1255
+ for dataset, per_filter in mpr_auc_dict.items():
1256
+ if not isinstance(per_filter, dict):
1257
+ continue
1258
+ val = per_filter.get(filter_key)
1259
+ if val is None:
1260
+ continue
1261
+ try:
1262
+ auc_by_dataset[dataset] = float(val)
1263
+ except (TypeError, ValueError):
1264
+ continue
1265
+
1266
+ if not auc_by_dataset:
1267
+ log.warning(
1268
+ f"No mPR complexes AUC scores found for filter '{filter_key}'. Available filters: {list(FILTER_STYLES.keys())}"
1269
+ )
1270
+ return pd.Series(dtype=float)
1059
1271
 
1272
+ s = pd.Series(auc_by_dataset).sort_values(ascending=False)
1273
+ datasets = list(s.index)
1274
+ auc_scores = list(s.values)
1275
+
1276
+ fig, ax = plt.subplots()
1277
+
1278
+ # Color logic (match other bar plots)
1279
+ cmap_name = config.get("color_map", "tab10")
1280
+ try:
1281
+ cmap = get_cmap(cmap_name)
1282
+ except ValueError:
1283
+ cmap = get_cmap("tab10")
1284
+
1285
+ num_datasets = len(datasets)
1286
+ if num_datasets <= 10 and cmap_name == "tab10":
1287
+ default_colors = [cmap(i) for i in range(num_datasets)]
1288
+ else:
1289
+ default_colors = [cmap(float(i) / max(num_datasets - 1, 1)) for i in range(num_datasets)]
1290
+
1291
+ final_colors = []
1292
+ for i, dataset in enumerate(datasets):
1293
+ color = input_colors.get(dataset) if input_colors else None
1294
+ if color is None:
1295
+ color = default_colors[i]
1296
+ final_colors.append(color)
1297
+
1298
+ ax.bar(datasets, auc_scores, color=final_colors, edgecolor="black")
1299
+
1300
+ ymax = max([v for v in auc_scores if np.isfinite(v)], default=0.0)
1301
+ ax.set_ylim(0, ymax + 0.01)
1302
+ ax.set_ylabel("mPR complexes AUC")
1303
+ plt.xticks(rotation=45, ha="right")
1304
+
1305
+ # Styling consistent with other plots
1306
+ ax.grid(visible=False, which="both", axis="both")
1307
+ ax.set_axisbelow(False)
1308
+ ax.spines["top"].set_visible(False)
1309
+ ax.spines["right"].set_visible(False)
1310
+
1311
+ if plot_config.get("save_plot", False):
1312
+ output_type = plot_config.get("output_type", "pdf")
1313
+ output_folder = Path(config["output_folder"])
1314
+ output_folder.mkdir(parents=True, exist_ok=True)
1315
+ output_path = output_folder / f"mpr_complexes_auc_{filter_key}.{output_type}"
1316
+ plt.savefig(output_path, bbox_inches="tight", format=output_type)
1317
+
1318
+ if plot_config.get("show_plot", True):
1319
+ plt.show()
1320
+
1321
+ plt.close(fig)
1322
+ return s
1060
1323
 
1061
1324
  # -----------------------------------------------------------------------------
1062
1325
  # mPR plots (Fig. 1E and Fig. 1F)
1063
1326
  # -----------------------------------------------------------------------------
1064
1327
 
1065
-
1066
1328
  def plot_mpr_complexes(name, ax=None, save=True, outname=None):
1067
1329
  """
1068
1330
  Fig. 1F-style module-level PR:
@@ -1213,7 +1475,6 @@ def plot_mpr_tp(name, ax=None, save=True, outname=None):
1213
1475
 
1214
1476
  return ax
1215
1477
 
1216
-
1217
1478
  """
1218
1479
  Multi-dataset mPR plotting functions.
1219
1480
 
@@ -1234,7 +1495,6 @@ from pathlib import Path
1234
1495
  from .utils import dload
1235
1496
  from .logging_config import log
1236
1497
 
1237
-
1238
1498
  # Default color palette (colorblind-friendly)
1239
1499
  DEFAULT_COLORS = [
1240
1500
  "#4E79A7", # blue
@@ -1257,6 +1517,21 @@ FILTER_STYLES = {
1257
1517
  }
1258
1518
 
1259
1519
 
1520
+ def _normalize_show_filters(show_filters):
1521
+ """Normalize show_filters to an ordered tuple of filter keys.
1522
+
1523
+ Common footgun: passing a single string (e.g. "no_mtRibo_ETCI") is iterable,
1524
+ which would otherwise be treated as a sequence of characters.
1525
+ """
1526
+ if show_filters is None:
1527
+ return tuple(FILTER_STYLES.keys())
1528
+ if isinstance(show_filters, str):
1529
+ return (show_filters,)
1530
+ try:
1531
+ return tuple(show_filters)
1532
+ except TypeError:
1533
+ return (show_filters,)
1534
+
1260
1535
  def plot_mpr_tp_multi(
1261
1536
  dataset_names=None,
1262
1537
  colors=None,
@@ -1297,6 +1572,8 @@ def plot_mpr_tp_multi(
1297
1572
  config = dload("config")
1298
1573
  plot_config = config["plotting"]
1299
1574
  input_colors = dload("input", "colors")
1575
+
1576
+ show_filters = _normalize_show_filters(show_filters)
1300
1577
 
1301
1578
  # Sanitize color keys
1302
1579
  if input_colors:
@@ -1421,14 +1698,21 @@ def plot_mpr_tp_multi(
1421
1698
 
1422
1699
  # Save
1423
1700
  if save:
1701
+ output_type = plot_config.get("output_type", "pdf")
1424
1702
  if outname is None:
1425
- outname = "mpr_tp_multi.pdf"
1703
+ outname = f"mpr_tp_multi.{output_type}"
1704
+
1705
+ # Check if outname is just a filename or a full path
1706
+ outpath = Path(outname)
1707
+ if len(outpath.parts) == 1:
1708
+ # Just a filename, prepend configured output folder
1709
+ outpath = Path(config["output_folder"]) / outname
1710
+
1426
1711
  fig.tight_layout()
1427
- fig.savefig(outname, bbox_inches="tight")
1712
+ fig.savefig(outpath, bbox_inches="tight", format=output_type)
1428
1713
 
1429
1714
  return ax
1430
1715
 
1431
-
1432
1716
  def plot_mpr_complexes_multi(
1433
1717
  dataset_names=None,
1434
1718
  colors=None,
@@ -1437,6 +1721,8 @@ def plot_mpr_complexes_multi(
1437
1721
  outname=None,
1438
1722
  linewidth=1.8,
1439
1723
  show_filters=("all", "no_mtRibo_ETCI", "no_small_highAUPRC"),
1724
+ show_markers="auto",
1725
+ marker_size=20,
1440
1726
  ):
1441
1727
  """
1442
1728
  Plot module-level PR (#complexes vs precision) for multiple datasets.
@@ -1461,6 +1747,11 @@ def plot_mpr_complexes_multi(
1461
1747
  Line width for all curves
1462
1748
  show_filters : tuple of str
1463
1749
  Which filters to show. Default is all three.
1750
+ show_markers : bool or "auto"
1751
+ If True, draw markers on curves to make short curves visible.
1752
+ If "auto" (default), markers are drawn only for curves with <= 10 points.
1753
+ marker_size : int
1754
+ Scatter marker size (points^2) when markers are shown.
1464
1755
 
1465
1756
  Returns
1466
1757
  -------
@@ -1469,6 +1760,8 @@ def plot_mpr_complexes_multi(
1469
1760
  config = dload("config")
1470
1761
  plot_config = config["plotting"]
1471
1762
  input_colors = dload("input", "colors")
1763
+
1764
+ show_filters = _normalize_show_filters(show_filters)
1472
1765
 
1473
1766
  # Sanitize color keys
1474
1767
  if input_colors:
@@ -1545,13 +1838,26 @@ def plot_mpr_complexes_multi(
1545
1838
  prec_plot = precision_cutoffs[mask]
1546
1839
 
1547
1840
  style = FILTER_STYLES.get(filter_key, {})
1548
- ax.plot(
1549
- cov_plot,
1550
- prec_plot,
1551
- color=color,
1552
- linestyle=style.get("linestyle", "-"),
1553
- linewidth=linewidth,
1554
- )
1841
+
1842
+ # Decide marker visibility
1843
+ if show_markers == "auto":
1844
+ use_markers = (cov_plot.size <= 10)
1845
+ else:
1846
+ use_markers = bool(show_markers)
1847
+
1848
+ if cov_plot.size == 1:
1849
+ # A single point is effectively invisible as a line; draw a marker.
1850
+ ax.scatter(cov_plot, prec_plot, color=color, s=marker_size, zorder=3)
1851
+ else:
1852
+ ax.plot(
1853
+ cov_plot,
1854
+ prec_plot,
1855
+ color=color,
1856
+ linestyle=style.get("linestyle", "-"),
1857
+ linewidth=linewidth,
1858
+ marker=("o" if use_markers else None),
1859
+ markersize=(3 if use_markers else None),
1860
+ )
1555
1861
 
1556
1862
  # Configure axes
1557
1863
  ax.set_xscale("log")
@@ -1575,18 +1881,26 @@ def plot_mpr_complexes_multi(
1575
1881
 
1576
1882
  # Save
1577
1883
  if save:
1884
+ output_type = plot_config.get("output_type", "pdf")
1578
1885
  if outname is None:
1579
- outname = "mpr_complexes_multi.pdf"
1886
+ outname = f"mpr_complexes_multi.{output_type}"
1887
+
1888
+ # Check if outname is just a filename or a full path
1889
+ outpath = Path(outname)
1890
+ if len(outpath.parts) == 1:
1891
+ # Just a filename, prepend configured output folder
1892
+ outpath = Path(config["output_folder"]) / outname
1893
+
1580
1894
  fig.tight_layout()
1581
- fig.savefig(outname, bbox_inches="tight")
1895
+ fig.savefig(outpath, bbox_inches="tight", format=output_type)
1582
1896
 
1583
1897
  return ax
1584
1898
 
1585
-
1586
1899
  def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
1587
1900
  """
1588
1901
  Add vertically stacked legends: Dataset on top, Filter below.
1589
1902
  """
1903
+ show_filters = _normalize_show_filters(show_filters)
1590
1904
  # Legend 1: Datasets (colors) - solid lines
1591
1905
  dataset_handles = []
1592
1906
  for i, name in enumerate(dataset_names):
@@ -1632,11 +1946,11 @@ def _add_vertical_legend(ax, dataset_names, colors, show_filters, linewidth):
1632
1946
  bbox_to_anchor=(1.05, 1.0 - len(dataset_names) * 0.06 - 0.1)
1633
1947
  )
1634
1948
 
1635
-
1636
1949
  def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
1637
1950
  """
1638
1951
  Add two legends: one for datasets (colors), one for filters (line styles).
1639
1952
  """
1953
+ show_filters = _normalize_show_filters(show_filters)
1640
1954
  # Legend 1: Datasets (colors) - solid lines
1641
1955
  dataset_handles = []
1642
1956
  for i, name in enumerate(dataset_names):
@@ -1682,7 +1996,6 @@ def _add_dual_legend(ax, dataset_names, colors, show_filters, linewidth):
1682
1996
  title_fontsize=8,
1683
1997
  )
1684
1998
 
1685
-
1686
1999
  # ============================================================================
1687
2000
  # Single dataset functions are now obsolete
1688
2001
  # ============================================================================
@@ -13,28 +13,36 @@ from pathlib import Path
13
13
 
14
14
 
15
15
  def return_package_dir():
16
-
17
- # Get the distribution
18
- dist = distribution('pythonflex')
19
-
20
- # Check for direct_url.json
21
- direct_url_text = dist.read_text('direct_url.json')
22
-
23
- if direct_url_text:
24
- direct_url = json.loads(direct_url_text)
25
- if direct_url.get('dir_info', {}).get('editable'):
26
- # Editable install detected
27
- project_url = direct_url['url']
28
- # Remove 'file:///' prefix and handle Windows paths
29
- project_root = project_url.removeprefix('file:///').replace('/', os.sep)
30
- # Assuming src layout: project_root/src/pythonflex
31
- package_dir = os.path.join(project_root, 'src', 'pythonflex')
16
+ try:
17
+ # Get the distribution
18
+ dist = distribution('pythonflex')
19
+
20
+ # Check for direct_url.json
21
+ try:
22
+ direct_url_text = dist.read_text('direct_url.json')
23
+ except FileNotFoundError:
24
+ direct_url_text = None
25
+
26
+ if direct_url_text:
27
+ direct_url = json.loads(direct_url_text)
28
+ if direct_url.get('dir_info', {}).get('editable'):
29
+ # Editable install detected
30
+ project_url = direct_url['url']
31
+ # Remove 'file:///' prefix and handle Windows paths
32
+ project_root = project_url.removeprefix('file:///').replace('/', os.sep)
33
+ # Assuming src layout: project_root/src/pythonflex
34
+ package_dir = os.path.join(project_root, 'src', 'pythonflex')
35
+ else:
36
+ # Non-editable
37
+ package_dir = str(files('pythonflex'))
32
38
  else:
33
- # Non-editable
39
+ # No direct_url, assume non-editable
34
40
  package_dir = str(files('pythonflex'))
35
- else:
36
- # No direct_url, assume non-editable
37
- package_dir = str(files('pythonflex'))
41
+
42
+ except Exception: # PackageNotFoundError or other issues
43
+ # Fallback to local directory relative to this file
44
+ # precise location: src/pythonflex/preprocessing.py -> package dir is parent
45
+ package_dir = str(Path(__file__).parent)
38
46
 
39
47
  return package_dir
40
48
 
@@ -190,7 +198,6 @@ def load_gold_standard():
190
198
  "PATHWAY": "gold_standard/PATHWAY.parquet"
191
199
  }
192
200
 
193
-
194
201
  if gold_standard_source in gold_standard_files:
195
202
  # Load predefined gold standard from package resources
196
203
  filename = gold_standard_files[gold_standard_source]
@@ -1,8 +1,14 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pythonflex
3
- Version: 0.3.1
3
+ Version: 0.3.3
4
4
  Summary: pythonFLEX is a benchmarking toolkit for evaluating CRISPR screen results against biological gold standards. The toolkit computes gene-level and complex-level performance metrics, helping researchers systematically assess the biological relevance and resolution of their CRISPR screening data.
5
5
  Author-email: Yasir Demirtaş <tyasird@hotmail.com>
6
+ Classifier: License :: OSI Approved :: MIT License
7
+ Classifier: Operating System :: OS Independent
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.9
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
6
12
  Requires-Python: >=3.9
7
13
  Requires-Dist: adjusttext
8
14
  Requires-Dist: art
@@ -1,8 +1,8 @@
1
- pythonflex/__init__.py,sha256=UKu_QgAZsWgERWedUA7drG4kIQ8zKJLSLc8OYHHNJSM,1570
2
- pythonflex/analysis.py,sha256=n8gIidtRk9_DxoO6Z4g1MSH0rYsPfQAKdzPtEguZqQY,75067
1
+ pythonflex/__init__.py,sha256=MoDbdVhclK_PF_u9vzN4ntWX6hTRAKfvkTiDisIci5o,1748
2
+ pythonflex/analysis.py,sha256=gKJ4cYA_TWYe521nAXizMqChd36A90TWfDf595fw_0M,77760
3
3
  pythonflex/logging_config.py,sha256=iqRKK18zvtfV_-bYHWrXtSZywiUtYxoHkw0ZnVORQBQ,2015
4
- pythonflex/plotting.py,sha256=7S8IibsyEaK26YKv6FXShMix_15vCUQnZIxD7VyJwmQ,64036
5
- pythonflex/preprocessing.py,sha256=5cV8zNbrgCslidrMpMjGr-7HzTZgVligWVEsUQu3Stw,10999
4
+ pythonflex/plotting.py,sha256=AOzgyhJX5bPMoGs2ih2zbA30Dm-OoWpk8MNBC-9OQ94,75981
5
+ pythonflex/preprocessing.py,sha256=jIeyB2SPPac-svtjB-zGe3vIyOSVB-SxYIFyNFFiCsY,11440
6
6
  pythonflex/utils.py,sha256=7toGnKbA_TKBtHz1HLk7ckWM0bjuFw_Byhp6ZUJaNs4,3694
7
7
  pythonflex/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
8
  pythonflex/data/dataset/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,13 +16,9 @@ pythonflex/data/gold_standard/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5N
16
16
  pythonflex/data/gold_standard/corum.csv,sha256=2rZeyr2Ghm7f-gFxCZnhPtxI2jxRoiZMUEH2EJwAgsI,208889
17
17
  pythonflex/data/gold_standard/gobp.csv,sha256=TO9yfx9mO8WkXvWfSB-pFId9T8xYfqdZpshAXC0Fyj8,1739167
18
18
  pythonflex/data/gold_standard/pathway.csv,sha256=J3HKVLUZ_Oxucmn_14ieYp3Wr2lcKtp0nIl4_8_K2Yc,489424
19
- pythonflex/examples/basic_usage.py,sha256=LniAq5Al5meNfcqlniYIRpOYRTce0BvGhZpw4P6_djc,1994
20
- pythonflex/examples/comparison.py,sha256=Gaakp4xk8EWd_Sdmm9I9QHxk5DyQwpLUfHlQKn1l7WU,2084
21
- pythonflex/examples/dataset_filtering.py,sha256=7PCKCZWYLZUn3XAStGTCaVGbY9F0gqjT0ote_G6WPho,1238
22
- pythonflex/examples/diag.py,sha256=9sKfMTn8_em6IJOAX2hE1DRJs7-qrRuWyXWfQUwSK5c,3815
23
- pythonflex/examples/test.py,sha256=B8-JE5AU7be5loSr6Qv2rOviXXe1NRCYpaEGfGjaow0,2388
24
- pythonflex/examples/test2.py,sha256=nbjd3A9R2R-Cf4P9jdgclysoZbQVC2Cmzt4Npbsxw6w,184
25
- pythonflex-0.3.1.dist-info/METADATA,sha256=l6hqrsmT0tRkaDgg4g5KwpRC-tG-Yj7e8sSWuE6uD54,3928
26
- pythonflex-0.3.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
27
- pythonflex-0.3.1.dist-info/entry_points.txt,sha256=37liK1baI_CRVDivpjsn8JDClL9_YeTTuSMAZ3Ty7oE,47
28
- pythonflex-0.3.1.dist-info/RECORD,,
19
+ pythonflex/examples/basic_usage.py,sha256=dizQXYPJWjW7-2d2G29a8qYCBRIsKhrvxOxyXtudK30,2265
20
+ pythonflex/examples/manuscript.py,sha256=V28vIBFmrxGsE_YhvouRFiLKWC9CorbOx9Ed3B2L8bQ,2810
21
+ pythonflex-0.3.3.dist-info/METADATA,sha256=l5CnF5hX_qgnhMEHnTQbK9ZrBJrIRKzbYeCVCC7Wv1M,4226
22
+ pythonflex-0.3.3.dist-info/WHEEL,sha256=QccIxa26bgl1E6uMy58deGWi-0aeIkkangHcxk2kWfw,87
23
+ pythonflex-0.3.3.dist-info/entry_points.txt,sha256=37liK1baI_CRVDivpjsn8JDClL9_YeTTuSMAZ3Ty7oE,47
24
+ pythonflex-0.3.3.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.28.0
2
+ Generator: hatchling 1.29.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,78 +0,0 @@
1
- """
2
- Basic usage example of the pythonFLEX package.
3
- Demonstrates initialization, data loading, analysis, and plotting.
4
- """
5
- #%%
6
- import pythonflex as flex
7
- import pandas as pd
8
-
9
- depmap = pd.read_csv('../../../../_datasets/depmap/25Q2/gene_effect.csv', index_col=0)
10
- white = pd.read_csv('../../../../_datasets/depmap/25Q2/25Q2_chronos_whitened_PCA.csv', index_col=0).T
11
-
12
- inputs = {
13
- "25Q2": {
14
- "path": depmap,
15
- "sort": "high",
16
- "color": "#fff000" # Black
17
- },
18
-
19
- "25Q2 white": {
20
- "path": white,
21
- "sort": "high",
22
- "color": "#ff0000" # Orange
23
- },
24
- }
25
-
26
- default_config = {
27
- "min_genes_in_complex": 0,
28
- "min_genes_per_complex_analysis": 3,
29
- "output_folder": "CORUM_25Q2_comparison2",
30
- "gold_standard": "CORUM",
31
- "color_map": "BuGn",
32
- "jaccard": False,
33
- "use_common_genes": False, # Set to False for individual dataset-gold standard intersections
34
- "plotting": {
35
- "save_plot": True,
36
- "output_type": "png",
37
- },
38
- "preprocessing": {
39
- "fill_na": True,
40
- "normalize": False,
41
- },
42
- "corr_function": "numpy",
43
- "logging": {
44
- "visible_levels": ["DONE"] # "PROGRESS", "STARTED", ,"INFO","WARNING"
45
- }
46
- }
47
-
48
- # Initialize logger, config, and output folder
49
- flex.initialize(default_config)
50
-
51
- # Load datasets and gold standard terms
52
- data, _ = flex.load_datasets(inputs)
53
- terms, genes_in_terms = flex.load_gold_standard()
54
-
55
- # Run analysis
56
- for name, dataset in data.items():
57
- pra = flex.pra(name, dataset, is_corr=False)
58
- fpc = flex.pra_percomplex(name, dataset, is_corr=False)
59
- flex.mpr_prepare(name) # Add this line
60
- cc = flex.complex_contributions(name)
61
-
62
-
63
-
64
-
65
- #%%
66
- # Generate plots
67
- flex.plot_precision_recall_curve()
68
- flex.plot_auc_scores()
69
- flex.plot_significant_complexes()
70
- flex.plot_percomplex_scatter(n_top=20)
71
- flex.plot_percomplex_scatter_bysize()
72
- flex.plot_complex_contributions()
73
- flex.plot_mpr_tp_multi()
74
- flex.plot_mpr_complexes_multi()
75
- # Save results to CSV
76
- # flex.save_results_to_csv()
77
-
78
- # %%
@@ -1,42 +0,0 @@
1
-
2
- # %%
3
- import pandas as pd
4
-
5
- df = pd.read_csv("../../../../datasets/depmap/24Q4/CRISPRGeneEffect.csv",index_col=0)
6
- model = pd.read_csv("../../../../datasets/depmap/24Q4/Model.csv",index_col=0)
7
-
8
- df.columns = df.columns.str.split(" \\(").str[0]
9
- df = df.T
10
-
11
- #%%
12
-
13
- # %%
14
- # get ModelID of selected disease for example OncotreePrimaryDisease==Melanoma
15
- melanoma = model[model.OncotreePrimaryDisease=="Melanoma"].index.unique().values
16
- liver = model[model.OncotreeLineage=="Liver"].index.unique().values
17
- neuroblastoma = model[model.OncotreePrimaryDisease=="Neuroblastoma"].index.unique().values
18
-
19
- # %%
20
- # mel.index is model ids, filter that ids in the columns of df
21
- mel_df = df.loc[:,df.columns.isin(melanoma)]
22
- liver_df = df.loc[:,df.columns.isin(liver)]
23
- neuro_df = df.loc[:,df.columns.isin(neuroblastoma)]
24
-
25
-
26
- # %%
27
- mel_df.to_csv("melanoma.csv")
28
- liver_df.to_csv("liver.csv")
29
- neuro_df.to_csv("neuroblastoma.csv")
30
- df.to_csv("depmap_geneeffect_all_cellines.csv")
31
-
32
-
33
- # %%
34
- import pandas as pd
35
- df = pd.read_csv('../../../../_datasets/depmap/19Q2/Achilles_gene_effect.csv', index_col=0)
36
- df.columns = df.columns.str.split(" \\(").str[0]
37
- df = df.T
38
-
39
- # %%
40
- df.to_csv("../../../../_datasets/depmap/19Q2/gene_effect.csv")
41
-
42
- # %%
@@ -1,106 +0,0 @@
1
- #%%
2
- # Run this in Jupyter to test the two approaches
3
-
4
- import numpy as np
5
- import pandas as pd
6
- from pythonflex.utils import dload
7
-
8
- dataset_name = "[CORUM] 19Q2"
9
-
10
- pra = dload("pra", dataset_name)
11
- mpr = dload("mpr", dataset_name)
12
-
13
- filter_ids = set(mpr["filters"]["no_mtRibo_ETCI"])
14
- print(f"Filter IDs: {filter_ids}")
15
-
16
- cid_col = "complex_id" if "complex_id" in pra.columns else "complex_ids"
17
-
18
- # Sort by score descending
19
- pra_sorted = pra.sort_values("score", ascending=False).reset_index(drop=True)
20
-
21
- def has_filter_id(cids, filter_ids):
22
- """Check if any complex ID is in filter_ids"""
23
- if isinstance(cids, (np.ndarray, list)):
24
- ids = [int(x) for x in cids if pd.notnull(x)]
25
- else:
26
- return False
27
- return any(c in filter_ids for c in ids)
28
-
29
- # Mark which pairs should be filtered
30
- pra_sorted["should_filter"] = pra_sorted[cid_col].apply(lambda x: has_filter_id(x, filter_ids))
31
-
32
- print(f"\nTotal pairs: {len(pra_sorted)}")
33
- print(f"Pairs to filter: {pra_sorted['should_filter'].sum()}")
34
- print(f"TPs to filter: {(pra_sorted['should_filter'] & (pra_sorted['prediction']==1)).sum()}")
35
-
36
- # APPROACH 1: Mark as negative (what your Python does)
37
- # Keep all rows, but filtered TPs become FPs
38
- print("\n" + "=" * 70)
39
- print("APPROACH 1: Mark filtered TPs as negatives (keep rows)")
40
- print("=" * 70)
41
-
42
- df1 = pra_sorted.copy()
43
- df1["true_filtered"] = df1["prediction"].copy()
44
- df1.loc[df1["should_filter"] & (df1["prediction"]==1), "true_filtered"] = 0
45
-
46
- tp_cum_1 = df1["true_filtered"].cumsum()
47
- prec_1 = tp_cum_1 / (np.arange(len(df1)) + 1)
48
-
49
- # Show precision at key TP counts
50
- print("\nPrecision at key TP counts:")
51
- for target_tp in [10, 50, 100, 500, 1000]:
52
- if target_tp <= tp_cum_1.max():
53
- idx = np.where(tp_cum_1 >= target_tp)[0][0]
54
- print(f" TP={target_tp}: precision={prec_1.iloc[idx]:.3f} (at rank {idx+1})")
55
-
56
- # APPROACH 2: Remove rows entirely (what R does with replace=FALSE)
57
- print("\n" + "=" * 70)
58
- print("APPROACH 2: Remove filtered rows entirely")
59
- print("=" * 70)
60
-
61
- df2 = pra_sorted[~pra_sorted["should_filter"]].copy().reset_index(drop=True)
62
-
63
- tp_cum_2 = df2["prediction"].cumsum()
64
- prec_2 = tp_cum_2 / (np.arange(len(df2)) + 1)
65
-
66
- print(f"\nRows remaining after removal: {len(df2)}")
67
- print(f"TPs remaining: {df2['prediction'].sum()}")
68
-
69
- print("\nPrecision at key TP counts:")
70
- for target_tp in [10, 50, 100, 500, 1000]:
71
- if target_tp <= tp_cum_2.max():
72
- idx = np.where(tp_cum_2 >= target_tp)[0][0]
73
- print(f" TP={target_tp}: precision={prec_2.iloc[idx]:.3f} (at rank {idx+1})")
74
-
75
- # APPROACH 3: Only remove filtered POSITIVE pairs, keep negatives
76
- print("\n" + "=" * 70)
77
- print("APPROACH 3: Remove only filtered TPs (keep filtered negatives)")
78
- print("=" * 70)
79
-
80
- # This removes TP rows that contain filter IDs, but keeps negative rows
81
- remove_mask = pra_sorted["should_filter"] & (pra_sorted["prediction"] == 1)
82
- df3 = pra_sorted[~remove_mask].copy().reset_index(drop=True)
83
-
84
- tp_cum_3 = df3["prediction"].cumsum()
85
- prec_3 = tp_cum_3 / (np.arange(len(df3)) + 1)
86
-
87
- print(f"\nRows remaining: {len(df3)}")
88
- print(f"TPs remaining: {df3['prediction'].sum()}")
89
-
90
- print("\nPrecision at key TP counts:")
91
- for target_tp in [10, 50, 100, 500, 1000]:
92
- if target_tp <= tp_cum_3.max():
93
- idx = np.where(tp_cum_3 >= target_tp)[0][0]
94
- print(f" TP={target_tp}: precision={prec_3.iloc[idx]:.3f} (at rank {idx+1})")
95
-
96
- print("\n" + "=" * 70)
97
- print("COMPARISON")
98
- print("=" * 70)
99
- print("""
100
- Approach 1 (mark as negative): Filtered TPs become FPs, lowering precision
101
- Approach 2 (remove all filtered): Both TPs and negatives removed
102
- Approach 3 (remove only TPs): Only filtered TPs removed, negatives kept
103
-
104
- The R code uses Approach 3 (remove positive pairs that contain the filter ID).
105
- """)
106
- # %%
@@ -1,104 +0,0 @@
1
- #%%
2
- import pythonflex as flex
3
- import os
4
-
5
- # # Define specific cell line types you're interested in
6
- DATA_DIR = "C:/Users/yd/Desktop/projects/_datasets/depmap/25Q2/subset/"
7
-
8
- # Specific cell lines of interest with "_cell_lines" suffix removed
9
- cell_line_files = [
10
- "soft_tissue_cell_lines.csv",
11
- "skin_cell_lines.csv",
12
- # "lung_cell_lines.csv",
13
- # "head_and_neck_cell_lines.csv",
14
- # "esophagus_stomach_cell_lines.csv",
15
- ]
16
-
17
- inputs = {}
18
-
19
- # Create inputs dict with shortened names (removing "_cell_lines" suffix)
20
- for filename in cell_line_files:
21
- # Remove .csv extension and _cell_lines suffix
22
- key = filename.replace("_cell_lines.csv", "")
23
- full_path = os.path.join(DATA_DIR, filename)
24
-
25
- inputs[key] = {
26
- "path": full_path,
27
- "sort": "high"
28
- }
29
-
30
- inputs['depmap'] = {
31
- "path": "C:/Users/yd/Desktop/projects/_datasets/depmap/25Q2/gene_effect.csv",
32
- "sort": "high"
33
- }
34
-
35
- # Print the resulting inputs dictionary
36
- print("Configured inputs:")
37
- for key, value in inputs.items():
38
- print(f" {key}: {value['path']}")
39
-
40
-
41
-
42
- default_config = {
43
- "min_genes_in_complex": 2,
44
- "min_genes_per_complex_analysis": 2,
45
- "output_folder": "25q2_min_genes_2",
46
- "gold_standard": "CORUM",
47
- "color_map": "RdYlBu",
48
- "jaccard": True,
49
- "plotting": {
50
- "save_plot": True,
51
- "output_type": "pdf",
52
- },
53
- "preprocessing": {
54
- "fill_na": True,
55
- "normalize": False,
56
- },
57
- "corr_function": "numpy",
58
- "logging": {
59
- "visible_levels": ["DONE","STARTED"] # "PROGRESS", "STARTED", ,"INFO","WARNING"
60
- }
61
- }
62
-
63
- # Initialize logger, config, and output folder
64
- flex.initialize(default_config)
65
-
66
- # Load datasets and gold standard terms
67
- data, _ = flex.load_datasets(inputs)
68
- terms, genes_in_terms = flex.load_gold_standard()
69
-
70
-
71
- #%%
72
- # Run analysis
73
- for name, dataset in data.items():
74
- pra = flex.pra(name, dataset, is_corr=False)
75
- fpc = flex.pra_percomplex(name, dataset, is_corr=False)
76
- cc = flex.complex_contributions(name)
77
-
78
-
79
-
80
- #%%
81
- # Generate plots
82
- flex.plot_auc_scores()
83
- flex.plot_precision_recall_curve()
84
- flex.plot_percomplex_scatter()
85
- flex.plot_percomplex_scatter_bysize()
86
- flex.plot_significant_complexes()
87
- flex.plot_complex_contributions()
88
-
89
-
90
- #%%
91
- # Save results to CSV
92
- flex.save_results_to_csv()
93
-
94
-
95
-
96
-
97
-
98
-
99
-
100
-
101
-
102
- #%%
103
-
104
-
@@ -1,11 +0,0 @@
1
- #%%
2
- import anndata as ad
3
-
4
- adata = ad.read_h5ad(
5
- "C:/Users/yd/Desktop/22mcell/GWCD4i.pseudobulk_merged.h5ad",
6
- backed="r" # read-only, disk-backed
7
- )
8
-
9
- #%%
10
- adata
11
- # %%