grasp-tool 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- grasp_tool/__init__.py +17 -0
- grasp_tool/__main__.py +6 -0
- grasp_tool/cli/__init__.py +1 -0
- grasp_tool/cli/main.py +793 -0
- grasp_tool/cli/train_moco.py +778 -0
- grasp_tool/gnn/__init__.py +1 -0
- grasp_tool/gnn/embedding.py +165 -0
- grasp_tool/gnn/gat_moco_final.py +990 -0
- grasp_tool/gnn/graphloader.py +1748 -0
- grasp_tool/gnn/plot_refined.py +1556 -0
- grasp_tool/preprocessing/__init__.py +1 -0
- grasp_tool/preprocessing/augumentation.py +66 -0
- grasp_tool/preprocessing/cellplot.py +475 -0
- grasp_tool/preprocessing/filter.py +171 -0
- grasp_tool/preprocessing/network.py +79 -0
- grasp_tool/preprocessing/partition.py +654 -0
- grasp_tool/preprocessing/portrait.py +1862 -0
- grasp_tool/preprocessing/register.py +1021 -0
- grasp_tool-0.1.0.dist-info/METADATA +511 -0
- grasp_tool-0.1.0.dist-info/RECORD +22 -0
- grasp_tool-0.1.0.dist-info/WHEEL +4 -0
- grasp_tool-0.1.0.dist-info/entry_points.txt +3 -0
|
@@ -0,0 +1,1556 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from sklearn.decomposition import PCA
|
|
5
|
+
from sklearn.manifold import TSNE
|
|
6
|
+
from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN
|
|
7
|
+
import seaborn as sns
|
|
8
|
+
from sklearn.manifold import TSNE
|
|
9
|
+
from sklearn.metrics import (
|
|
10
|
+
accuracy_score,
|
|
11
|
+
adjusted_rand_score,
|
|
12
|
+
normalized_mutual_info_score,
|
|
13
|
+
)
|
|
14
|
+
from scipy.stats import mode
|
|
15
|
+
import igraph as ig
|
|
16
|
+
from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
|
|
17
|
+
from leidenalg import find_partition
|
|
18
|
+
import igraph as ig
|
|
19
|
+
from sklearn.metrics.pairwise import cosine_similarity
|
|
20
|
+
import numpy as np
|
|
21
|
+
import igraph as ig
|
|
22
|
+
from sklearn.metrics import classification_report
|
|
23
|
+
from sklearn.metrics import classification_report, accuracy_score
|
|
24
|
+
from scipy.optimize import linear_sum_assignment
|
|
25
|
+
import numpy as np
|
|
26
|
+
from sklearn.cluster import SpectralClustering
|
|
27
|
+
from sklearn.mixture import GaussianMixture
|
|
28
|
+
from sklearn.preprocessing import StandardScaler
|
|
29
|
+
from sklearn.feature_selection import SelectKBest, f_classif
|
|
30
|
+
import os
|
|
31
|
+
import warnings
|
|
32
|
+
import time
|
|
33
|
+
from typing import List, Dict, Tuple, Optional, Union, Any
|
|
34
|
+
|
|
35
|
+
warnings.filterwarnings("ignore", category=RuntimeWarning)
|
|
36
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _lazy_import_umap():
|
|
40
|
+
"""Import `umap` only when actually needed.
|
|
41
|
+
|
|
42
|
+
Some environments may pull in TensorFlow when importing `umap`, which can
|
|
43
|
+
emit noisy logs. Avoid importing it at module import time.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
try:
|
|
47
|
+
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
|
|
48
|
+
import umap # type: ignore
|
|
49
|
+
|
|
50
|
+
return umap
|
|
51
|
+
except ModuleNotFoundError as e:
|
|
52
|
+
raise ModuleNotFoundError(
|
|
53
|
+
"Missing optional dependency `umap-learn` (import name: `umap`).\n"
|
|
54
|
+
"Install: pip install umap-learn"
|
|
55
|
+
) from e
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
########################################################
|
|
59
|
+
# Unified evaluation + visualization utilities.
|
|
60
|
+
#
|
|
61
|
+
# The main entrypoint is evaluate_and_visualize(...), which replaces older
|
|
62
|
+
# per-mode plotting helpers.
|
|
63
|
+
########################################################
|
|
64
|
+
|
|
65
|
+
# Preprocessing method names.
|
|
66
|
+
PREPROCESS_BASIC = "basic" # Basic
|
|
67
|
+
PREPROCESS_SCALER = "scaler" # StandardScaler
|
|
68
|
+
PREPROCESS_PCA = "pca" # PCA
|
|
69
|
+
PREPROCESS_SELECT = "select" # Feature selection
|
|
70
|
+
|
|
71
|
+
# All supported preprocessing methods.
|
|
72
|
+
ALL_PREPROCESS_METHODS = [
|
|
73
|
+
PREPROCESS_BASIC,
|
|
74
|
+
PREPROCESS_SCALER,
|
|
75
|
+
PREPROCESS_PCA,
|
|
76
|
+
PREPROCESS_SELECT,
|
|
77
|
+
]
|
|
78
|
+
|
|
79
|
+
# TODO: performance/memory improvements for TSNE/UMAP on large datasets.
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def compute_metrics(true_labels, predicted_labels):
|
|
83
|
+
"""Compute clustering metrics.
|
|
84
|
+
|
|
85
|
+
Returns ARI (adjusted rand index) and NMI (normalized mutual information).
|
|
86
|
+
"""
|
|
87
|
+
ari = adjusted_rand_score(true_labels, predicted_labels)
|
|
88
|
+
nmi = normalized_mutual_info_score(true_labels, predicted_labels)
|
|
89
|
+
return ari, nmi
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def map_labels_with_hungarian(true_labels, predicted_labels):
|
|
93
|
+
"""Map predicted cluster labels to true labels via Hungarian matching."""
|
|
94
|
+
# Ensure numpy arrays.
|
|
95
|
+
true_labels = np.asarray(true_labels)
|
|
96
|
+
predicted_labels = np.asarray(predicted_labels)
|
|
97
|
+
|
|
98
|
+
# Unique label values.
|
|
99
|
+
true_classes = np.unique(true_labels)
|
|
100
|
+
predicted_classes = np.unique(predicted_labels)
|
|
101
|
+
|
|
102
|
+
# Build cost matrix.
|
|
103
|
+
n_true = len(true_classes)
|
|
104
|
+
n_pred = len(predicted_classes)
|
|
105
|
+
max_size = max(n_true, n_pred)
|
|
106
|
+
cost_matrix = np.zeros((max_size, max_size))
|
|
107
|
+
|
|
108
|
+
# Fill cost matrix (negative overlap -> Hungarian solves min-cost).
|
|
109
|
+
for i, t_class in enumerate(true_classes):
|
|
110
|
+
for j, p_class in enumerate(predicted_classes):
|
|
111
|
+
# Overlap between (t_class, p_class)
|
|
112
|
+
match = np.sum((true_labels == t_class) & (predicted_labels == p_class))
|
|
113
|
+
cost_matrix[i, j] = -match
|
|
114
|
+
|
|
115
|
+
# Pad extra rows/cols with a large cost.
|
|
116
|
+
if n_true < max_size or n_pred < max_size:
|
|
117
|
+
large_val = abs(np.min(cost_matrix)) * 10
|
|
118
|
+
for i in range(n_true, max_size):
|
|
119
|
+
cost_matrix[i, :] = large_val
|
|
120
|
+
for j in range(n_pred, max_size):
|
|
121
|
+
cost_matrix[:, j] = large_val
|
|
122
|
+
|
|
123
|
+
# Hungarian matching
|
|
124
|
+
row_ind, col_ind = linear_sum_assignment(cost_matrix)
|
|
125
|
+
|
|
126
|
+
# Build mapping
|
|
127
|
+
label_mapping = {}
|
|
128
|
+
for row, col in zip(row_ind, col_ind):
|
|
129
|
+
if row < n_true and col < n_pred: # Valid match
|
|
130
|
+
label_mapping[predicted_classes[col]] = true_classes[row]
|
|
131
|
+
|
|
132
|
+
# Handle unmatched predicted labels (if any)
|
|
133
|
+
unmatched_pred_labels = set(predicted_classes) - set(label_mapping.keys())
|
|
134
|
+
if unmatched_pred_labels:
|
|
135
|
+
# Find best-effort mapping for unmatched labels
|
|
136
|
+
for unmatched_label in unmatched_pred_labels:
|
|
137
|
+
# Overlap with each true class
|
|
138
|
+
overlaps = []
|
|
139
|
+
for t_class in true_classes:
|
|
140
|
+
overlap = np.sum(
|
|
141
|
+
(predicted_labels == unmatched_label) & (true_labels == t_class)
|
|
142
|
+
)
|
|
143
|
+
overlaps.append((t_class, overlap))
|
|
144
|
+
|
|
145
|
+
# Assign to the most overlapping true class
|
|
146
|
+
best_class = max(overlaps, key=lambda x: x[1])[0]
|
|
147
|
+
label_mapping[unmatched_label] = best_class
|
|
148
|
+
|
|
149
|
+
# Mapped label array
|
|
150
|
+
mapped_labels = np.array([label_mapping[label] for label in predicted_labels])
|
|
151
|
+
|
|
152
|
+
return mapped_labels, label_mapping
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def compute_metrics_with_classification_report(true_labels, predicted_labels):
|
|
156
|
+
"""Compute metrics from sklearn's classification_report.
|
|
157
|
+
|
|
158
|
+
Returns: (accuracy, precision, recall, f1_score) using macro average.
|
|
159
|
+
"""
|
|
160
|
+
report = classification_report(
|
|
161
|
+
true_labels, predicted_labels, output_dict=True, zero_division="warn"
|
|
162
|
+
)
|
|
163
|
+
accuracy = report.get("accuracy", 0.0)
|
|
164
|
+
|
|
165
|
+
macro_avg_report = report.get("macro avg", {})
|
|
166
|
+
precision = macro_avg_report.get("precision", 0.0)
|
|
167
|
+
recall = macro_avg_report.get("recall", 0.0)
|
|
168
|
+
f1_score = macro_avg_report.get("f1-score", 0.0)
|
|
169
|
+
return accuracy, precision, recall, f1_score
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def load_location_data(df, dataset, graphs_number=None, specific_label_file=None):
|
|
173
|
+
"""
|
|
174
|
+
Load location/label data and merge it into the input DataFrame.
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
df: Input DataFrame.
|
|
178
|
+
dataset: Dataset name.
|
|
179
|
+
graphs_number: Optional graph count suffix used in label file naming.
|
|
180
|
+
specific_label_file: Optional label filename or an absolute path.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
DataFrame: DataFrame with an added/updated 'location' column.
|
|
184
|
+
"""
|
|
185
|
+
# Priority order: label column names from high to low.
|
|
186
|
+
priority_label_cols = [
|
|
187
|
+
"groundtruth_wzx",
|
|
188
|
+
"groundtruth",
|
|
189
|
+
"label",
|
|
190
|
+
"location",
|
|
191
|
+
"cluster",
|
|
192
|
+
"category",
|
|
193
|
+
"type",
|
|
194
|
+
]
|
|
195
|
+
|
|
196
|
+
# Warn if required columns are missing; do not early-return.
|
|
197
|
+
if "cell" not in df.columns or "gene" not in df.columns:
|
|
198
|
+
print(
|
|
199
|
+
f"Warning: DataFrame missing required columns ('cell' and/or 'gene'), merge may fail"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
# Candidate base paths for label files.
|
|
203
|
+
base_paths = [
|
|
204
|
+
"../1_input/label",
|
|
205
|
+
"../1_input/label_annotation",
|
|
206
|
+
"../../GCN_CL/1_input/label",
|
|
207
|
+
"../../GCN_CL/1_input/label_annotation",
|
|
208
|
+
"./1_input/label",
|
|
209
|
+
"./1_input/label_annotation",
|
|
210
|
+
]
|
|
211
|
+
|
|
212
|
+
# Try to locate a label file.
|
|
213
|
+
label_file = None
|
|
214
|
+
|
|
215
|
+
# Allow absolute paths.
|
|
216
|
+
# 1) If user provided an absolute path and it exists, use it directly.
|
|
217
|
+
if (
|
|
218
|
+
specific_label_file
|
|
219
|
+
and os.path.isabs(specific_label_file)
|
|
220
|
+
and os.path.exists(specific_label_file)
|
|
221
|
+
):
|
|
222
|
+
label_file = specific_label_file
|
|
223
|
+
print(f"Using absolute label file path: {label_file}")
|
|
224
|
+
# 2) If user provided a relative filename, search under base_paths.
|
|
225
|
+
elif specific_label_file:
|
|
226
|
+
for base_path in base_paths:
|
|
227
|
+
path = f"{base_path}/{specific_label_file}"
|
|
228
|
+
if os.path.exists(path):
|
|
229
|
+
label_file = path
|
|
230
|
+
print(f"Using specified label file: {label_file}")
|
|
231
|
+
break
|
|
232
|
+
|
|
233
|
+
# If no specific file was found, try common naming patterns.
|
|
234
|
+
if label_file is None:
|
|
235
|
+
for base_path in base_paths:
|
|
236
|
+
possible_files = [
|
|
237
|
+
f"{base_path}/{dataset}_label.csv",
|
|
238
|
+
f"{base_path}/{dataset}_labeled.csv",
|
|
239
|
+
]
|
|
240
|
+
if graphs_number:
|
|
241
|
+
possible_files.extend(
|
|
242
|
+
[
|
|
243
|
+
f"{base_path}/{dataset}_graph{graphs_number}_label.csv",
|
|
244
|
+
f"{base_path}/{dataset}_graph{graphs_number}_labeled.csv",
|
|
245
|
+
]
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
for file_path in possible_files:
|
|
249
|
+
if os.path.exists(file_path):
|
|
250
|
+
label_file = file_path
|
|
251
|
+
print(f"Found label file: {label_file}")
|
|
252
|
+
break
|
|
253
|
+
|
|
254
|
+
if label_file:
|
|
255
|
+
break
|
|
256
|
+
|
|
257
|
+
# If a label file is found, load and process it.
|
|
258
|
+
if label_file:
|
|
259
|
+
try:
|
|
260
|
+
# Read label file.
|
|
261
|
+
label_df = pd.read_csv(label_file)
|
|
262
|
+
|
|
263
|
+
# Print label file columns for debugging.
|
|
264
|
+
print(f"Label file columns: {label_df.columns.tolist()}")
|
|
265
|
+
|
|
266
|
+
# Find label columns by priority order.
|
|
267
|
+
found_label_cols = []
|
|
268
|
+
for col in priority_label_cols:
|
|
269
|
+
if col in label_df.columns:
|
|
270
|
+
found_label_cols.append(col)
|
|
271
|
+
print(
|
|
272
|
+
f"Found label column: '{col}' (priority {priority_label_cols.index(col) + 1})"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
if found_label_cols:
|
|
276
|
+
primary_label_col = found_label_cols[0] # highest-priority column
|
|
277
|
+
print(f"Using '{primary_label_col}' as primary label column")
|
|
278
|
+
else:
|
|
279
|
+
print("Warning: No recognized label column found in label file")
|
|
280
|
+
primary_label_col = None
|
|
281
|
+
|
|
282
|
+
# Robustly infer label file format.
|
|
283
|
+
# Long format: has 'cell' and 'gene' columns (+ one or more label columns).
|
|
284
|
+
# Wide format: one row per cell, many gene columns.
|
|
285
|
+
|
|
286
|
+
if "cell" in label_df.columns and "gene" in label_df.columns:
|
|
287
|
+
print(f"Detected long format label file")
|
|
288
|
+
|
|
289
|
+
# Keep only required columns.
|
|
290
|
+
keep_cols = ["cell", "gene"] + found_label_cols
|
|
291
|
+
# Only keep columns that exist.
|
|
292
|
+
label_df = label_df[
|
|
293
|
+
[col for col in keep_cols if col in label_df.columns]
|
|
294
|
+
]
|
|
295
|
+
|
|
296
|
+
# Merge into input dataframe.
|
|
297
|
+
try:
|
|
298
|
+
merged_df = df.merge(label_df, on=["gene", "cell"], how="left")
|
|
299
|
+
print(f"Merged on both 'gene' and 'cell' columns")
|
|
300
|
+
except Exception as e:
|
|
301
|
+
print(f"Error in gene+cell merge: {e}, trying cell-only merge")
|
|
302
|
+
try:
|
|
303
|
+
# Fallback to merging only on 'cell'.
|
|
304
|
+
merged_df = df.merge(label_df, on=["cell"], how="left")
|
|
305
|
+
print(f"Merged on 'cell' column only")
|
|
306
|
+
except Exception as e2:
|
|
307
|
+
print(f"All merge attempts failed: {e2}")
|
|
308
|
+
# Fallback to default location.
|
|
309
|
+
df["location"] = "unknown"
|
|
310
|
+
return df
|
|
311
|
+
|
|
312
|
+
# Normalize the primary label column to 'location' (if needed).
|
|
313
|
+
if primary_label_col and primary_label_col != "location":
|
|
314
|
+
if primary_label_col in merged_df.columns:
|
|
315
|
+
print(f"Renaming '{primary_label_col}' to 'location'")
|
|
316
|
+
merged_df["location"] = merged_df[primary_label_col]
|
|
317
|
+
|
|
318
|
+
# Ensure 'location' exists.
|
|
319
|
+
if "location" not in merged_df.columns:
|
|
320
|
+
print("Creating 'location' column from highest priority label")
|
|
321
|
+
if found_label_cols:
|
|
322
|
+
merged_df["location"] = merged_df[found_label_cols[0]]
|
|
323
|
+
else:
|
|
324
|
+
print("Warning: No label columns found, using 'unknown'")
|
|
325
|
+
merged_df["location"] = "unknown"
|
|
326
|
+
|
|
327
|
+
merged_df["location"] = merged_df["location"].fillna("unknown")
|
|
328
|
+
print(f"Merged data, final shape: {merged_df.shape}")
|
|
329
|
+
|
|
330
|
+
# Print label distribution.
|
|
331
|
+
label_counts = merged_df["location"].value_counts()
|
|
332
|
+
print(f"Label distribution (top 10): \n{label_counts.head(10)}")
|
|
333
|
+
|
|
334
|
+
return merged_df
|
|
335
|
+
|
|
336
|
+
# Wide format: row is a cell, columns are genes (and possibly metadata).
|
|
337
|
+
elif "cell" in label_df.columns and len(label_df.columns) > 2:
|
|
338
|
+
print(
|
|
339
|
+
f"Detected wide format label file, columns: {label_df.columns.tolist()}"
|
|
340
|
+
)
|
|
341
|
+
|
|
342
|
+
# Confirm it looks like a typical wide format with gene columns.
|
|
343
|
+
non_gene_cols = ["cell", "Cell"] + priority_label_cols
|
|
344
|
+
potential_gene_cols = [
|
|
345
|
+
col for col in label_df.columns if col not in non_gene_cols
|
|
346
|
+
]
|
|
347
|
+
|
|
348
|
+
# Heuristic: treat as wide format if there are many gene columns.
|
|
349
|
+
if len(potential_gene_cols) > 5: # assume at least 5 gene columns
|
|
350
|
+
print(
|
|
351
|
+
f"Converting wide format to long format with {len(potential_gene_cols)} gene columns"
|
|
352
|
+
)
|
|
353
|
+
|
|
354
|
+
# Preserve non-gene columns for later merge.
|
|
355
|
+
cell_info_cols = []
|
|
356
|
+
if primary_label_col:
|
|
357
|
+
cell_info_cols = [primary_label_col]
|
|
358
|
+
else:
|
|
359
|
+
# If no label column is found, keep all non-gene columns.
|
|
360
|
+
cell_info_cols = [
|
|
361
|
+
col
|
|
362
|
+
for col in non_gene_cols
|
|
363
|
+
if col in label_df.columns and col != "cell"
|
|
364
|
+
]
|
|
365
|
+
|
|
366
|
+
cell_info_df = None
|
|
367
|
+
if cell_info_cols:
|
|
368
|
+
cell_info_df = label_df[["cell"] + cell_info_cols].copy()
|
|
369
|
+
print(f"Preserved cell information columns: {cell_info_cols}")
|
|
370
|
+
|
|
371
|
+
# Convert wide -> long.
|
|
372
|
+
long_df = pd.melt(
|
|
373
|
+
label_df,
|
|
374
|
+
id_vars="cell",
|
|
375
|
+
value_vars=potential_gene_cols,
|
|
376
|
+
var_name="gene",
|
|
377
|
+
value_name="gene_value",
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
# Merge into input dataframe.
|
|
381
|
+
try:
|
|
382
|
+
merged_df = df.merge(long_df, on=["gene", "cell"], how="left")
|
|
383
|
+
print(f"Merged gene-cell data")
|
|
384
|
+
except Exception as e:
|
|
385
|
+
print(f"Error in gene-cell merge for wide format: {e}")
|
|
386
|
+
# Fallback to default location.
|
|
387
|
+
df["location"] = "unknown"
|
|
388
|
+
return df
|
|
389
|
+
|
|
390
|
+
# Merge preserved cell info columns (if any).
|
|
391
|
+
if cell_info_df is not None:
|
|
392
|
+
try:
|
|
393
|
+
merged_df = merged_df.merge(
|
|
394
|
+
cell_info_df, on="cell", how="left"
|
|
395
|
+
)
|
|
396
|
+
print(f"Added cell information columns")
|
|
397
|
+
except Exception as e:
|
|
398
|
+
print(f"Error merging cell info: {e}")
|
|
399
|
+
|
|
400
|
+
# Set 'location' column.
|
|
401
|
+
if primary_label_col and primary_label_col in merged_df.columns:
|
|
402
|
+
merged_df["location"] = merged_df[primary_label_col]
|
|
403
|
+
print(f"Set 'location' from '{primary_label_col}'")
|
|
404
|
+
else:
|
|
405
|
+
# Fallback to 'gene_value'.
|
|
406
|
+
merged_df["location"] = merged_df["gene_value"]
|
|
407
|
+
print(f"Set 'location' from gene_value column")
|
|
408
|
+
|
|
409
|
+
# Ensure 'location' has values.
|
|
410
|
+
merged_df["location"] = merged_df["location"].fillna("unknown")
|
|
411
|
+
print(
|
|
412
|
+
f"Merged data from wide format, final shape: {merged_df.shape}"
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Print label distribution.
|
|
416
|
+
label_counts = merged_df["location"].value_counts()
|
|
417
|
+
print(f"Label distribution (top 10): \n{label_counts.head(10)}")
|
|
418
|
+
|
|
419
|
+
return merged_df
|
|
420
|
+
else:
|
|
421
|
+
print(
|
|
422
|
+
f"Label file has 'cell' column but doesn't appear to be a typical wide format"
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
# Try merging based on 'cell'.
|
|
426
|
+
try:
|
|
427
|
+
# Use discovered label columns.
|
|
428
|
+
useful_cols = ["cell"]
|
|
429
|
+
if found_label_cols:
|
|
430
|
+
useful_cols.extend(found_label_cols)
|
|
431
|
+
else:
|
|
432
|
+
# If no label columns are found, keep all likely label cols.
|
|
433
|
+
useful_cols.extend(
|
|
434
|
+
[
|
|
435
|
+
col
|
|
436
|
+
for col in label_df.columns
|
|
437
|
+
if col in priority_label_cols
|
|
438
|
+
or col not in non_gene_cols
|
|
439
|
+
]
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
# Only keep existing columns.
|
|
443
|
+
label_subset = label_df[
|
|
444
|
+
[col for col in useful_cols if col in label_df.columns]
|
|
445
|
+
].copy()
|
|
446
|
+
|
|
447
|
+
# Merge.
|
|
448
|
+
merged_df = df.merge(label_subset, on="cell", how="left")
|
|
449
|
+
|
|
450
|
+
# Set 'location' column.
|
|
451
|
+
if primary_label_col and primary_label_col in merged_df.columns:
|
|
452
|
+
merged_df["location"] = merged_df[primary_label_col]
|
|
453
|
+
print(f"Set 'location' from '{primary_label_col}'")
|
|
454
|
+
else:
|
|
455
|
+
# If no primary label column is available, use default.
|
|
456
|
+
print("No primary label column available")
|
|
457
|
+
merged_df["location"] = "unknown"
|
|
458
|
+
|
|
459
|
+
merged_df["location"] = merged_df["location"].fillna("unknown")
|
|
460
|
+
|
|
461
|
+
print(
|
|
462
|
+
f"Merged data using cell-based join, final shape: {merged_df.shape}"
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
# Print label distribution.
|
|
466
|
+
label_counts = merged_df["location"].value_counts()
|
|
467
|
+
print(f"Label distribution (top 10): \n{label_counts.head(10)}")
|
|
468
|
+
|
|
469
|
+
return merged_df
|
|
470
|
+
except Exception as e:
|
|
471
|
+
print(f"Error in cell-based merge: {e}")
|
|
472
|
+
|
|
473
|
+
# Unrecognized format.
|
|
474
|
+
print(f"Unrecognized label file format, using default 'unknown' location")
|
|
475
|
+
df["location"] = "unknown"
|
|
476
|
+
return df
|
|
477
|
+
|
|
478
|
+
except Exception as e:
|
|
479
|
+
print(f"Error processing label file {label_file}: {e}")
|
|
480
|
+
import traceback
|
|
481
|
+
|
|
482
|
+
traceback.print_exc()
|
|
483
|
+
|
|
484
|
+
# If all attempts fail, add a default 'location' column.
|
|
485
|
+
print(
|
|
486
|
+
"Could not find or process a suitable label file, using default 'unknown' location"
|
|
487
|
+
)
|
|
488
|
+
if "location" not in df.columns:
|
|
489
|
+
df["location"] = "unknown"
|
|
490
|
+
|
|
491
|
+
return df
|
|
492
|
+
|
|
493
|
+
|
|
494
|
+
def apply_clustering(features, n_clusters, clustering_methods=None):
|
|
495
|
+
"""
|
|
496
|
+
Apply multiple clustering methods.
|
|
497
|
+
|
|
498
|
+
Args:
|
|
499
|
+
features: Feature matrix.
|
|
500
|
+
n_clusters: Number of clusters.
|
|
501
|
+
clustering_methods: Optional list of method names to run.
|
|
502
|
+
|
|
503
|
+
Returns:
|
|
504
|
+
clustering_methods: Dict of method name -> cluster labels.
|
|
505
|
+
"""
|
|
506
|
+
all_methods = {
|
|
507
|
+
"KMeans": KMeans(n_clusters=n_clusters, random_state=2025)
|
|
508
|
+
.fit(features)
|
|
509
|
+
.labels_,
|
|
510
|
+
"Agglomerative": AgglomerativeClustering(n_clusters=n_clusters).fit_predict(
|
|
511
|
+
features
|
|
512
|
+
),
|
|
513
|
+
"SpectralClustering": SpectralClustering(
|
|
514
|
+
n_clusters=n_clusters, random_state=2025, affinity="nearest_neighbors"
|
|
515
|
+
).fit_predict(features),
|
|
516
|
+
"GaussianMixture": GaussianMixture(n_components=n_clusters, random_state=2025)
|
|
517
|
+
.fit(features)
|
|
518
|
+
.predict(features),
|
|
519
|
+
}
|
|
520
|
+
|
|
521
|
+
# If methods are specified, return only those.
|
|
522
|
+
if clustering_methods is not None:
|
|
523
|
+
return {k: v for k, v in all_methods.items() if k in clustering_methods}
|
|
524
|
+
|
|
525
|
+
# Otherwise return all methods.
|
|
526
|
+
return all_methods
|
|
527
|
+
|
|
528
|
+
|
|
529
|
+
def evaluate_clustering(
|
|
530
|
+
true_labels, clustering_methods, df, save_path, num_epochs, lr, suffix=""
|
|
531
|
+
):
|
|
532
|
+
"""
|
|
533
|
+
Evaluate clustering results.
|
|
534
|
+
|
|
535
|
+
Args:
|
|
536
|
+
true_labels: Ground-truth labels.
|
|
537
|
+
clustering_methods: Dict of method name -> predicted labels.
|
|
538
|
+
df: DataFrame.
|
|
539
|
+
save_path: Output directory.
|
|
540
|
+
num_epochs: Current epoch.
|
|
541
|
+
lr: Learning rate.
|
|
542
|
+
suffix: Filename suffix.
|
|
543
|
+
|
|
544
|
+
Returns:
|
|
545
|
+
metrics: Dict of evaluation metrics.
|
|
546
|
+
"""
|
|
547
|
+
# Check label distribution.
|
|
548
|
+
unique_labels = np.unique(true_labels)
|
|
549
|
+
unique_labels_count = len(unique_labels)
|
|
550
|
+
print("\nLabel distribution:")
|
|
551
|
+
print(f"Unique label count: {unique_labels_count}")
|
|
552
|
+
print(
|
|
553
|
+
f"Label values: {unique_labels[:10]}{'...' if len(unique_labels) > 10 else ''}"
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
# If all samples share one label, ARI/NMI become 0 by definition.
|
|
557
|
+
if unique_labels_count == 1:
|
|
558
|
+
print(f"Warning: all samples share the same label '{unique_labels[0]}'")
|
|
559
|
+
print(
|
|
560
|
+
"This makes ARI/NMI equal to 0 and can make accuracy-like metrics appear high, "
|
|
561
|
+
"because there is no class variation to compare."
|
|
562
|
+
)
|
|
563
|
+
print("Please check the label file or load_location_data() processing.\n")
|
|
564
|
+
elif unique_labels_count < 2:
|
|
565
|
+
print("Error: too few unique labels for meaningful clustering evaluation")
|
|
566
|
+
return {"error": "insufficient_labels"}
|
|
567
|
+
|
|
568
|
+
metrics = {}
|
|
569
|
+
for method, predicted_labels in clustering_methods.items():
|
|
570
|
+
mapped_labels, label_mapping = map_labels_with_hungarian(
|
|
571
|
+
true_labels, predicted_labels
|
|
572
|
+
)
|
|
573
|
+
accuracy, precision, recall, f1_score = (
|
|
574
|
+
compute_metrics_with_classification_report(true_labels, mapped_labels)
|
|
575
|
+
)
|
|
576
|
+
print(f"{method} Label Mapping: {label_mapping}")
|
|
577
|
+
|
|
578
|
+
ari, nmi = compute_metrics(true_labels, predicted_labels)
|
|
579
|
+
metrics[method] = {
|
|
580
|
+
"ARI": ari,
|
|
581
|
+
"NMI": nmi,
|
|
582
|
+
"Accuracy": accuracy,
|
|
583
|
+
"Precision": precision,
|
|
584
|
+
"Recall": recall,
|
|
585
|
+
"F1-Score": f1_score,
|
|
586
|
+
"Label Mapping": label_mapping,
|
|
587
|
+
}
|
|
588
|
+
df[method + "_cluster"] = predicted_labels
|
|
589
|
+
|
|
590
|
+
print(f"Epoch {num_epochs} lr={lr} suffix={suffix} clustering metrics:")
|
|
591
|
+
for method, values in metrics.items():
|
|
592
|
+
print(
|
|
593
|
+
f"{method}: ARI = {values['ARI']:.4f}, NMI = {values['NMI']:.4f}, "
|
|
594
|
+
f"Accuracy = {values['Accuracy']:.4f}, Precision = {values['Precision']:.4f}, "
|
|
595
|
+
f"Recall = {values['Recall']:.4f}, F1-Score = {values['F1-Score']:.4f}"
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
# Save clustering outputs.
|
|
599
|
+
if suffix:
|
|
600
|
+
df.to_csv(
|
|
601
|
+
f"{save_path}/epoch{num_epochs}_lr{lr}_clusters_{suffix}.csv", index=False
|
|
602
|
+
)
|
|
603
|
+
|
|
604
|
+
# Save metrics.
|
|
605
|
+
with open(
|
|
606
|
+
f"{save_path}/epoch{num_epochs}_lr{lr}_metrics_{suffix}.txt", "w"
|
|
607
|
+
) as f:
|
|
608
|
+
for method, values in metrics.items():
|
|
609
|
+
f.write(
|
|
610
|
+
f"{method}:\nARI = {values['ARI']:.4f}, NMI = {values['NMI']:.4f},\n "
|
|
611
|
+
f"Accuracy = {values['Accuracy']:.4f}, Precision = {values['Precision']:.4f},\n"
|
|
612
|
+
f"Recall = {values['Recall']:.4f}, F1-Score = {values['F1-Score']:.4f}\n"
|
|
613
|
+
)
|
|
614
|
+
f.write(f"{method} Label Mapping: {values['Label Mapping']}\n")
|
|
615
|
+
f.write(f"\n")
|
|
616
|
+
else:
|
|
617
|
+
df.to_csv(f"{save_path}/epoch{num_epochs}_lr{lr}_clusters.csv", index=False)
|
|
618
|
+
|
|
619
|
+
# Save metrics.
|
|
620
|
+
with open(f"{save_path}/epoch{num_epochs}_lr{lr}_metrics.txt", "w") as f:
|
|
621
|
+
for method, values in metrics.items():
|
|
622
|
+
f.write(
|
|
623
|
+
f"{method}:\nARI = {values['ARI']:.4f}, NMI = {values['NMI']:.4f},\n "
|
|
624
|
+
f"Accuracy = {values['Accuracy']:.4f}, Precision = {values['Precision']:.4f},\n"
|
|
625
|
+
f"Recall = {values['Recall']:.4f}, F1-Score = {values['F1-Score']:.4f}\n"
|
|
626
|
+
)
|
|
627
|
+
f.write(f"{method} Label Mapping: {values['Label Mapping']}\n")
|
|
628
|
+
f.write(f"\n")
|
|
629
|
+
|
|
630
|
+
return metrics
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def visualize_clustering(
|
|
634
|
+
reduced_features_tsne,
|
|
635
|
+
reduced_features_umap,
|
|
636
|
+
true_labels,
|
|
637
|
+
clustering_methods,
|
|
638
|
+
num_epochs,
|
|
639
|
+
lr,
|
|
640
|
+
size=15,
|
|
641
|
+
):
|
|
642
|
+
"""
|
|
643
|
+
Visualize clustering results.
|
|
644
|
+
|
|
645
|
+
Args:
|
|
646
|
+
reduced_features_tsne: t-SNE 2D features.
|
|
647
|
+
reduced_features_umap: UMAP 2D features.
|
|
648
|
+
true_labels: Ground-truth labels.
|
|
649
|
+
clustering_methods: Dict of method name -> predicted labels.
|
|
650
|
+
num_epochs: Current epoch.
|
|
651
|
+
lr: Learning rate.
|
|
652
|
+
size: Point size.
|
|
653
|
+
|
|
654
|
+
Returns:
|
|
655
|
+
fig: Matplotlib figure.
|
|
656
|
+
"""
|
|
657
|
+
fig, axes = plt.subplots(
|
|
658
|
+
len(clustering_methods) + 1, 2, figsize=(16, 4 * (len(clustering_methods) + 1))
|
|
659
|
+
)
|
|
660
|
+
fig.suptitle(f"Epoch {num_epochs} lr={lr}", fontsize=16)
|
|
661
|
+
|
|
662
|
+
# True labels.
|
|
663
|
+
sns.scatterplot(
|
|
664
|
+
x=reduced_features_tsne[:, 0],
|
|
665
|
+
y=reduced_features_tsne[:, 1],
|
|
666
|
+
hue=true_labels,
|
|
667
|
+
palette="Set1",
|
|
668
|
+
ax=axes[0, 0],
|
|
669
|
+
s=size,
|
|
670
|
+
legend="full",
|
|
671
|
+
)
|
|
672
|
+
axes[0, 0].set_title("t-SNE with True Labels")
|
|
673
|
+
axes[0, 0].legend(loc="upper left", bbox_to_anchor=(1, 1))
|
|
674
|
+
|
|
675
|
+
sns.scatterplot(
|
|
676
|
+
x=reduced_features_umap[:, 0],
|
|
677
|
+
y=reduced_features_umap[:, 1],
|
|
678
|
+
hue=true_labels,
|
|
679
|
+
palette="Set1",
|
|
680
|
+
ax=axes[0, 1],
|
|
681
|
+
s=size,
|
|
682
|
+
legend="full",
|
|
683
|
+
)
|
|
684
|
+
axes[0, 1].set_title("UMAP with True Labels")
|
|
685
|
+
axes[0, 1].legend(loc="upper left", bbox_to_anchor=(1, 1))
|
|
686
|
+
|
|
687
|
+
# Predicted clusters.
|
|
688
|
+
for i, (method, predicted_labels) in enumerate(clustering_methods.items(), start=1):
|
|
689
|
+
sns.scatterplot(
|
|
690
|
+
x=reduced_features_tsne[:, 0],
|
|
691
|
+
y=reduced_features_tsne[:, 1],
|
|
692
|
+
hue=predicted_labels,
|
|
693
|
+
palette="Set2",
|
|
694
|
+
ax=axes[i, 0],
|
|
695
|
+
s=size,
|
|
696
|
+
legend="full",
|
|
697
|
+
)
|
|
698
|
+
axes[i, 0].set_title(f"t-SNE with {method} Clusters")
|
|
699
|
+
axes[i, 0].legend(loc="upper left", bbox_to_anchor=(1, 1))
|
|
700
|
+
|
|
701
|
+
sns.scatterplot(
|
|
702
|
+
x=reduced_features_umap[:, 0],
|
|
703
|
+
y=reduced_features_umap[:, 1],
|
|
704
|
+
hue=predicted_labels,
|
|
705
|
+
palette="Set3",
|
|
706
|
+
ax=axes[i, 1],
|
|
707
|
+
s=size,
|
|
708
|
+
legend="full",
|
|
709
|
+
)
|
|
710
|
+
axes[i, 1].set_title(f"UMAP with {method} Clusters")
|
|
711
|
+
axes[i, 1].legend(loc="upper left", bbox_to_anchor=(1, 1))
|
|
712
|
+
|
|
713
|
+
plt.tight_layout()
|
|
714
|
+
return fig
|
|
715
|
+
|
|
716
|
+
|
|
717
|
+
def preprocess_features(df, preprocess_method="basic", true_labels=None):
|
|
718
|
+
"""
|
|
719
|
+
Preprocess features according to the given method.
|
|
720
|
+
|
|
721
|
+
Args:
|
|
722
|
+
df: DataFrame.
|
|
723
|
+
preprocess_method: One of 'basic', 'scaler', 'pca', 'select'.
|
|
724
|
+
true_labels: Labels used by feature selection.
|
|
725
|
+
|
|
726
|
+
Returns:
|
|
727
|
+
features: Preprocessed feature matrix.
|
|
728
|
+
"""
|
|
729
|
+
# Exclude known non-feature columns.
|
|
730
|
+
non_feature_cols = [
|
|
731
|
+
"cell",
|
|
732
|
+
"gene",
|
|
733
|
+
"graph_id",
|
|
734
|
+
"original_graph_id",
|
|
735
|
+
"augmented_graph_id",
|
|
736
|
+
"location",
|
|
737
|
+
"groundtruth",
|
|
738
|
+
"groundtruth_wzx",
|
|
739
|
+
"label",
|
|
740
|
+
"cluster",
|
|
741
|
+
"category",
|
|
742
|
+
"type",
|
|
743
|
+
"SCT_cluster",
|
|
744
|
+
"cell_type",
|
|
745
|
+
"celltype", # additional common non-feature column name
|
|
746
|
+
# Also exclude columns generated by clustering methods.
|
|
747
|
+
"KMeans_cluster",
|
|
748
|
+
"Agglomerative_cluster",
|
|
749
|
+
"SpectralClustering_cluster",
|
|
750
|
+
"GaussianMixture_cluster",
|
|
751
|
+
]
|
|
752
|
+
|
|
753
|
+
# Candidate feature columns.
|
|
754
|
+
candidate_feature_names = [
|
|
755
|
+
col
|
|
756
|
+
for col in df.columns
|
|
757
|
+
if col not in non_feature_cols and not str(col).endswith("_cluster")
|
|
758
|
+
]
|
|
759
|
+
|
|
760
|
+
if not candidate_feature_names:
|
|
761
|
+
print(
|
|
762
|
+
f"Warning: No candidate feature columns found after excluding non_feature_cols for method '{preprocess_method}'. Original columns: {df.columns.tolist()}"
|
|
763
|
+
)
|
|
764
|
+
return np.array([])
|
|
765
|
+
|
|
766
|
+
# Work on a copy to safely attempt numeric conversion.
|
|
767
|
+
df_candidate_features = df[candidate_feature_names].copy()
|
|
768
|
+
|
|
769
|
+
actually_numeric_cols = []
|
|
770
|
+
for col_name in df_candidate_features.columns:
|
|
771
|
+
try:
|
|
772
|
+
# Coerce column to numeric; non-convertible values become NaN.
|
|
773
|
+
converted_series = pd.to_numeric(
|
|
774
|
+
df_candidate_features[col_name], errors="coerce"
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
# Keep numeric columns that are not all-NaN after coercion.
|
|
778
|
+
if (
|
|
779
|
+
pd.api.types.is_numeric_dtype(converted_series)
|
|
780
|
+
and not converted_series.isnull().all()
|
|
781
|
+
):
|
|
782
|
+
# Heuristic: if coercion wipes most values, it's likely not a true numeric feature.
|
|
783
|
+
original_non_na_count = (
|
|
784
|
+
df_candidate_features[col_name].dropna().shape[0]
|
|
785
|
+
)
|
|
786
|
+
converted_non_na_count = converted_series.dropna().shape[0]
|
|
787
|
+
if original_non_na_count > 0 and (
|
|
788
|
+
converted_non_na_count / original_non_na_count < 0.5
|
|
789
|
+
):
|
|
790
|
+
print(
|
|
791
|
+
f"Warning: Column '{col_name}' lost significant data after numeric coercion ({original_non_na_count} -> {converted_non_na_count} non-NaNs). Excluding."
|
|
792
|
+
)
|
|
793
|
+
else:
|
|
794
|
+
actually_numeric_cols.append(col_name)
|
|
795
|
+
# Replace with numeric series to help downstream .values.
|
|
796
|
+
df_candidate_features[col_name] = converted_series
|
|
797
|
+
else:
|
|
798
|
+
print(
|
|
799
|
+
f"Warning: Column '{col_name}' excluded. Not purely numeric or became all NaNs after coercion."
|
|
800
|
+
)
|
|
801
|
+
except Exception as e:
|
|
802
|
+
print(
|
|
803
|
+
f"Warning: Column '{col_name}' encountered an error during numeric conversion ('{e}') and will be excluded."
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
if not actually_numeric_cols:
|
|
807
|
+
print(
|
|
808
|
+
f"Error: No numeric feature columns remaining after filtering for method '{preprocess_method}'."
|
|
809
|
+
)
|
|
810
|
+
return np.array([])
|
|
811
|
+
|
|
812
|
+
# Extract features from the converted dataframe.
|
|
813
|
+
features = df_candidate_features[actually_numeric_cols].values
|
|
814
|
+
|
|
815
|
+
# Ensure float dtype.
|
|
816
|
+
if features.dtype == np.object_:
|
|
817
|
+
print(
|
|
818
|
+
f"Warning: Features array for method '{preprocess_method}' has dtype 'object' before scaling. Attempting astype(float)."
|
|
819
|
+
)
|
|
820
|
+
try:
|
|
821
|
+
features = features.astype(float)
|
|
822
|
+
except ValueError as e_astype:
|
|
823
|
+
error_msg = (
|
|
824
|
+
f"Failed to convert object-type features to float for method '{preprocess_method}'. "
|
|
825
|
+
f"Problematic data likely still exists. Original error: {e_astype}"
|
|
826
|
+
)
|
|
827
|
+
print(f"Error: {error_msg}")
|
|
828
|
+
# If this happens, upstream data likely contains non-numeric features.
|
|
829
|
+
# return np.full(features.shape, np.nan)
|
|
830
|
+
raise ValueError(error_msg) from e_astype
|
|
831
|
+
|
|
832
|
+
# Apply preprocessing.
|
|
833
|
+
if preprocess_method == "basic":
|
|
834
|
+
pass # no-op
|
|
835
|
+
elif preprocess_method == "scaler":
|
|
836
|
+
if features.size == 0:
|
|
837
|
+
print(
|
|
838
|
+
f"Warning: No features to scale for method '{preprocess_method}'. Skipping scaling."
|
|
839
|
+
)
|
|
840
|
+
return features # avoid StandardScaler errors
|
|
841
|
+
scaler = StandardScaler()
|
|
842
|
+
features = scaler.fit_transform(features)
|
|
843
|
+
elif preprocess_method == "pca":
|
|
844
|
+
if features.size == 0:
|
|
845
|
+
print(
|
|
846
|
+
f"Warning: No features for PCA for method '{preprocess_method}'. Skipping PCA."
|
|
847
|
+
)
|
|
848
|
+
return features
|
|
849
|
+
scaler = StandardScaler()
|
|
850
|
+
features = scaler.fit_transform(features)
|
|
851
|
+
pca = PCA(n_components=0.95) # keep 95% variance
|
|
852
|
+
features = pca.fit_transform(features)
|
|
853
|
+
print(f" PCA applied. Features shape after PCA: {features.shape}")
|
|
854
|
+
elif preprocess_method == "select":
|
|
855
|
+
if features.size == 0:
|
|
856
|
+
print(
|
|
857
|
+
f"Warning: No features for selection for method '{preprocess_method}'. Skipping selection."
|
|
858
|
+
)
|
|
859
|
+
return features
|
|
860
|
+
scaler = StandardScaler()
|
|
861
|
+
features = scaler.fit_transform(features)
|
|
862
|
+
if true_labels is not None:
|
|
863
|
+
selector = SelectKBest(score_func=f_classif, k=min(50, features.shape[1]))
|
|
864
|
+
features = selector.fit_transform(features, true_labels)
|
|
865
|
+
else:
|
|
866
|
+
raise ValueError(f"Unsupported preprocess method: {preprocess_method}")
|
|
867
|
+
|
|
868
|
+
return features
|
|
869
|
+
|
|
870
|
+
|
|
871
|
+
def evaluate_and_visualize(
|
|
872
|
+
dataset,
|
|
873
|
+
df,
|
|
874
|
+
save_path,
|
|
875
|
+
num_epochs,
|
|
876
|
+
lr,
|
|
877
|
+
n_clusters=8,
|
|
878
|
+
size=15,
|
|
879
|
+
vis_methods=None,
|
|
880
|
+
visualize=True,
|
|
881
|
+
graphs_number=None,
|
|
882
|
+
clustering_methods=None,
|
|
883
|
+
reduce_dims=True,
|
|
884
|
+
specific_label_file=None,
|
|
885
|
+
tsne_perplexity: float = 30.0,
|
|
886
|
+
umap_n_neighbors: int = 15,
|
|
887
|
+
umap_min_dist: float = 0.2,
|
|
888
|
+
):
|
|
889
|
+
"""
|
|
890
|
+
Evaluate clustering results and optionally create visualizations.
|
|
891
|
+
|
|
892
|
+
Args:
|
|
893
|
+
dataset: Dataset name.
|
|
894
|
+
df: DataFrame containing embeddings.
|
|
895
|
+
save_path: Output directory.
|
|
896
|
+
num_epochs: Current epoch.
|
|
897
|
+
lr: Learning rate.
|
|
898
|
+
n_clusters: Number of clusters.
|
|
899
|
+
size: Point size.
|
|
900
|
+
vis_methods: List of visualization/preprocess methods.
|
|
901
|
+
visualize: Whether to generate visualization images.
|
|
902
|
+
graphs_number: Optional graph count suffix used in label file naming.
|
|
903
|
+
clustering_methods: Optional list of clustering method names.
|
|
904
|
+
reduce_dims: Whether to run dimensionality reduction.
|
|
905
|
+
specific_label_file: Optional label filename or absolute path.
|
|
906
|
+
tsne_perplexity: t-SNE perplexity.
|
|
907
|
+
umap_n_neighbors: UMAP n_neighbors.
|
|
908
|
+
umap_min_dist: UMAP min_dist.
|
|
909
|
+
|
|
910
|
+
Returns:
|
|
911
|
+
Tuple[Dict, Dict]: (metrics for all methods, figures for all methods)
|
|
912
|
+
"""
|
|
913
|
+
# Timing.
|
|
914
|
+
start_time = time.time()
|
|
915
|
+
|
|
916
|
+
# Normalize visualization methods.
|
|
917
|
+
if vis_methods is None:
|
|
918
|
+
# Default: run all.
|
|
919
|
+
vis_methods = ALL_PREPROCESS_METHODS
|
|
920
|
+
elif isinstance(vis_methods, str):
|
|
921
|
+
# Normalize a single string into a list.
|
|
922
|
+
vis_methods = [vis_methods]
|
|
923
|
+
|
|
924
|
+
# Validate visualization methods.
|
|
925
|
+
for method in vis_methods:
|
|
926
|
+
if method not in ALL_PREPROCESS_METHODS:
|
|
927
|
+
raise ValueError(
|
|
928
|
+
f"Unsupported visualization method: {method}. Available: {ALL_PREPROCESS_METHODS}"
|
|
929
|
+
)
|
|
930
|
+
|
|
931
|
+
# Load location/label data.
|
|
932
|
+
df = load_location_data(df, dataset, graphs_number, specific_label_file)
|
|
933
|
+
|
|
934
|
+
# Extract labels.
|
|
935
|
+
true_labels = df["location"].astype(str).values
|
|
936
|
+
|
|
937
|
+
# Initialize result dicts.
|
|
938
|
+
all_metrics = {}
|
|
939
|
+
all_figures = {}
|
|
940
|
+
|
|
941
|
+
# If no methods are requested, return empty results.
|
|
942
|
+
if not vis_methods:
|
|
943
|
+
print("No visualization methods specified. Returning empty results.")
|
|
944
|
+
return all_metrics, all_figures
|
|
945
|
+
|
|
946
|
+
# Track total method count for progress printing.
|
|
947
|
+
total_methods = len(vis_methods)
|
|
948
|
+
|
|
949
|
+
# Process each method.
|
|
950
|
+
for i, method in enumerate(vis_methods, 1):
|
|
951
|
+
method_start = time.time()
|
|
952
|
+
print(f"\n[{i}/{total_methods}] Start processing method: {method}")
|
|
953
|
+
|
|
954
|
+
# Preprocess features.
|
|
955
|
+
features = preprocess_features(df, method, true_labels)
|
|
956
|
+
|
|
957
|
+
# Dimensionality reduction (only when requested).
|
|
958
|
+
reduced_features_tsne = None
|
|
959
|
+
reduced_features_umap = None
|
|
960
|
+
|
|
961
|
+
if visualize and reduce_dims:
|
|
962
|
+
dim_start = time.time()
|
|
963
|
+
print(" Running dimensionality reduction...")
|
|
964
|
+
try:
|
|
965
|
+
tsne = TSNE(
|
|
966
|
+
n_components=2,
|
|
967
|
+
random_state=42,
|
|
968
|
+
perplexity=min(tsne_perplexity, len(df) - 1),
|
|
969
|
+
)
|
|
970
|
+
reduced_features_tsne = tsne.fit_transform(features)
|
|
971
|
+
except Exception as e:
|
|
972
|
+
print(f" t-SNE reduction failed: {e}")
|
|
973
|
+
reduced_features_tsne = np.zeros((len(df), 2)) # blank fallback
|
|
974
|
+
|
|
975
|
+
try:
|
|
976
|
+
umap_model = _lazy_import_umap().UMAP(
|
|
977
|
+
n_components=2,
|
|
978
|
+
random_state=42,
|
|
979
|
+
n_neighbors=min(umap_n_neighbors, len(df) - 1),
|
|
980
|
+
min_dist=umap_min_dist,
|
|
981
|
+
)
|
|
982
|
+
reduced_features_umap = umap_model.fit_transform(features)
|
|
983
|
+
except Exception as e:
|
|
984
|
+
print(f" UMAP reduction failed: {e}")
|
|
985
|
+
reduced_features_umap = np.zeros((len(df), 2)) # blank fallback
|
|
986
|
+
|
|
987
|
+
dim_time = time.time() - dim_start
|
|
988
|
+
print(f" Dimensionality reduction done in {dim_time:.2f}s")
|
|
989
|
+
|
|
990
|
+
# Clustering.
|
|
991
|
+
cluster_start = time.time()
|
|
992
|
+
clustering_methods_dict = apply_clustering(
|
|
993
|
+
features, n_clusters, clustering_methods
|
|
994
|
+
)
|
|
995
|
+
cluster_time = time.time() - cluster_start
|
|
996
|
+
print(f" Clustering done in {cluster_time:.2f}s")
|
|
997
|
+
|
|
998
|
+
# Evaluation.
|
|
999
|
+
eval_start = time.time()
|
|
1000
|
+
metrics = evaluate_clustering(
|
|
1001
|
+
true_labels,
|
|
1002
|
+
clustering_methods_dict,
|
|
1003
|
+
df,
|
|
1004
|
+
save_path,
|
|
1005
|
+
num_epochs,
|
|
1006
|
+
lr,
|
|
1007
|
+
suffix=method if method != PREPROCESS_BASIC else "",
|
|
1008
|
+
)
|
|
1009
|
+
all_metrics[method] = metrics
|
|
1010
|
+
eval_time = time.time() - eval_start
|
|
1011
|
+
print(f" Evaluation done in {eval_time:.2f}s")
|
|
1012
|
+
|
|
1013
|
+
# Visualization.
|
|
1014
|
+
if visualize and reduce_dims:
|
|
1015
|
+
vis_start = time.time()
|
|
1016
|
+
print(" Creating visualization...")
|
|
1017
|
+
fig = visualize_clustering(
|
|
1018
|
+
reduced_features_tsne,
|
|
1019
|
+
reduced_features_umap,
|
|
1020
|
+
true_labels,
|
|
1021
|
+
clustering_methods_dict,
|
|
1022
|
+
num_epochs,
|
|
1023
|
+
lr,
|
|
1024
|
+
size,
|
|
1025
|
+
)
|
|
1026
|
+
|
|
1027
|
+
fig_suffix = "" if method == PREPROCESS_BASIC else f"_{method}"
|
|
1028
|
+
fig_path = (
|
|
1029
|
+
f"{save_path}/epoch{num_epochs}_lr{lr}_visualization{fig_suffix}.png"
|
|
1030
|
+
)
|
|
1031
|
+
fig.savefig(fig_path)
|
|
1032
|
+
all_figures[method] = fig
|
|
1033
|
+
|
|
1034
|
+
vis_time = time.time() - vis_start
|
|
1035
|
+
print(f" Visualization saved to {fig_path} in {vis_time:.2f}s")
|
|
1036
|
+
|
|
1037
|
+
method_time = time.time() - method_start
|
|
1038
|
+
print(f"Method {method} done in {method_time:.2f}s")
|
|
1039
|
+
|
|
1040
|
+
total_time = time.time() - start_time
|
|
1041
|
+
print(f"\nAll methods completed in {total_time:.2f}s")
|
|
1042
|
+
|
|
1043
|
+
return all_metrics, all_figures
|
|
1044
|
+
|
|
1045
|
+
|
|
1046
|
+
def plot_scatter_mutilbatch_tensorboard(*args, **kwargs):
|
|
1047
|
+
"""
|
|
1048
|
+
Deprecated: use evaluate_and_visualize(vis_methods=['basic']).
|
|
1049
|
+
"""
|
|
1050
|
+
warnings.warn(
|
|
1051
|
+
"plot_scatter_mutilbatch_tensorboard is deprecated; use evaluate_and_visualize",
|
|
1052
|
+
DeprecationWarning,
|
|
1053
|
+
)
|
|
1054
|
+
metrics, figs = evaluate_and_visualize(*args, vis_methods=["basic"], **kwargs)
|
|
1055
|
+
return metrics.get("basic", {}), figs.get("basic")
|
|
1056
|
+
|
|
1057
|
+
|
|
1058
|
+
def plot_scatter_mutilbatch_tensorboard(
|
|
1059
|
+
dataset,
|
|
1060
|
+
df,
|
|
1061
|
+
save_path,
|
|
1062
|
+
num_epochs,
|
|
1063
|
+
lr,
|
|
1064
|
+
n_clusters,
|
|
1065
|
+
size=15,
|
|
1066
|
+
visualize=True,
|
|
1067
|
+
graphs_number=None,
|
|
1068
|
+
clustering_methods=None,
|
|
1069
|
+
):
|
|
1070
|
+
"""
|
|
1071
|
+
Basic visualization using raw features.
|
|
1072
|
+
|
|
1073
|
+
Args:
|
|
1074
|
+
dataset: Dataset name.
|
|
1075
|
+
df: DataFrame.
|
|
1076
|
+
save_path: Output directory.
|
|
1077
|
+
num_epochs: Current epoch.
|
|
1078
|
+
lr: Learning rate.
|
|
1079
|
+
n_clusters: Number of clusters.
|
|
1080
|
+
size: Point size.
|
|
1081
|
+
visualize: Whether to create visualization images.
|
|
1082
|
+
graphs_number: Optional graph count suffix used in label file naming.
|
|
1083
|
+
clustering_methods: Optional list of clustering method names.
|
|
1084
|
+
|
|
1085
|
+
Returns:
|
|
1086
|
+
tuple: (metrics dict, figure if visualize=True)
|
|
1087
|
+
"""
|
|
1088
|
+
# Load location labels.
|
|
1089
|
+
df = load_location_data(df, dataset, graphs_number)
|
|
1090
|
+
|
|
1091
|
+
# Extract features and labels.
|
|
1092
|
+
features = df.drop(columns=["cell", "gene", "location"]).values
|
|
1093
|
+
true_labels = df["location"].astype(str).values
|
|
1094
|
+
|
|
1095
|
+
# Dimensionality reduction (only when visualization is requested).
|
|
1096
|
+
reduced_features_tsne = None
|
|
1097
|
+
reduced_features_umap = None
|
|
1098
|
+
if visualize:
|
|
1099
|
+
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
|
1100
|
+
reduced_features_tsne = tsne.fit_transform(features)
|
|
1101
|
+
|
|
1102
|
+
umap_model = _lazy_import_umap().UMAP(
|
|
1103
|
+
n_components=2, random_state=42, n_neighbors=15, min_dist=0.2
|
|
1104
|
+
)
|
|
1105
|
+
reduced_features_umap = umap_model.fit_transform(features)
|
|
1106
|
+
|
|
1107
|
+
# Clustering.
|
|
1108
|
+
clustering_methods_dict = apply_clustering(features, n_clusters, clustering_methods)
|
|
1109
|
+
|
|
1110
|
+
# Evaluation.
|
|
1111
|
+
metrics = evaluate_clustering(
|
|
1112
|
+
true_labels, clustering_methods_dict, df, save_path, num_epochs, lr
|
|
1113
|
+
)
|
|
1114
|
+
|
|
1115
|
+
# Visualization.
|
|
1116
|
+
if visualize:
|
|
1117
|
+
fig = visualize_clustering(
|
|
1118
|
+
reduced_features_tsne,
|
|
1119
|
+
reduced_features_umap,
|
|
1120
|
+
true_labels,
|
|
1121
|
+
clustering_methods_dict,
|
|
1122
|
+
num_epochs,
|
|
1123
|
+
lr,
|
|
1124
|
+
size,
|
|
1125
|
+
)
|
|
1126
|
+
fig_path = f"{save_path}/epoch{num_epochs}_lr{lr}_visualization.png"
|
|
1127
|
+
fig.savefig(fig_path)
|
|
1128
|
+
return metrics, fig
|
|
1129
|
+
else:
|
|
1130
|
+
return metrics, None
|
|
1131
|
+
|
|
1132
|
+
|
|
1133
|
+
def plot_scatter_mutilbatch_scaler_tensorboard(
|
|
1134
|
+
dataset,
|
|
1135
|
+
df,
|
|
1136
|
+
save_path,
|
|
1137
|
+
num_epochs,
|
|
1138
|
+
lr,
|
|
1139
|
+
n_clusters,
|
|
1140
|
+
size=15,
|
|
1141
|
+
visualize=True,
|
|
1142
|
+
graphs_number=None,
|
|
1143
|
+
clustering_methods=None,
|
|
1144
|
+
):
|
|
1145
|
+
"""
|
|
1146
|
+
Visualization using standardized features.
|
|
1147
|
+
|
|
1148
|
+
Args:
|
|
1149
|
+
Same as plot_scatter_mutilbatch_tensorboard.
|
|
1150
|
+
"""
|
|
1151
|
+
# Load location labels.
|
|
1152
|
+
df = load_location_data(df, dataset, graphs_number)
|
|
1153
|
+
|
|
1154
|
+
# Extract features and labels.
|
|
1155
|
+
features = df.drop(columns=["cell", "gene", "location"]).values
|
|
1156
|
+
scaler = StandardScaler()
|
|
1157
|
+
features = scaler.fit_transform(features)
|
|
1158
|
+
true_labels = df["location"].astype(str).values
|
|
1159
|
+
|
|
1160
|
+
# Dimensionality reduction (only when visualization is requested).
|
|
1161
|
+
reduced_features_tsne = None
|
|
1162
|
+
reduced_features_umap = None
|
|
1163
|
+
if visualize:
|
|
1164
|
+
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
|
1165
|
+
reduced_features_tsne = tsne.fit_transform(features)
|
|
1166
|
+
|
|
1167
|
+
umap_model = _lazy_import_umap().UMAP(
|
|
1168
|
+
n_components=2, random_state=42, n_neighbors=15, min_dist=0.2
|
|
1169
|
+
)
|
|
1170
|
+
reduced_features_umap = umap_model.fit_transform(features)
|
|
1171
|
+
|
|
1172
|
+
# Clustering.
|
|
1173
|
+
clustering_methods_dict = apply_clustering(features, n_clusters, clustering_methods)
|
|
1174
|
+
|
|
1175
|
+
# Evaluation.
|
|
1176
|
+
metrics = evaluate_clustering(
|
|
1177
|
+
true_labels, clustering_methods_dict, df, save_path, num_epochs, lr, "scaler"
|
|
1178
|
+
)
|
|
1179
|
+
|
|
1180
|
+
# Visualization.
|
|
1181
|
+
if visualize:
|
|
1182
|
+
fig = visualize_clustering(
|
|
1183
|
+
reduced_features_tsne,
|
|
1184
|
+
reduced_features_umap,
|
|
1185
|
+
true_labels,
|
|
1186
|
+
clustering_methods_dict,
|
|
1187
|
+
num_epochs,
|
|
1188
|
+
lr,
|
|
1189
|
+
size,
|
|
1190
|
+
)
|
|
1191
|
+
fig_path = f"{save_path}/epoch{num_epochs}_lr{lr}_visualization_scaler.png"
|
|
1192
|
+
fig.savefig(fig_path)
|
|
1193
|
+
return metrics, fig
|
|
1194
|
+
else:
|
|
1195
|
+
return metrics, None
|
|
1196
|
+
|
|
1197
|
+
|
|
1198
|
+
def plot_scatter_mutilbatch_pca_tensorboard(
|
|
1199
|
+
dataset,
|
|
1200
|
+
df,
|
|
1201
|
+
save_path,
|
|
1202
|
+
num_epochs,
|
|
1203
|
+
lr,
|
|
1204
|
+
n_clusters,
|
|
1205
|
+
size=15,
|
|
1206
|
+
visualize=True,
|
|
1207
|
+
graphs_number=None,
|
|
1208
|
+
clustering_methods=None,
|
|
1209
|
+
):
|
|
1210
|
+
"""
|
|
1211
|
+
Visualization with PCA-reduced features.
|
|
1212
|
+
|
|
1213
|
+
Args:
|
|
1214
|
+
Same as plot_scatter_mutilbatch_tensorboard.
|
|
1215
|
+
"""
|
|
1216
|
+
# Load location labels.
|
|
1217
|
+
df = load_location_data(df, dataset, graphs_number)
|
|
1218
|
+
|
|
1219
|
+
# Extract features and labels.
|
|
1220
|
+
features = df.drop(columns=["cell", "gene", "location"]).values
|
|
1221
|
+
scaler = StandardScaler()
|
|
1222
|
+
features = scaler.fit_transform(features)
|
|
1223
|
+
true_labels = df["location"].astype(str).values
|
|
1224
|
+
|
|
1225
|
+
# PCA.
|
|
1226
|
+
pca = PCA(n_components=0.95) # keep 95% variance
|
|
1227
|
+
features = pca.fit_transform(features)
|
|
1228
|
+
|
|
1229
|
+
# Dimensionality reduction for plotting (only when requested).
|
|
1230
|
+
reduced_features_tsne = None
|
|
1231
|
+
reduced_features_umap = None
|
|
1232
|
+
if visualize:
|
|
1233
|
+
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
|
1234
|
+
reduced_features_tsne = tsne.fit_transform(features)
|
|
1235
|
+
|
|
1236
|
+
umap_model = _lazy_import_umap().UMAP(
|
|
1237
|
+
n_components=2, random_state=42, n_neighbors=15, min_dist=0.2
|
|
1238
|
+
)
|
|
1239
|
+
reduced_features_umap = umap_model.fit_transform(features)
|
|
1240
|
+
|
|
1241
|
+
# Clustering.
|
|
1242
|
+
clustering_methods_dict = apply_clustering(features, n_clusters, clustering_methods)
|
|
1243
|
+
|
|
1244
|
+
# Evaluation.
|
|
1245
|
+
metrics = evaluate_clustering(
|
|
1246
|
+
true_labels, clustering_methods_dict, df, save_path, num_epochs, lr, "pca"
|
|
1247
|
+
)
|
|
1248
|
+
|
|
1249
|
+
# Visualization.
|
|
1250
|
+
if visualize:
|
|
1251
|
+
fig = visualize_clustering(
|
|
1252
|
+
reduced_features_tsne,
|
|
1253
|
+
reduced_features_umap,
|
|
1254
|
+
true_labels,
|
|
1255
|
+
clustering_methods_dict,
|
|
1256
|
+
num_epochs,
|
|
1257
|
+
lr,
|
|
1258
|
+
size,
|
|
1259
|
+
)
|
|
1260
|
+
fig_path = f"{save_path}/epoch{num_epochs}_lr{lr}_visualization_pca.png"
|
|
1261
|
+
fig.savefig(fig_path)
|
|
1262
|
+
return metrics, fig
|
|
1263
|
+
else:
|
|
1264
|
+
return metrics, None
|
|
1265
|
+
|
|
1266
|
+
|
|
1267
|
+
def plot_scatter_mutilbatch_select_tensorboard(
|
|
1268
|
+
dataset,
|
|
1269
|
+
df,
|
|
1270
|
+
save_path,
|
|
1271
|
+
num_epochs,
|
|
1272
|
+
lr,
|
|
1273
|
+
n_clusters,
|
|
1274
|
+
size=15,
|
|
1275
|
+
visualize=True,
|
|
1276
|
+
graphs_number=None,
|
|
1277
|
+
clustering_methods=None,
|
|
1278
|
+
):
|
|
1279
|
+
"""
|
|
1280
|
+
Visualization using feature selection.
|
|
1281
|
+
|
|
1282
|
+
Args:
|
|
1283
|
+
dataset: Dataset name.
|
|
1284
|
+
df: DataFrame.
|
|
1285
|
+
save_path: Output directory.
|
|
1286
|
+
num_epochs: Current epoch.
|
|
1287
|
+
lr: Learning rate.
|
|
1288
|
+
n_clusters: Number of clusters.
|
|
1289
|
+
size: Point size.
|
|
1290
|
+
visualize: Whether to create visualization images.
|
|
1291
|
+
graphs_number: Optional graph count suffix used in label file naming.
|
|
1292
|
+
clustering_methods: Optional list of clustering method names.
|
|
1293
|
+
|
|
1294
|
+
Returns:
|
|
1295
|
+
tuple: (metrics dict, figure if visualize=True)
|
|
1296
|
+
"""
|
|
1297
|
+
# Load location labels.
|
|
1298
|
+
df = load_location_data(df, dataset, graphs_number)
|
|
1299
|
+
|
|
1300
|
+
# Extract features and labels.
|
|
1301
|
+
features = df.drop(columns=["cell", "gene", "location"]).values
|
|
1302
|
+
scaler = StandardScaler()
|
|
1303
|
+
features = scaler.fit_transform(features)
|
|
1304
|
+
true_labels = df["location"].astype(str).values
|
|
1305
|
+
|
|
1306
|
+
# Feature selection.
|
|
1307
|
+
selector = SelectKBest(score_func=f_classif, k=50) # top-50 features
|
|
1308
|
+
features = selector.fit_transform(features, true_labels)
|
|
1309
|
+
|
|
1310
|
+
# Dimensionality reduction for plotting (only when requested).
|
|
1311
|
+
reduced_features_tsne = None
|
|
1312
|
+
reduced_features_umap = None
|
|
1313
|
+
if visualize:
|
|
1314
|
+
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
|
1315
|
+
reduced_features_tsne = tsne.fit_transform(features)
|
|
1316
|
+
|
|
1317
|
+
umap_model = _lazy_import_umap().UMAP(
|
|
1318
|
+
n_components=2, random_state=42, n_neighbors=15, min_dist=0.2
|
|
1319
|
+
)
|
|
1320
|
+
reduced_features_umap = umap_model.fit_transform(features)
|
|
1321
|
+
|
|
1322
|
+
# Clustering.
|
|
1323
|
+
clustering_methods_dict = apply_clustering(features, n_clusters, clustering_methods)
|
|
1324
|
+
|
|
1325
|
+
# Evaluation.
|
|
1326
|
+
metrics = evaluate_clustering(
|
|
1327
|
+
true_labels, clustering_methods_dict, df, save_path, num_epochs, lr, "select"
|
|
1328
|
+
)
|
|
1329
|
+
|
|
1330
|
+
# Visualization.
|
|
1331
|
+
if visualize:
|
|
1332
|
+
fig = visualize_clustering(
|
|
1333
|
+
reduced_features_tsne,
|
|
1334
|
+
reduced_features_umap,
|
|
1335
|
+
true_labels,
|
|
1336
|
+
clustering_methods_dict,
|
|
1337
|
+
num_epochs,
|
|
1338
|
+
lr,
|
|
1339
|
+
size,
|
|
1340
|
+
)
|
|
1341
|
+
fig_path = f"{save_path}/epoch{num_epochs}_lr{lr}_visualization_select.png"
|
|
1342
|
+
fig.savefig(fig_path)
|
|
1343
|
+
return metrics, fig
|
|
1344
|
+
else:
|
|
1345
|
+
return metrics, None
|
|
1346
|
+
|
|
1347
|
+
|
|
1348
|
+
def plot_scatter_nolabel(
|
|
1349
|
+
dataset,
|
|
1350
|
+
df,
|
|
1351
|
+
save_path,
|
|
1352
|
+
num_epochs,
|
|
1353
|
+
lr,
|
|
1354
|
+
n_clusters,
|
|
1355
|
+
size=15,
|
|
1356
|
+
visualize=True,
|
|
1357
|
+
clustering_methods=None,
|
|
1358
|
+
):
|
|
1359
|
+
"""
|
|
1360
|
+
Clustering and visualization without ground-truth labels.
|
|
1361
|
+
|
|
1362
|
+
Args:
|
|
1363
|
+
dataset: Dataset name.
|
|
1364
|
+
df: DataFrame.
|
|
1365
|
+
save_path: Output directory.
|
|
1366
|
+
num_epochs: Current epoch.
|
|
1367
|
+
lr: Learning rate.
|
|
1368
|
+
n_clusters: Number of clusters.
|
|
1369
|
+
size: Point size.
|
|
1370
|
+
visualize: Whether to create visualization images.
|
|
1371
|
+
clustering_methods: Optional list of clustering method names.
|
|
1372
|
+
"""
|
|
1373
|
+
df["location"] = "other"
|
|
1374
|
+
|
|
1375
|
+
features = df.drop(columns=["cell", "gene", "location"]).values
|
|
1376
|
+
scaler = StandardScaler()
|
|
1377
|
+
features = scaler.fit_transform(features)
|
|
1378
|
+
|
|
1379
|
+
# Cluster even if visualization is disabled.
|
|
1380
|
+
clustering_methods_dict = apply_clustering(features, n_clusters, clustering_methods)
|
|
1381
|
+
|
|
1382
|
+
# Save clustering results.
|
|
1383
|
+
for method, labels in clustering_methods_dict.items():
|
|
1384
|
+
df[method + "_cluster"] = labels
|
|
1385
|
+
|
|
1386
|
+
df.to_csv(f"{save_path}/epoch{num_epochs}_lr{lr}_clusters.csv", index=False)
|
|
1387
|
+
|
|
1388
|
+
# Visualization (only when requested).
|
|
1389
|
+
if visualize:
|
|
1390
|
+
# Dimensionality reduction.
|
|
1391
|
+
tsne = TSNE(n_components=2, random_state=42, perplexity=30)
|
|
1392
|
+
reduced_features_tsne = tsne.fit_transform(features)
|
|
1393
|
+
|
|
1394
|
+
umap_model = _lazy_import_umap().UMAP(
|
|
1395
|
+
n_components=2, random_state=42, n_neighbors=15, min_dist=0.2
|
|
1396
|
+
)
|
|
1397
|
+
reduced_features_umap = umap_model.fit_transform(features)
|
|
1398
|
+
|
|
1399
|
+
fig, axes = plt.subplots(
|
|
1400
|
+
len(clustering_methods_dict),
|
|
1401
|
+
2,
|
|
1402
|
+
figsize=(16, 4 * len(clustering_methods_dict)),
|
|
1403
|
+
)
|
|
1404
|
+
fig.suptitle(f"Epoch {num_epochs} lr={lr}", fontsize=16)
|
|
1405
|
+
|
|
1406
|
+
for i, (method, predicted_labels) in enumerate(clustering_methods_dict.items()):
|
|
1407
|
+
sns.scatterplot(
|
|
1408
|
+
x=reduced_features_tsne[:, 0],
|
|
1409
|
+
y=reduced_features_tsne[:, 1],
|
|
1410
|
+
hue=predicted_labels,
|
|
1411
|
+
palette="Set2",
|
|
1412
|
+
ax=axes[i, 0],
|
|
1413
|
+
s=size,
|
|
1414
|
+
legend="full",
|
|
1415
|
+
)
|
|
1416
|
+
axes[i, 0].set_title(f"t-SNE with {method} Clusters")
|
|
1417
|
+
axes[i, 0].legend(loc="upper left", bbox_to_anchor=(1, 1))
|
|
1418
|
+
|
|
1419
|
+
sns.scatterplot(
|
|
1420
|
+
x=reduced_features_umap[:, 0],
|
|
1421
|
+
y=reduced_features_umap[:, 1],
|
|
1422
|
+
hue=predicted_labels,
|
|
1423
|
+
palette="Set3",
|
|
1424
|
+
ax=axes[i, 1],
|
|
1425
|
+
s=size,
|
|
1426
|
+
legend="full",
|
|
1427
|
+
)
|
|
1428
|
+
axes[i, 1].set_title(f"UMAP with {method} Clusters")
|
|
1429
|
+
axes[i, 1].legend(loc="upper left", bbox_to_anchor=(1, 1))
|
|
1430
|
+
|
|
1431
|
+
plt.tight_layout()
|
|
1432
|
+
fig.savefig(f"{save_path}/epoch{num_epochs}_lr{lr}_visualization_nolabel.png")
|
|
1433
|
+
return fig
|
|
1434
|
+
return None
|
|
1435
|
+
|
|
1436
|
+
|
|
1437
|
+
def plot_embeddings_only(df, save_path, num_epochs, lr, visualize=False):
|
|
1438
|
+
"""
|
|
1439
|
+
Save embeddings only (no clustering/evaluation).
|
|
1440
|
+
|
|
1441
|
+
Args:
|
|
1442
|
+
df: DataFrame containing embeddings.
|
|
1443
|
+
save_path: Output directory.
|
|
1444
|
+
num_epochs: Current epoch.
|
|
1445
|
+
lr: Learning rate.
|
|
1446
|
+
visualize: Whether to create a quick embedding scatter plot.
|
|
1447
|
+
|
|
1448
|
+
Returns:
|
|
1449
|
+
fig: Figure if visualize=True, else None.
|
|
1450
|
+
"""
|
|
1451
|
+
# Save embeddings.
|
|
1452
|
+
df.to_csv(f"{save_path}/epoch{num_epochs}_lr{lr}_embedding.csv", index=False)
|
|
1453
|
+
print(f"Embeddings saved to {save_path}/epoch{num_epochs}_lr{lr}_embedding.csv")
|
|
1454
|
+
|
|
1455
|
+
# Create a simple embedding visualization on demand.
|
|
1456
|
+
if visualize:
|
|
1457
|
+
# Extract features.
|
|
1458
|
+
features = df.drop(columns=["cell", "gene"]).values
|
|
1459
|
+
|
|
1460
|
+
n_samples = features.shape[0]
|
|
1461
|
+
# Very small demos (e.g., smoke tests) can have only a handful of graphs.
|
|
1462
|
+
# t-SNE / UMAP can fail in these edge cases. Prefer skipping visualization
|
|
1463
|
+
# over crashing the whole training run.
|
|
1464
|
+
if n_samples < 3:
|
|
1465
|
+
print(f"Skip embedding visualization: too few samples (n={n_samples}).")
|
|
1466
|
+
return None
|
|
1467
|
+
|
|
1468
|
+
# Standardize.
|
|
1469
|
+
scaler = StandardScaler()
|
|
1470
|
+
features = scaler.fit_transform(features)
|
|
1471
|
+
|
|
1472
|
+
# Dimensionality reduction.
|
|
1473
|
+
tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, n_samples - 1))
|
|
1474
|
+
reduced_features_tsne = tsne.fit_transform(features)
|
|
1475
|
+
|
|
1476
|
+
# UMAP's default init='spectral' can fail when n is extremely small (k >= n).
|
|
1477
|
+
# Use init='random' for small n to keep smoke tests robust.
|
|
1478
|
+
umap_init = "random" if n_samples < 4 else "spectral"
|
|
1479
|
+
umap_model = _lazy_import_umap().UMAP(
|
|
1480
|
+
n_components=2,
|
|
1481
|
+
random_state=42,
|
|
1482
|
+
n_neighbors=min(15, n_samples - 1),
|
|
1483
|
+
min_dist=0.2,
|
|
1484
|
+
init=umap_init,
|
|
1485
|
+
)
|
|
1486
|
+
reduced_features_umap = umap_model.fit_transform(features)
|
|
1487
|
+
|
|
1488
|
+
# Simple scatter plots.
|
|
1489
|
+
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
|
|
1490
|
+
|
|
1491
|
+
# Color by cell.
|
|
1492
|
+
sns.scatterplot(
|
|
1493
|
+
x=reduced_features_tsne[:, 0],
|
|
1494
|
+
y=reduced_features_tsne[:, 1],
|
|
1495
|
+
hue=df["cell"],
|
|
1496
|
+
ax=axes[0],
|
|
1497
|
+
s=15,
|
|
1498
|
+
legend="auto",
|
|
1499
|
+
)
|
|
1500
|
+
axes[0].set_title("t-SNE Embedding by Cell Type")
|
|
1501
|
+
axes[0].legend(loc="upper left", bbox_to_anchor=(1, 1))
|
|
1502
|
+
|
|
1503
|
+
# Color by gene.
|
|
1504
|
+
sns.scatterplot(
|
|
1505
|
+
x=reduced_features_umap[:, 0],
|
|
1506
|
+
y=reduced_features_umap[:, 1],
|
|
1507
|
+
hue=df["gene"],
|
|
1508
|
+
ax=axes[1],
|
|
1509
|
+
s=15,
|
|
1510
|
+
legend="auto",
|
|
1511
|
+
)
|
|
1512
|
+
axes[1].set_title("UMAP Embedding by Gene Type")
|
|
1513
|
+
axes[1].legend(loc="upper left", bbox_to_anchor=(1, 1))
|
|
1514
|
+
|
|
1515
|
+
plt.tight_layout()
|
|
1516
|
+
fig.savefig(
|
|
1517
|
+
f"{save_path}/epoch{num_epochs}_lr{lr}_embeddings_visualization.png"
|
|
1518
|
+
)
|
|
1519
|
+
return fig
|
|
1520
|
+
|
|
1521
|
+
return None
|
|
1522
|
+
|
|
1523
|
+
|
|
1524
|
+
def plot_loss_weights(num_epochs, weights_a, weights_b, weights_c, save_path, lr):
|
|
1525
|
+
"""
|
|
1526
|
+
Plot dynamic loss weights over epochs.
|
|
1527
|
+
|
|
1528
|
+
Args:
|
|
1529
|
+
num_epochs: Total number of epochs.
|
|
1530
|
+
weights_a: Reconstruction loss weights.
|
|
1531
|
+
weights_b: Contrastive loss weights.
|
|
1532
|
+
weights_c: Clustering loss weights.
|
|
1533
|
+
save_path: Output directory.
|
|
1534
|
+
lr: Learning rate.
|
|
1535
|
+
"""
|
|
1536
|
+
plt.figure(figsize=(6, 4))
|
|
1537
|
+
plt.plot(
|
|
1538
|
+
range(num_epochs),
|
|
1539
|
+
weights_a,
|
|
1540
|
+
label="Reconstruction Loss Weight (a)",
|
|
1541
|
+
color="blue",
|
|
1542
|
+
)
|
|
1543
|
+
plt.plot(
|
|
1544
|
+
range(num_epochs),
|
|
1545
|
+
weights_b,
|
|
1546
|
+
label="Contrastive Loss Weight (b)",
|
|
1547
|
+
color="orange",
|
|
1548
|
+
)
|
|
1549
|
+
plt.plot(
|
|
1550
|
+
range(num_epochs), weights_c, label="Clustering Loss Weight (c)", color="green"
|
|
1551
|
+
)
|
|
1552
|
+
plt.xlabel("Epoch")
|
|
1553
|
+
plt.ylabel("Weight")
|
|
1554
|
+
plt.title("Dynamic Loss Weights Over Epochs")
|
|
1555
|
+
plt.legend()
|
|
1556
|
+
plt.savefig(f"{save_path}/epoch{num_epochs}_lr{lr}_plot_learning_rate.png")
|