grasp-tool 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1556 @@
1
+ import torch
2
+ import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ from sklearn.decomposition import PCA
5
+ from sklearn.manifold import TSNE
6
+ from sklearn.cluster import KMeans, AgglomerativeClustering, DBSCAN
7
+ import seaborn as sns
8
+ from sklearn.manifold import TSNE
9
+ from sklearn.metrics import (
10
+ accuracy_score,
11
+ adjusted_rand_score,
12
+ normalized_mutual_info_score,
13
+ )
14
+ from scipy.stats import mode
15
+ import igraph as ig
16
+ from sklearn.metrics import adjusted_rand_score, normalized_mutual_info_score
17
+ from leidenalg import find_partition
18
+ import igraph as ig
19
+ from sklearn.metrics.pairwise import cosine_similarity
20
+ import numpy as np
21
+ import igraph as ig
22
+ from sklearn.metrics import classification_report
23
+ from sklearn.metrics import classification_report, accuracy_score
24
+ from scipy.optimize import linear_sum_assignment
25
+ import numpy as np
26
+ from sklearn.cluster import SpectralClustering
27
+ from sklearn.mixture import GaussianMixture
28
+ from sklearn.preprocessing import StandardScaler
29
+ from sklearn.feature_selection import SelectKBest, f_classif
30
+ import os
31
+ import warnings
32
+ import time
33
+ from typing import List, Dict, Tuple, Optional, Union, Any
34
+
35
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
36
+ warnings.filterwarnings("ignore", category=UserWarning)
37
+
38
+
39
+ def _lazy_import_umap():
40
+ """Import `umap` only when actually needed.
41
+
42
+ Some environments may pull in TensorFlow when importing `umap`, which can
43
+ emit noisy logs. Avoid importing it at module import time.
44
+ """
45
+
46
+ try:
47
+ os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
48
+ import umap # type: ignore
49
+
50
+ return umap
51
+ except ModuleNotFoundError as e:
52
+ raise ModuleNotFoundError(
53
+ "Missing optional dependency `umap-learn` (import name: `umap`).\n"
54
+ "Install: pip install umap-learn"
55
+ ) from e
56
+
57
+
58
+ ########################################################
59
+ # Unified evaluation + visualization utilities.
60
+ #
61
+ # The main entrypoint is evaluate_and_visualize(...), which replaces older
62
+ # per-mode plotting helpers.
63
+ ########################################################
64
+
65
+ # Preprocessing method names.
66
+ PREPROCESS_BASIC = "basic" # Basic
67
+ PREPROCESS_SCALER = "scaler" # StandardScaler
68
+ PREPROCESS_PCA = "pca" # PCA
69
+ PREPROCESS_SELECT = "select" # Feature selection
70
+
71
+ # All supported preprocessing methods.
72
+ ALL_PREPROCESS_METHODS = [
73
+ PREPROCESS_BASIC,
74
+ PREPROCESS_SCALER,
75
+ PREPROCESS_PCA,
76
+ PREPROCESS_SELECT,
77
+ ]
78
+
79
+ # TODO: performance/memory improvements for TSNE/UMAP on large datasets.
80
+
81
+
82
+ def compute_metrics(true_labels, predicted_labels):
83
+ """Compute clustering metrics.
84
+
85
+ Returns ARI (adjusted rand index) and NMI (normalized mutual information).
86
+ """
87
+ ari = adjusted_rand_score(true_labels, predicted_labels)
88
+ nmi = normalized_mutual_info_score(true_labels, predicted_labels)
89
+ return ari, nmi
90
+
91
+
92
+ def map_labels_with_hungarian(true_labels, predicted_labels):
93
+ """Map predicted cluster labels to true labels via Hungarian matching."""
94
+ # Ensure numpy arrays.
95
+ true_labels = np.asarray(true_labels)
96
+ predicted_labels = np.asarray(predicted_labels)
97
+
98
+ # Unique label values.
99
+ true_classes = np.unique(true_labels)
100
+ predicted_classes = np.unique(predicted_labels)
101
+
102
+ # Build cost matrix.
103
+ n_true = len(true_classes)
104
+ n_pred = len(predicted_classes)
105
+ max_size = max(n_true, n_pred)
106
+ cost_matrix = np.zeros((max_size, max_size))
107
+
108
+ # Fill cost matrix (negative overlap -> Hungarian solves min-cost).
109
+ for i, t_class in enumerate(true_classes):
110
+ for j, p_class in enumerate(predicted_classes):
111
+ # Overlap between (t_class, p_class)
112
+ match = np.sum((true_labels == t_class) & (predicted_labels == p_class))
113
+ cost_matrix[i, j] = -match
114
+
115
+ # Pad extra rows/cols with a large cost.
116
+ if n_true < max_size or n_pred < max_size:
117
+ large_val = abs(np.min(cost_matrix)) * 10
118
+ for i in range(n_true, max_size):
119
+ cost_matrix[i, :] = large_val
120
+ for j in range(n_pred, max_size):
121
+ cost_matrix[:, j] = large_val
122
+
123
+ # Hungarian matching
124
+ row_ind, col_ind = linear_sum_assignment(cost_matrix)
125
+
126
+ # Build mapping
127
+ label_mapping = {}
128
+ for row, col in zip(row_ind, col_ind):
129
+ if row < n_true and col < n_pred: # Valid match
130
+ label_mapping[predicted_classes[col]] = true_classes[row]
131
+
132
+ # Handle unmatched predicted labels (if any)
133
+ unmatched_pred_labels = set(predicted_classes) - set(label_mapping.keys())
134
+ if unmatched_pred_labels:
135
+ # Find best-effort mapping for unmatched labels
136
+ for unmatched_label in unmatched_pred_labels:
137
+ # Overlap with each true class
138
+ overlaps = []
139
+ for t_class in true_classes:
140
+ overlap = np.sum(
141
+ (predicted_labels == unmatched_label) & (true_labels == t_class)
142
+ )
143
+ overlaps.append((t_class, overlap))
144
+
145
+ # Assign to the most overlapping true class
146
+ best_class = max(overlaps, key=lambda x: x[1])[0]
147
+ label_mapping[unmatched_label] = best_class
148
+
149
+ # Mapped label array
150
+ mapped_labels = np.array([label_mapping[label] for label in predicted_labels])
151
+
152
+ return mapped_labels, label_mapping
153
+
154
+
155
+ def compute_metrics_with_classification_report(true_labels, predicted_labels):
156
+ """Compute metrics from sklearn's classification_report.
157
+
158
+ Returns: (accuracy, precision, recall, f1_score) using macro average.
159
+ """
160
+ report = classification_report(
161
+ true_labels, predicted_labels, output_dict=True, zero_division="warn"
162
+ )
163
+ accuracy = report.get("accuracy", 0.0)
164
+
165
+ macro_avg_report = report.get("macro avg", {})
166
+ precision = macro_avg_report.get("precision", 0.0)
167
+ recall = macro_avg_report.get("recall", 0.0)
168
+ f1_score = macro_avg_report.get("f1-score", 0.0)
169
+ return accuracy, precision, recall, f1_score
170
+
171
+
172
+ def load_location_data(df, dataset, graphs_number=None, specific_label_file=None):
173
+ """
174
+ Load location/label data and merge it into the input DataFrame.
175
+
176
+ Args:
177
+ df: Input DataFrame.
178
+ dataset: Dataset name.
179
+ graphs_number: Optional graph count suffix used in label file naming.
180
+ specific_label_file: Optional label filename or an absolute path.
181
+
182
+ Returns:
183
+ DataFrame: DataFrame with an added/updated 'location' column.
184
+ """
185
+ # Priority order: label column names from high to low.
186
+ priority_label_cols = [
187
+ "groundtruth_wzx",
188
+ "groundtruth",
189
+ "label",
190
+ "location",
191
+ "cluster",
192
+ "category",
193
+ "type",
194
+ ]
195
+
196
+ # Warn if required columns are missing; do not early-return.
197
+ if "cell" not in df.columns or "gene" not in df.columns:
198
+ print(
199
+ f"Warning: DataFrame missing required columns ('cell' and/or 'gene'), merge may fail"
200
+ )
201
+
202
+ # Candidate base paths for label files.
203
+ base_paths = [
204
+ "../1_input/label",
205
+ "../1_input/label_annotation",
206
+ "../../GCN_CL/1_input/label",
207
+ "../../GCN_CL/1_input/label_annotation",
208
+ "./1_input/label",
209
+ "./1_input/label_annotation",
210
+ ]
211
+
212
+ # Try to locate a label file.
213
+ label_file = None
214
+
215
+ # Allow absolute paths.
216
+ # 1) If user provided an absolute path and it exists, use it directly.
217
+ if (
218
+ specific_label_file
219
+ and os.path.isabs(specific_label_file)
220
+ and os.path.exists(specific_label_file)
221
+ ):
222
+ label_file = specific_label_file
223
+ print(f"Using absolute label file path: {label_file}")
224
+ # 2) If user provided a relative filename, search under base_paths.
225
+ elif specific_label_file:
226
+ for base_path in base_paths:
227
+ path = f"{base_path}/{specific_label_file}"
228
+ if os.path.exists(path):
229
+ label_file = path
230
+ print(f"Using specified label file: {label_file}")
231
+ break
232
+
233
+ # If no specific file was found, try common naming patterns.
234
+ if label_file is None:
235
+ for base_path in base_paths:
236
+ possible_files = [
237
+ f"{base_path}/{dataset}_label.csv",
238
+ f"{base_path}/{dataset}_labeled.csv",
239
+ ]
240
+ if graphs_number:
241
+ possible_files.extend(
242
+ [
243
+ f"{base_path}/{dataset}_graph{graphs_number}_label.csv",
244
+ f"{base_path}/{dataset}_graph{graphs_number}_labeled.csv",
245
+ ]
246
+ )
247
+
248
+ for file_path in possible_files:
249
+ if os.path.exists(file_path):
250
+ label_file = file_path
251
+ print(f"Found label file: {label_file}")
252
+ break
253
+
254
+ if label_file:
255
+ break
256
+
257
+ # If a label file is found, load and process it.
258
+ if label_file:
259
+ try:
260
+ # Read label file.
261
+ label_df = pd.read_csv(label_file)
262
+
263
+ # Print label file columns for debugging.
264
+ print(f"Label file columns: {label_df.columns.tolist()}")
265
+
266
+ # Find label columns by priority order.
267
+ found_label_cols = []
268
+ for col in priority_label_cols:
269
+ if col in label_df.columns:
270
+ found_label_cols.append(col)
271
+ print(
272
+ f"Found label column: '{col}' (priority {priority_label_cols.index(col) + 1})"
273
+ )
274
+
275
+ if found_label_cols:
276
+ primary_label_col = found_label_cols[0] # highest-priority column
277
+ print(f"Using '{primary_label_col}' as primary label column")
278
+ else:
279
+ print("Warning: No recognized label column found in label file")
280
+ primary_label_col = None
281
+
282
+ # Robustly infer label file format.
283
+ # Long format: has 'cell' and 'gene' columns (+ one or more label columns).
284
+ # Wide format: one row per cell, many gene columns.
285
+
286
+ if "cell" in label_df.columns and "gene" in label_df.columns:
287
+ print(f"Detected long format label file")
288
+
289
+ # Keep only required columns.
290
+ keep_cols = ["cell", "gene"] + found_label_cols
291
+ # Only keep columns that exist.
292
+ label_df = label_df[
293
+ [col for col in keep_cols if col in label_df.columns]
294
+ ]
295
+
296
+ # Merge into input dataframe.
297
+ try:
298
+ merged_df = df.merge(label_df, on=["gene", "cell"], how="left")
299
+ print(f"Merged on both 'gene' and 'cell' columns")
300
+ except Exception as e:
301
+ print(f"Error in gene+cell merge: {e}, trying cell-only merge")
302
+ try:
303
+ # Fallback to merging only on 'cell'.
304
+ merged_df = df.merge(label_df, on=["cell"], how="left")
305
+ print(f"Merged on 'cell' column only")
306
+ except Exception as e2:
307
+ print(f"All merge attempts failed: {e2}")
308
+ # Fallback to default location.
309
+ df["location"] = "unknown"
310
+ return df
311
+
312
+ # Normalize the primary label column to 'location' (if needed).
313
+ if primary_label_col and primary_label_col != "location":
314
+ if primary_label_col in merged_df.columns:
315
+ print(f"Renaming '{primary_label_col}' to 'location'")
316
+ merged_df["location"] = merged_df[primary_label_col]
317
+
318
+ # Ensure 'location' exists.
319
+ if "location" not in merged_df.columns:
320
+ print("Creating 'location' column from highest priority label")
321
+ if found_label_cols:
322
+ merged_df["location"] = merged_df[found_label_cols[0]]
323
+ else:
324
+ print("Warning: No label columns found, using 'unknown'")
325
+ merged_df["location"] = "unknown"
326
+
327
+ merged_df["location"] = merged_df["location"].fillna("unknown")
328
+ print(f"Merged data, final shape: {merged_df.shape}")
329
+
330
+ # Print label distribution.
331
+ label_counts = merged_df["location"].value_counts()
332
+ print(f"Label distribution (top 10): \n{label_counts.head(10)}")
333
+
334
+ return merged_df
335
+
336
+ # Wide format: row is a cell, columns are genes (and possibly metadata).
337
+ elif "cell" in label_df.columns and len(label_df.columns) > 2:
338
+ print(
339
+ f"Detected wide format label file, columns: {label_df.columns.tolist()}"
340
+ )
341
+
342
+ # Confirm it looks like a typical wide format with gene columns.
343
+ non_gene_cols = ["cell", "Cell"] + priority_label_cols
344
+ potential_gene_cols = [
345
+ col for col in label_df.columns if col not in non_gene_cols
346
+ ]
347
+
348
+ # Heuristic: treat as wide format if there are many gene columns.
349
+ if len(potential_gene_cols) > 5: # assume at least 5 gene columns
350
+ print(
351
+ f"Converting wide format to long format with {len(potential_gene_cols)} gene columns"
352
+ )
353
+
354
+ # Preserve non-gene columns for later merge.
355
+ cell_info_cols = []
356
+ if primary_label_col:
357
+ cell_info_cols = [primary_label_col]
358
+ else:
359
+ # If no label column is found, keep all non-gene columns.
360
+ cell_info_cols = [
361
+ col
362
+ for col in non_gene_cols
363
+ if col in label_df.columns and col != "cell"
364
+ ]
365
+
366
+ cell_info_df = None
367
+ if cell_info_cols:
368
+ cell_info_df = label_df[["cell"] + cell_info_cols].copy()
369
+ print(f"Preserved cell information columns: {cell_info_cols}")
370
+
371
+ # Convert wide -> long.
372
+ long_df = pd.melt(
373
+ label_df,
374
+ id_vars="cell",
375
+ value_vars=potential_gene_cols,
376
+ var_name="gene",
377
+ value_name="gene_value",
378
+ )
379
+
380
+ # Merge into input dataframe.
381
+ try:
382
+ merged_df = df.merge(long_df, on=["gene", "cell"], how="left")
383
+ print(f"Merged gene-cell data")
384
+ except Exception as e:
385
+ print(f"Error in gene-cell merge for wide format: {e}")
386
+ # Fallback to default location.
387
+ df["location"] = "unknown"
388
+ return df
389
+
390
+ # Merge preserved cell info columns (if any).
391
+ if cell_info_df is not None:
392
+ try:
393
+ merged_df = merged_df.merge(
394
+ cell_info_df, on="cell", how="left"
395
+ )
396
+ print(f"Added cell information columns")
397
+ except Exception as e:
398
+ print(f"Error merging cell info: {e}")
399
+
400
+ # Set 'location' column.
401
+ if primary_label_col and primary_label_col in merged_df.columns:
402
+ merged_df["location"] = merged_df[primary_label_col]
403
+ print(f"Set 'location' from '{primary_label_col}'")
404
+ else:
405
+ # Fallback to 'gene_value'.
406
+ merged_df["location"] = merged_df["gene_value"]
407
+ print(f"Set 'location' from gene_value column")
408
+
409
+ # Ensure 'location' has values.
410
+ merged_df["location"] = merged_df["location"].fillna("unknown")
411
+ print(
412
+ f"Merged data from wide format, final shape: {merged_df.shape}"
413
+ )
414
+
415
+ # Print label distribution.
416
+ label_counts = merged_df["location"].value_counts()
417
+ print(f"Label distribution (top 10): \n{label_counts.head(10)}")
418
+
419
+ return merged_df
420
+ else:
421
+ print(
422
+ f"Label file has 'cell' column but doesn't appear to be a typical wide format"
423
+ )
424
+
425
+ # Try merging based on 'cell'.
426
+ try:
427
+ # Use discovered label columns.
428
+ useful_cols = ["cell"]
429
+ if found_label_cols:
430
+ useful_cols.extend(found_label_cols)
431
+ else:
432
+ # If no label columns are found, keep all likely label cols.
433
+ useful_cols.extend(
434
+ [
435
+ col
436
+ for col in label_df.columns
437
+ if col in priority_label_cols
438
+ or col not in non_gene_cols
439
+ ]
440
+ )
441
+
442
+ # Only keep existing columns.
443
+ label_subset = label_df[
444
+ [col for col in useful_cols if col in label_df.columns]
445
+ ].copy()
446
+
447
+ # Merge.
448
+ merged_df = df.merge(label_subset, on="cell", how="left")
449
+
450
+ # Set 'location' column.
451
+ if primary_label_col and primary_label_col in merged_df.columns:
452
+ merged_df["location"] = merged_df[primary_label_col]
453
+ print(f"Set 'location' from '{primary_label_col}'")
454
+ else:
455
+ # If no primary label column is available, use default.
456
+ print("No primary label column available")
457
+ merged_df["location"] = "unknown"
458
+
459
+ merged_df["location"] = merged_df["location"].fillna("unknown")
460
+
461
+ print(
462
+ f"Merged data using cell-based join, final shape: {merged_df.shape}"
463
+ )
464
+
465
+ # Print label distribution.
466
+ label_counts = merged_df["location"].value_counts()
467
+ print(f"Label distribution (top 10): \n{label_counts.head(10)}")
468
+
469
+ return merged_df
470
+ except Exception as e:
471
+ print(f"Error in cell-based merge: {e}")
472
+
473
+ # Unrecognized format.
474
+ print(f"Unrecognized label file format, using default 'unknown' location")
475
+ df["location"] = "unknown"
476
+ return df
477
+
478
+ except Exception as e:
479
+ print(f"Error processing label file {label_file}: {e}")
480
+ import traceback
481
+
482
+ traceback.print_exc()
483
+
484
+ # If all attempts fail, add a default 'location' column.
485
+ print(
486
+ "Could not find or process a suitable label file, using default 'unknown' location"
487
+ )
488
+ if "location" not in df.columns:
489
+ df["location"] = "unknown"
490
+
491
+ return df
492
+
493
+
494
+ def apply_clustering(features, n_clusters, clustering_methods=None):
495
+ """
496
+ Apply multiple clustering methods.
497
+
498
+ Args:
499
+ features: Feature matrix.
500
+ n_clusters: Number of clusters.
501
+ clustering_methods: Optional list of method names to run.
502
+
503
+ Returns:
504
+ clustering_methods: Dict of method name -> cluster labels.
505
+ """
506
+ all_methods = {
507
+ "KMeans": KMeans(n_clusters=n_clusters, random_state=2025)
508
+ .fit(features)
509
+ .labels_,
510
+ "Agglomerative": AgglomerativeClustering(n_clusters=n_clusters).fit_predict(
511
+ features
512
+ ),
513
+ "SpectralClustering": SpectralClustering(
514
+ n_clusters=n_clusters, random_state=2025, affinity="nearest_neighbors"
515
+ ).fit_predict(features),
516
+ "GaussianMixture": GaussianMixture(n_components=n_clusters, random_state=2025)
517
+ .fit(features)
518
+ .predict(features),
519
+ }
520
+
521
+ # If methods are specified, return only those.
522
+ if clustering_methods is not None:
523
+ return {k: v for k, v in all_methods.items() if k in clustering_methods}
524
+
525
+ # Otherwise return all methods.
526
+ return all_methods
527
+
528
+
529
+ def evaluate_clustering(
530
+ true_labels, clustering_methods, df, save_path, num_epochs, lr, suffix=""
531
+ ):
532
+ """
533
+ Evaluate clustering results.
534
+
535
+ Args:
536
+ true_labels: Ground-truth labels.
537
+ clustering_methods: Dict of method name -> predicted labels.
538
+ df: DataFrame.
539
+ save_path: Output directory.
540
+ num_epochs: Current epoch.
541
+ lr: Learning rate.
542
+ suffix: Filename suffix.
543
+
544
+ Returns:
545
+ metrics: Dict of evaluation metrics.
546
+ """
547
+ # Check label distribution.
548
+ unique_labels = np.unique(true_labels)
549
+ unique_labels_count = len(unique_labels)
550
+ print("\nLabel distribution:")
551
+ print(f"Unique label count: {unique_labels_count}")
552
+ print(
553
+ f"Label values: {unique_labels[:10]}{'...' if len(unique_labels) > 10 else ''}"
554
+ )
555
+
556
+ # If all samples share one label, ARI/NMI become 0 by definition.
557
+ if unique_labels_count == 1:
558
+ print(f"Warning: all samples share the same label '{unique_labels[0]}'")
559
+ print(
560
+ "This makes ARI/NMI equal to 0 and can make accuracy-like metrics appear high, "
561
+ "because there is no class variation to compare."
562
+ )
563
+ print("Please check the label file or load_location_data() processing.\n")
564
+ elif unique_labels_count < 2:
565
+ print("Error: too few unique labels for meaningful clustering evaluation")
566
+ return {"error": "insufficient_labels"}
567
+
568
+ metrics = {}
569
+ for method, predicted_labels in clustering_methods.items():
570
+ mapped_labels, label_mapping = map_labels_with_hungarian(
571
+ true_labels, predicted_labels
572
+ )
573
+ accuracy, precision, recall, f1_score = (
574
+ compute_metrics_with_classification_report(true_labels, mapped_labels)
575
+ )
576
+ print(f"{method} Label Mapping: {label_mapping}")
577
+
578
+ ari, nmi = compute_metrics(true_labels, predicted_labels)
579
+ metrics[method] = {
580
+ "ARI": ari,
581
+ "NMI": nmi,
582
+ "Accuracy": accuracy,
583
+ "Precision": precision,
584
+ "Recall": recall,
585
+ "F1-Score": f1_score,
586
+ "Label Mapping": label_mapping,
587
+ }
588
+ df[method + "_cluster"] = predicted_labels
589
+
590
+ print(f"Epoch {num_epochs} lr={lr} suffix={suffix} clustering metrics:")
591
+ for method, values in metrics.items():
592
+ print(
593
+ f"{method}: ARI = {values['ARI']:.4f}, NMI = {values['NMI']:.4f}, "
594
+ f"Accuracy = {values['Accuracy']:.4f}, Precision = {values['Precision']:.4f}, "
595
+ f"Recall = {values['Recall']:.4f}, F1-Score = {values['F1-Score']:.4f}"
596
+ )
597
+
598
+ # Save clustering outputs.
599
+ if suffix:
600
+ df.to_csv(
601
+ f"{save_path}/epoch{num_epochs}_lr{lr}_clusters_{suffix}.csv", index=False
602
+ )
603
+
604
+ # Save metrics.
605
+ with open(
606
+ f"{save_path}/epoch{num_epochs}_lr{lr}_metrics_{suffix}.txt", "w"
607
+ ) as f:
608
+ for method, values in metrics.items():
609
+ f.write(
610
+ f"{method}:\nARI = {values['ARI']:.4f}, NMI = {values['NMI']:.4f},\n "
611
+ f"Accuracy = {values['Accuracy']:.4f}, Precision = {values['Precision']:.4f},\n"
612
+ f"Recall = {values['Recall']:.4f}, F1-Score = {values['F1-Score']:.4f}\n"
613
+ )
614
+ f.write(f"{method} Label Mapping: {values['Label Mapping']}\n")
615
+ f.write(f"\n")
616
+ else:
617
+ df.to_csv(f"{save_path}/epoch{num_epochs}_lr{lr}_clusters.csv", index=False)
618
+
619
+ # Save metrics.
620
+ with open(f"{save_path}/epoch{num_epochs}_lr{lr}_metrics.txt", "w") as f:
621
+ for method, values in metrics.items():
622
+ f.write(
623
+ f"{method}:\nARI = {values['ARI']:.4f}, NMI = {values['NMI']:.4f},\n "
624
+ f"Accuracy = {values['Accuracy']:.4f}, Precision = {values['Precision']:.4f},\n"
625
+ f"Recall = {values['Recall']:.4f}, F1-Score = {values['F1-Score']:.4f}\n"
626
+ )
627
+ f.write(f"{method} Label Mapping: {values['Label Mapping']}\n")
628
+ f.write(f"\n")
629
+
630
+ return metrics
631
+
632
+
633
+ def visualize_clustering(
634
+ reduced_features_tsne,
635
+ reduced_features_umap,
636
+ true_labels,
637
+ clustering_methods,
638
+ num_epochs,
639
+ lr,
640
+ size=15,
641
+ ):
642
+ """
643
+ Visualize clustering results.
644
+
645
+ Args:
646
+ reduced_features_tsne: t-SNE 2D features.
647
+ reduced_features_umap: UMAP 2D features.
648
+ true_labels: Ground-truth labels.
649
+ clustering_methods: Dict of method name -> predicted labels.
650
+ num_epochs: Current epoch.
651
+ lr: Learning rate.
652
+ size: Point size.
653
+
654
+ Returns:
655
+ fig: Matplotlib figure.
656
+ """
657
+ fig, axes = plt.subplots(
658
+ len(clustering_methods) + 1, 2, figsize=(16, 4 * (len(clustering_methods) + 1))
659
+ )
660
+ fig.suptitle(f"Epoch {num_epochs} lr={lr}", fontsize=16)
661
+
662
+ # True labels.
663
+ sns.scatterplot(
664
+ x=reduced_features_tsne[:, 0],
665
+ y=reduced_features_tsne[:, 1],
666
+ hue=true_labels,
667
+ palette="Set1",
668
+ ax=axes[0, 0],
669
+ s=size,
670
+ legend="full",
671
+ )
672
+ axes[0, 0].set_title("t-SNE with True Labels")
673
+ axes[0, 0].legend(loc="upper left", bbox_to_anchor=(1, 1))
674
+
675
+ sns.scatterplot(
676
+ x=reduced_features_umap[:, 0],
677
+ y=reduced_features_umap[:, 1],
678
+ hue=true_labels,
679
+ palette="Set1",
680
+ ax=axes[0, 1],
681
+ s=size,
682
+ legend="full",
683
+ )
684
+ axes[0, 1].set_title("UMAP with True Labels")
685
+ axes[0, 1].legend(loc="upper left", bbox_to_anchor=(1, 1))
686
+
687
+ # Predicted clusters.
688
+ for i, (method, predicted_labels) in enumerate(clustering_methods.items(), start=1):
689
+ sns.scatterplot(
690
+ x=reduced_features_tsne[:, 0],
691
+ y=reduced_features_tsne[:, 1],
692
+ hue=predicted_labels,
693
+ palette="Set2",
694
+ ax=axes[i, 0],
695
+ s=size,
696
+ legend="full",
697
+ )
698
+ axes[i, 0].set_title(f"t-SNE with {method} Clusters")
699
+ axes[i, 0].legend(loc="upper left", bbox_to_anchor=(1, 1))
700
+
701
+ sns.scatterplot(
702
+ x=reduced_features_umap[:, 0],
703
+ y=reduced_features_umap[:, 1],
704
+ hue=predicted_labels,
705
+ palette="Set3",
706
+ ax=axes[i, 1],
707
+ s=size,
708
+ legend="full",
709
+ )
710
+ axes[i, 1].set_title(f"UMAP with {method} Clusters")
711
+ axes[i, 1].legend(loc="upper left", bbox_to_anchor=(1, 1))
712
+
713
+ plt.tight_layout()
714
+ return fig
715
+
716
+
717
+ def preprocess_features(df, preprocess_method="basic", true_labels=None):
718
+ """
719
+ Preprocess features according to the given method.
720
+
721
+ Args:
722
+ df: DataFrame.
723
+ preprocess_method: One of 'basic', 'scaler', 'pca', 'select'.
724
+ true_labels: Labels used by feature selection.
725
+
726
+ Returns:
727
+ features: Preprocessed feature matrix.
728
+ """
729
+ # Exclude known non-feature columns.
730
+ non_feature_cols = [
731
+ "cell",
732
+ "gene",
733
+ "graph_id",
734
+ "original_graph_id",
735
+ "augmented_graph_id",
736
+ "location",
737
+ "groundtruth",
738
+ "groundtruth_wzx",
739
+ "label",
740
+ "cluster",
741
+ "category",
742
+ "type",
743
+ "SCT_cluster",
744
+ "cell_type",
745
+ "celltype", # additional common non-feature column name
746
+ # Also exclude columns generated by clustering methods.
747
+ "KMeans_cluster",
748
+ "Agglomerative_cluster",
749
+ "SpectralClustering_cluster",
750
+ "GaussianMixture_cluster",
751
+ ]
752
+
753
+ # Candidate feature columns.
754
+ candidate_feature_names = [
755
+ col
756
+ for col in df.columns
757
+ if col not in non_feature_cols and not str(col).endswith("_cluster")
758
+ ]
759
+
760
+ if not candidate_feature_names:
761
+ print(
762
+ f"Warning: No candidate feature columns found after excluding non_feature_cols for method '{preprocess_method}'. Original columns: {df.columns.tolist()}"
763
+ )
764
+ return np.array([])
765
+
766
+ # Work on a copy to safely attempt numeric conversion.
767
+ df_candidate_features = df[candidate_feature_names].copy()
768
+
769
+ actually_numeric_cols = []
770
+ for col_name in df_candidate_features.columns:
771
+ try:
772
+ # Coerce column to numeric; non-convertible values become NaN.
773
+ converted_series = pd.to_numeric(
774
+ df_candidate_features[col_name], errors="coerce"
775
+ )
776
+
777
+ # Keep numeric columns that are not all-NaN after coercion.
778
+ if (
779
+ pd.api.types.is_numeric_dtype(converted_series)
780
+ and not converted_series.isnull().all()
781
+ ):
782
+ # Heuristic: if coercion wipes most values, it's likely not a true numeric feature.
783
+ original_non_na_count = (
784
+ df_candidate_features[col_name].dropna().shape[0]
785
+ )
786
+ converted_non_na_count = converted_series.dropna().shape[0]
787
+ if original_non_na_count > 0 and (
788
+ converted_non_na_count / original_non_na_count < 0.5
789
+ ):
790
+ print(
791
+ f"Warning: Column '{col_name}' lost significant data after numeric coercion ({original_non_na_count} -> {converted_non_na_count} non-NaNs). Excluding."
792
+ )
793
+ else:
794
+ actually_numeric_cols.append(col_name)
795
+ # Replace with numeric series to help downstream .values.
796
+ df_candidate_features[col_name] = converted_series
797
+ else:
798
+ print(
799
+ f"Warning: Column '{col_name}' excluded. Not purely numeric or became all NaNs after coercion."
800
+ )
801
+ except Exception as e:
802
+ print(
803
+ f"Warning: Column '{col_name}' encountered an error during numeric conversion ('{e}') and will be excluded."
804
+ )
805
+
806
+ if not actually_numeric_cols:
807
+ print(
808
+ f"Error: No numeric feature columns remaining after filtering for method '{preprocess_method}'."
809
+ )
810
+ return np.array([])
811
+
812
+ # Extract features from the converted dataframe.
813
+ features = df_candidate_features[actually_numeric_cols].values
814
+
815
+ # Ensure float dtype.
816
+ if features.dtype == np.object_:
817
+ print(
818
+ f"Warning: Features array for method '{preprocess_method}' has dtype 'object' before scaling. Attempting astype(float)."
819
+ )
820
+ try:
821
+ features = features.astype(float)
822
+ except ValueError as e_astype:
823
+ error_msg = (
824
+ f"Failed to convert object-type features to float for method '{preprocess_method}'. "
825
+ f"Problematic data likely still exists. Original error: {e_astype}"
826
+ )
827
+ print(f"Error: {error_msg}")
828
+ # If this happens, upstream data likely contains non-numeric features.
829
+ # return np.full(features.shape, np.nan)
830
+ raise ValueError(error_msg) from e_astype
831
+
832
+ # Apply preprocessing.
833
+ if preprocess_method == "basic":
834
+ pass # no-op
835
+ elif preprocess_method == "scaler":
836
+ if features.size == 0:
837
+ print(
838
+ f"Warning: No features to scale for method '{preprocess_method}'. Skipping scaling."
839
+ )
840
+ return features # avoid StandardScaler errors
841
+ scaler = StandardScaler()
842
+ features = scaler.fit_transform(features)
843
+ elif preprocess_method == "pca":
844
+ if features.size == 0:
845
+ print(
846
+ f"Warning: No features for PCA for method '{preprocess_method}'. Skipping PCA."
847
+ )
848
+ return features
849
+ scaler = StandardScaler()
850
+ features = scaler.fit_transform(features)
851
+ pca = PCA(n_components=0.95) # keep 95% variance
852
+ features = pca.fit_transform(features)
853
+ print(f" PCA applied. Features shape after PCA: {features.shape}")
854
+ elif preprocess_method == "select":
855
+ if features.size == 0:
856
+ print(
857
+ f"Warning: No features for selection for method '{preprocess_method}'. Skipping selection."
858
+ )
859
+ return features
860
+ scaler = StandardScaler()
861
+ features = scaler.fit_transform(features)
862
+ if true_labels is not None:
863
+ selector = SelectKBest(score_func=f_classif, k=min(50, features.shape[1]))
864
+ features = selector.fit_transform(features, true_labels)
865
+ else:
866
+ raise ValueError(f"Unsupported preprocess method: {preprocess_method}")
867
+
868
+ return features
869
+
870
+
871
+ def evaluate_and_visualize(
872
+ dataset,
873
+ df,
874
+ save_path,
875
+ num_epochs,
876
+ lr,
877
+ n_clusters=8,
878
+ size=15,
879
+ vis_methods=None,
880
+ visualize=True,
881
+ graphs_number=None,
882
+ clustering_methods=None,
883
+ reduce_dims=True,
884
+ specific_label_file=None,
885
+ tsne_perplexity: float = 30.0,
886
+ umap_n_neighbors: int = 15,
887
+ umap_min_dist: float = 0.2,
888
+ ):
889
+ """
890
+ Evaluate clustering results and optionally create visualizations.
891
+
892
+ Args:
893
+ dataset: Dataset name.
894
+ df: DataFrame containing embeddings.
895
+ save_path: Output directory.
896
+ num_epochs: Current epoch.
897
+ lr: Learning rate.
898
+ n_clusters: Number of clusters.
899
+ size: Point size.
900
+ vis_methods: List of visualization/preprocess methods.
901
+ visualize: Whether to generate visualization images.
902
+ graphs_number: Optional graph count suffix used in label file naming.
903
+ clustering_methods: Optional list of clustering method names.
904
+ reduce_dims: Whether to run dimensionality reduction.
905
+ specific_label_file: Optional label filename or absolute path.
906
+ tsne_perplexity: t-SNE perplexity.
907
+ umap_n_neighbors: UMAP n_neighbors.
908
+ umap_min_dist: UMAP min_dist.
909
+
910
+ Returns:
911
+ Tuple[Dict, Dict]: (metrics for all methods, figures for all methods)
912
+ """
913
+ # Timing.
914
+ start_time = time.time()
915
+
916
+ # Normalize visualization methods.
917
+ if vis_methods is None:
918
+ # Default: run all.
919
+ vis_methods = ALL_PREPROCESS_METHODS
920
+ elif isinstance(vis_methods, str):
921
+ # Normalize a single string into a list.
922
+ vis_methods = [vis_methods]
923
+
924
+ # Validate visualization methods.
925
+ for method in vis_methods:
926
+ if method not in ALL_PREPROCESS_METHODS:
927
+ raise ValueError(
928
+ f"Unsupported visualization method: {method}. Available: {ALL_PREPROCESS_METHODS}"
929
+ )
930
+
931
+ # Load location/label data.
932
+ df = load_location_data(df, dataset, graphs_number, specific_label_file)
933
+
934
+ # Extract labels.
935
+ true_labels = df["location"].astype(str).values
936
+
937
+ # Initialize result dicts.
938
+ all_metrics = {}
939
+ all_figures = {}
940
+
941
+ # If no methods are requested, return empty results.
942
+ if not vis_methods:
943
+ print("No visualization methods specified. Returning empty results.")
944
+ return all_metrics, all_figures
945
+
946
+ # Track total method count for progress printing.
947
+ total_methods = len(vis_methods)
948
+
949
+ # Process each method.
950
+ for i, method in enumerate(vis_methods, 1):
951
+ method_start = time.time()
952
+ print(f"\n[{i}/{total_methods}] Start processing method: {method}")
953
+
954
+ # Preprocess features.
955
+ features = preprocess_features(df, method, true_labels)
956
+
957
+ # Dimensionality reduction (only when requested).
958
+ reduced_features_tsne = None
959
+ reduced_features_umap = None
960
+
961
+ if visualize and reduce_dims:
962
+ dim_start = time.time()
963
+ print(" Running dimensionality reduction...")
964
+ try:
965
+ tsne = TSNE(
966
+ n_components=2,
967
+ random_state=42,
968
+ perplexity=min(tsne_perplexity, len(df) - 1),
969
+ )
970
+ reduced_features_tsne = tsne.fit_transform(features)
971
+ except Exception as e:
972
+ print(f" t-SNE reduction failed: {e}")
973
+ reduced_features_tsne = np.zeros((len(df), 2)) # blank fallback
974
+
975
+ try:
976
+ umap_model = _lazy_import_umap().UMAP(
977
+ n_components=2,
978
+ random_state=42,
979
+ n_neighbors=min(umap_n_neighbors, len(df) - 1),
980
+ min_dist=umap_min_dist,
981
+ )
982
+ reduced_features_umap = umap_model.fit_transform(features)
983
+ except Exception as e:
984
+ print(f" UMAP reduction failed: {e}")
985
+ reduced_features_umap = np.zeros((len(df), 2)) # blank fallback
986
+
987
+ dim_time = time.time() - dim_start
988
+ print(f" Dimensionality reduction done in {dim_time:.2f}s")
989
+
990
+ # Clustering.
991
+ cluster_start = time.time()
992
+ clustering_methods_dict = apply_clustering(
993
+ features, n_clusters, clustering_methods
994
+ )
995
+ cluster_time = time.time() - cluster_start
996
+ print(f" Clustering done in {cluster_time:.2f}s")
997
+
998
+ # Evaluation.
999
+ eval_start = time.time()
1000
+ metrics = evaluate_clustering(
1001
+ true_labels,
1002
+ clustering_methods_dict,
1003
+ df,
1004
+ save_path,
1005
+ num_epochs,
1006
+ lr,
1007
+ suffix=method if method != PREPROCESS_BASIC else "",
1008
+ )
1009
+ all_metrics[method] = metrics
1010
+ eval_time = time.time() - eval_start
1011
+ print(f" Evaluation done in {eval_time:.2f}s")
1012
+
1013
+ # Visualization.
1014
+ if visualize and reduce_dims:
1015
+ vis_start = time.time()
1016
+ print(" Creating visualization...")
1017
+ fig = visualize_clustering(
1018
+ reduced_features_tsne,
1019
+ reduced_features_umap,
1020
+ true_labels,
1021
+ clustering_methods_dict,
1022
+ num_epochs,
1023
+ lr,
1024
+ size,
1025
+ )
1026
+
1027
+ fig_suffix = "" if method == PREPROCESS_BASIC else f"_{method}"
1028
+ fig_path = (
1029
+ f"{save_path}/epoch{num_epochs}_lr{lr}_visualization{fig_suffix}.png"
1030
+ )
1031
+ fig.savefig(fig_path)
1032
+ all_figures[method] = fig
1033
+
1034
+ vis_time = time.time() - vis_start
1035
+ print(f" Visualization saved to {fig_path} in {vis_time:.2f}s")
1036
+
1037
+ method_time = time.time() - method_start
1038
+ print(f"Method {method} done in {method_time:.2f}s")
1039
+
1040
+ total_time = time.time() - start_time
1041
+ print(f"\nAll methods completed in {total_time:.2f}s")
1042
+
1043
+ return all_metrics, all_figures
1044
+
1045
+
1046
+ def plot_scatter_mutilbatch_tensorboard(*args, **kwargs):
1047
+ """
1048
+ Deprecated: use evaluate_and_visualize(vis_methods=['basic']).
1049
+ """
1050
+ warnings.warn(
1051
+ "plot_scatter_mutilbatch_tensorboard is deprecated; use evaluate_and_visualize",
1052
+ DeprecationWarning,
1053
+ )
1054
+ metrics, figs = evaluate_and_visualize(*args, vis_methods=["basic"], **kwargs)
1055
+ return metrics.get("basic", {}), figs.get("basic")
1056
+
1057
+
1058
+ def plot_scatter_mutilbatch_tensorboard(
1059
+ dataset,
1060
+ df,
1061
+ save_path,
1062
+ num_epochs,
1063
+ lr,
1064
+ n_clusters,
1065
+ size=15,
1066
+ visualize=True,
1067
+ graphs_number=None,
1068
+ clustering_methods=None,
1069
+ ):
1070
+ """
1071
+ Basic visualization using raw features.
1072
+
1073
+ Args:
1074
+ dataset: Dataset name.
1075
+ df: DataFrame.
1076
+ save_path: Output directory.
1077
+ num_epochs: Current epoch.
1078
+ lr: Learning rate.
1079
+ n_clusters: Number of clusters.
1080
+ size: Point size.
1081
+ visualize: Whether to create visualization images.
1082
+ graphs_number: Optional graph count suffix used in label file naming.
1083
+ clustering_methods: Optional list of clustering method names.
1084
+
1085
+ Returns:
1086
+ tuple: (metrics dict, figure if visualize=True)
1087
+ """
1088
+ # Load location labels.
1089
+ df = load_location_data(df, dataset, graphs_number)
1090
+
1091
+ # Extract features and labels.
1092
+ features = df.drop(columns=["cell", "gene", "location"]).values
1093
+ true_labels = df["location"].astype(str).values
1094
+
1095
+ # Dimensionality reduction (only when visualization is requested).
1096
+ reduced_features_tsne = None
1097
+ reduced_features_umap = None
1098
+ if visualize:
1099
+ tsne = TSNE(n_components=2, random_state=42, perplexity=30)
1100
+ reduced_features_tsne = tsne.fit_transform(features)
1101
+
1102
+ umap_model = _lazy_import_umap().UMAP(
1103
+ n_components=2, random_state=42, n_neighbors=15, min_dist=0.2
1104
+ )
1105
+ reduced_features_umap = umap_model.fit_transform(features)
1106
+
1107
+ # Clustering.
1108
+ clustering_methods_dict = apply_clustering(features, n_clusters, clustering_methods)
1109
+
1110
+ # Evaluation.
1111
+ metrics = evaluate_clustering(
1112
+ true_labels, clustering_methods_dict, df, save_path, num_epochs, lr
1113
+ )
1114
+
1115
+ # Visualization.
1116
+ if visualize:
1117
+ fig = visualize_clustering(
1118
+ reduced_features_tsne,
1119
+ reduced_features_umap,
1120
+ true_labels,
1121
+ clustering_methods_dict,
1122
+ num_epochs,
1123
+ lr,
1124
+ size,
1125
+ )
1126
+ fig_path = f"{save_path}/epoch{num_epochs}_lr{lr}_visualization.png"
1127
+ fig.savefig(fig_path)
1128
+ return metrics, fig
1129
+ else:
1130
+ return metrics, None
1131
+
1132
+
1133
+ def plot_scatter_mutilbatch_scaler_tensorboard(
1134
+ dataset,
1135
+ df,
1136
+ save_path,
1137
+ num_epochs,
1138
+ lr,
1139
+ n_clusters,
1140
+ size=15,
1141
+ visualize=True,
1142
+ graphs_number=None,
1143
+ clustering_methods=None,
1144
+ ):
1145
+ """
1146
+ Visualization using standardized features.
1147
+
1148
+ Args:
1149
+ Same as plot_scatter_mutilbatch_tensorboard.
1150
+ """
1151
+ # Load location labels.
1152
+ df = load_location_data(df, dataset, graphs_number)
1153
+
1154
+ # Extract features and labels.
1155
+ features = df.drop(columns=["cell", "gene", "location"]).values
1156
+ scaler = StandardScaler()
1157
+ features = scaler.fit_transform(features)
1158
+ true_labels = df["location"].astype(str).values
1159
+
1160
+ # Dimensionality reduction (only when visualization is requested).
1161
+ reduced_features_tsne = None
1162
+ reduced_features_umap = None
1163
+ if visualize:
1164
+ tsne = TSNE(n_components=2, random_state=42, perplexity=30)
1165
+ reduced_features_tsne = tsne.fit_transform(features)
1166
+
1167
+ umap_model = _lazy_import_umap().UMAP(
1168
+ n_components=2, random_state=42, n_neighbors=15, min_dist=0.2
1169
+ )
1170
+ reduced_features_umap = umap_model.fit_transform(features)
1171
+
1172
+ # Clustering.
1173
+ clustering_methods_dict = apply_clustering(features, n_clusters, clustering_methods)
1174
+
1175
+ # Evaluation.
1176
+ metrics = evaluate_clustering(
1177
+ true_labels, clustering_methods_dict, df, save_path, num_epochs, lr, "scaler"
1178
+ )
1179
+
1180
+ # Visualization.
1181
+ if visualize:
1182
+ fig = visualize_clustering(
1183
+ reduced_features_tsne,
1184
+ reduced_features_umap,
1185
+ true_labels,
1186
+ clustering_methods_dict,
1187
+ num_epochs,
1188
+ lr,
1189
+ size,
1190
+ )
1191
+ fig_path = f"{save_path}/epoch{num_epochs}_lr{lr}_visualization_scaler.png"
1192
+ fig.savefig(fig_path)
1193
+ return metrics, fig
1194
+ else:
1195
+ return metrics, None
1196
+
1197
+
1198
+ def plot_scatter_mutilbatch_pca_tensorboard(
1199
+ dataset,
1200
+ df,
1201
+ save_path,
1202
+ num_epochs,
1203
+ lr,
1204
+ n_clusters,
1205
+ size=15,
1206
+ visualize=True,
1207
+ graphs_number=None,
1208
+ clustering_methods=None,
1209
+ ):
1210
+ """
1211
+ Visualization with PCA-reduced features.
1212
+
1213
+ Args:
1214
+ Same as plot_scatter_mutilbatch_tensorboard.
1215
+ """
1216
+ # Load location labels.
1217
+ df = load_location_data(df, dataset, graphs_number)
1218
+
1219
+ # Extract features and labels.
1220
+ features = df.drop(columns=["cell", "gene", "location"]).values
1221
+ scaler = StandardScaler()
1222
+ features = scaler.fit_transform(features)
1223
+ true_labels = df["location"].astype(str).values
1224
+
1225
+ # PCA.
1226
+ pca = PCA(n_components=0.95) # keep 95% variance
1227
+ features = pca.fit_transform(features)
1228
+
1229
+ # Dimensionality reduction for plotting (only when requested).
1230
+ reduced_features_tsne = None
1231
+ reduced_features_umap = None
1232
+ if visualize:
1233
+ tsne = TSNE(n_components=2, random_state=42, perplexity=30)
1234
+ reduced_features_tsne = tsne.fit_transform(features)
1235
+
1236
+ umap_model = _lazy_import_umap().UMAP(
1237
+ n_components=2, random_state=42, n_neighbors=15, min_dist=0.2
1238
+ )
1239
+ reduced_features_umap = umap_model.fit_transform(features)
1240
+
1241
+ # Clustering.
1242
+ clustering_methods_dict = apply_clustering(features, n_clusters, clustering_methods)
1243
+
1244
+ # Evaluation.
1245
+ metrics = evaluate_clustering(
1246
+ true_labels, clustering_methods_dict, df, save_path, num_epochs, lr, "pca"
1247
+ )
1248
+
1249
+ # Visualization.
1250
+ if visualize:
1251
+ fig = visualize_clustering(
1252
+ reduced_features_tsne,
1253
+ reduced_features_umap,
1254
+ true_labels,
1255
+ clustering_methods_dict,
1256
+ num_epochs,
1257
+ lr,
1258
+ size,
1259
+ )
1260
+ fig_path = f"{save_path}/epoch{num_epochs}_lr{lr}_visualization_pca.png"
1261
+ fig.savefig(fig_path)
1262
+ return metrics, fig
1263
+ else:
1264
+ return metrics, None
1265
+
1266
+
1267
+ def plot_scatter_mutilbatch_select_tensorboard(
1268
+ dataset,
1269
+ df,
1270
+ save_path,
1271
+ num_epochs,
1272
+ lr,
1273
+ n_clusters,
1274
+ size=15,
1275
+ visualize=True,
1276
+ graphs_number=None,
1277
+ clustering_methods=None,
1278
+ ):
1279
+ """
1280
+ Visualization using feature selection.
1281
+
1282
+ Args:
1283
+ dataset: Dataset name.
1284
+ df: DataFrame.
1285
+ save_path: Output directory.
1286
+ num_epochs: Current epoch.
1287
+ lr: Learning rate.
1288
+ n_clusters: Number of clusters.
1289
+ size: Point size.
1290
+ visualize: Whether to create visualization images.
1291
+ graphs_number: Optional graph count suffix used in label file naming.
1292
+ clustering_methods: Optional list of clustering method names.
1293
+
1294
+ Returns:
1295
+ tuple: (metrics dict, figure if visualize=True)
1296
+ """
1297
+ # Load location labels.
1298
+ df = load_location_data(df, dataset, graphs_number)
1299
+
1300
+ # Extract features and labels.
1301
+ features = df.drop(columns=["cell", "gene", "location"]).values
1302
+ scaler = StandardScaler()
1303
+ features = scaler.fit_transform(features)
1304
+ true_labels = df["location"].astype(str).values
1305
+
1306
+ # Feature selection.
1307
+ selector = SelectKBest(score_func=f_classif, k=50) # top-50 features
1308
+ features = selector.fit_transform(features, true_labels)
1309
+
1310
+ # Dimensionality reduction for plotting (only when requested).
1311
+ reduced_features_tsne = None
1312
+ reduced_features_umap = None
1313
+ if visualize:
1314
+ tsne = TSNE(n_components=2, random_state=42, perplexity=30)
1315
+ reduced_features_tsne = tsne.fit_transform(features)
1316
+
1317
+ umap_model = _lazy_import_umap().UMAP(
1318
+ n_components=2, random_state=42, n_neighbors=15, min_dist=0.2
1319
+ )
1320
+ reduced_features_umap = umap_model.fit_transform(features)
1321
+
1322
+ # Clustering.
1323
+ clustering_methods_dict = apply_clustering(features, n_clusters, clustering_methods)
1324
+
1325
+ # Evaluation.
1326
+ metrics = evaluate_clustering(
1327
+ true_labels, clustering_methods_dict, df, save_path, num_epochs, lr, "select"
1328
+ )
1329
+
1330
+ # Visualization.
1331
+ if visualize:
1332
+ fig = visualize_clustering(
1333
+ reduced_features_tsne,
1334
+ reduced_features_umap,
1335
+ true_labels,
1336
+ clustering_methods_dict,
1337
+ num_epochs,
1338
+ lr,
1339
+ size,
1340
+ )
1341
+ fig_path = f"{save_path}/epoch{num_epochs}_lr{lr}_visualization_select.png"
1342
+ fig.savefig(fig_path)
1343
+ return metrics, fig
1344
+ else:
1345
+ return metrics, None
1346
+
1347
+
1348
+ def plot_scatter_nolabel(
1349
+ dataset,
1350
+ df,
1351
+ save_path,
1352
+ num_epochs,
1353
+ lr,
1354
+ n_clusters,
1355
+ size=15,
1356
+ visualize=True,
1357
+ clustering_methods=None,
1358
+ ):
1359
+ """
1360
+ Clustering and visualization without ground-truth labels.
1361
+
1362
+ Args:
1363
+ dataset: Dataset name.
1364
+ df: DataFrame.
1365
+ save_path: Output directory.
1366
+ num_epochs: Current epoch.
1367
+ lr: Learning rate.
1368
+ n_clusters: Number of clusters.
1369
+ size: Point size.
1370
+ visualize: Whether to create visualization images.
1371
+ clustering_methods: Optional list of clustering method names.
1372
+ """
1373
+ df["location"] = "other"
1374
+
1375
+ features = df.drop(columns=["cell", "gene", "location"]).values
1376
+ scaler = StandardScaler()
1377
+ features = scaler.fit_transform(features)
1378
+
1379
+ # Cluster even if visualization is disabled.
1380
+ clustering_methods_dict = apply_clustering(features, n_clusters, clustering_methods)
1381
+
1382
+ # Save clustering results.
1383
+ for method, labels in clustering_methods_dict.items():
1384
+ df[method + "_cluster"] = labels
1385
+
1386
+ df.to_csv(f"{save_path}/epoch{num_epochs}_lr{lr}_clusters.csv", index=False)
1387
+
1388
+ # Visualization (only when requested).
1389
+ if visualize:
1390
+ # Dimensionality reduction.
1391
+ tsne = TSNE(n_components=2, random_state=42, perplexity=30)
1392
+ reduced_features_tsne = tsne.fit_transform(features)
1393
+
1394
+ umap_model = _lazy_import_umap().UMAP(
1395
+ n_components=2, random_state=42, n_neighbors=15, min_dist=0.2
1396
+ )
1397
+ reduced_features_umap = umap_model.fit_transform(features)
1398
+
1399
+ fig, axes = plt.subplots(
1400
+ len(clustering_methods_dict),
1401
+ 2,
1402
+ figsize=(16, 4 * len(clustering_methods_dict)),
1403
+ )
1404
+ fig.suptitle(f"Epoch {num_epochs} lr={lr}", fontsize=16)
1405
+
1406
+ for i, (method, predicted_labels) in enumerate(clustering_methods_dict.items()):
1407
+ sns.scatterplot(
1408
+ x=reduced_features_tsne[:, 0],
1409
+ y=reduced_features_tsne[:, 1],
1410
+ hue=predicted_labels,
1411
+ palette="Set2",
1412
+ ax=axes[i, 0],
1413
+ s=size,
1414
+ legend="full",
1415
+ )
1416
+ axes[i, 0].set_title(f"t-SNE with {method} Clusters")
1417
+ axes[i, 0].legend(loc="upper left", bbox_to_anchor=(1, 1))
1418
+
1419
+ sns.scatterplot(
1420
+ x=reduced_features_umap[:, 0],
1421
+ y=reduced_features_umap[:, 1],
1422
+ hue=predicted_labels,
1423
+ palette="Set3",
1424
+ ax=axes[i, 1],
1425
+ s=size,
1426
+ legend="full",
1427
+ )
1428
+ axes[i, 1].set_title(f"UMAP with {method} Clusters")
1429
+ axes[i, 1].legend(loc="upper left", bbox_to_anchor=(1, 1))
1430
+
1431
+ plt.tight_layout()
1432
+ fig.savefig(f"{save_path}/epoch{num_epochs}_lr{lr}_visualization_nolabel.png")
1433
+ return fig
1434
+ return None
1435
+
1436
+
1437
+ def plot_embeddings_only(df, save_path, num_epochs, lr, visualize=False):
1438
+ """
1439
+ Save embeddings only (no clustering/evaluation).
1440
+
1441
+ Args:
1442
+ df: DataFrame containing embeddings.
1443
+ save_path: Output directory.
1444
+ num_epochs: Current epoch.
1445
+ lr: Learning rate.
1446
+ visualize: Whether to create a quick embedding scatter plot.
1447
+
1448
+ Returns:
1449
+ fig: Figure if visualize=True, else None.
1450
+ """
1451
+ # Save embeddings.
1452
+ df.to_csv(f"{save_path}/epoch{num_epochs}_lr{lr}_embedding.csv", index=False)
1453
+ print(f"Embeddings saved to {save_path}/epoch{num_epochs}_lr{lr}_embedding.csv")
1454
+
1455
+ # Create a simple embedding visualization on demand.
1456
+ if visualize:
1457
+ # Extract features.
1458
+ features = df.drop(columns=["cell", "gene"]).values
1459
+
1460
+ n_samples = features.shape[0]
1461
+ # Very small demos (e.g., smoke tests) can have only a handful of graphs.
1462
+ # t-SNE / UMAP can fail in these edge cases. Prefer skipping visualization
1463
+ # over crashing the whole training run.
1464
+ if n_samples < 3:
1465
+ print(f"Skip embedding visualization: too few samples (n={n_samples}).")
1466
+ return None
1467
+
1468
+ # Standardize.
1469
+ scaler = StandardScaler()
1470
+ features = scaler.fit_transform(features)
1471
+
1472
+ # Dimensionality reduction.
1473
+ tsne = TSNE(n_components=2, random_state=42, perplexity=min(30, n_samples - 1))
1474
+ reduced_features_tsne = tsne.fit_transform(features)
1475
+
1476
+ # UMAP's default init='spectral' can fail when n is extremely small (k >= n).
1477
+ # Use init='random' for small n to keep smoke tests robust.
1478
+ umap_init = "random" if n_samples < 4 else "spectral"
1479
+ umap_model = _lazy_import_umap().UMAP(
1480
+ n_components=2,
1481
+ random_state=42,
1482
+ n_neighbors=min(15, n_samples - 1),
1483
+ min_dist=0.2,
1484
+ init=umap_init,
1485
+ )
1486
+ reduced_features_umap = umap_model.fit_transform(features)
1487
+
1488
+ # Simple scatter plots.
1489
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6))
1490
+
1491
+ # Color by cell.
1492
+ sns.scatterplot(
1493
+ x=reduced_features_tsne[:, 0],
1494
+ y=reduced_features_tsne[:, 1],
1495
+ hue=df["cell"],
1496
+ ax=axes[0],
1497
+ s=15,
1498
+ legend="auto",
1499
+ )
1500
+ axes[0].set_title("t-SNE Embedding by Cell Type")
1501
+ axes[0].legend(loc="upper left", bbox_to_anchor=(1, 1))
1502
+
1503
+ # Color by gene.
1504
+ sns.scatterplot(
1505
+ x=reduced_features_umap[:, 0],
1506
+ y=reduced_features_umap[:, 1],
1507
+ hue=df["gene"],
1508
+ ax=axes[1],
1509
+ s=15,
1510
+ legend="auto",
1511
+ )
1512
+ axes[1].set_title("UMAP Embedding by Gene Type")
1513
+ axes[1].legend(loc="upper left", bbox_to_anchor=(1, 1))
1514
+
1515
+ plt.tight_layout()
1516
+ fig.savefig(
1517
+ f"{save_path}/epoch{num_epochs}_lr{lr}_embeddings_visualization.png"
1518
+ )
1519
+ return fig
1520
+
1521
+ return None
1522
+
1523
+
1524
+ def plot_loss_weights(num_epochs, weights_a, weights_b, weights_c, save_path, lr):
1525
+ """
1526
+ Plot dynamic loss weights over epochs.
1527
+
1528
+ Args:
1529
+ num_epochs: Total number of epochs.
1530
+ weights_a: Reconstruction loss weights.
1531
+ weights_b: Contrastive loss weights.
1532
+ weights_c: Clustering loss weights.
1533
+ save_path: Output directory.
1534
+ lr: Learning rate.
1535
+ """
1536
+ plt.figure(figsize=(6, 4))
1537
+ plt.plot(
1538
+ range(num_epochs),
1539
+ weights_a,
1540
+ label="Reconstruction Loss Weight (a)",
1541
+ color="blue",
1542
+ )
1543
+ plt.plot(
1544
+ range(num_epochs),
1545
+ weights_b,
1546
+ label="Contrastive Loss Weight (b)",
1547
+ color="orange",
1548
+ )
1549
+ plt.plot(
1550
+ range(num_epochs), weights_c, label="Clustering Loss Weight (c)", color="green"
1551
+ )
1552
+ plt.xlabel("Epoch")
1553
+ plt.ylabel("Weight")
1554
+ plt.title("Dynamic Loss Weights Over Epochs")
1555
+ plt.legend()
1556
+ plt.savefig(f"{save_path}/epoch{num_epochs}_lr{lr}_plot_learning_rate.png")