chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,808 @@
1
+ """
2
+ A module for identifying spatial domains in spatial transcriptomics data.
3
+
4
+ This module provides an interface to several algorithms designed to partition
5
+ spatial data into distinct domains based on gene expression and spatial proximity.
6
+ It includes graph-based clustering methods (SpaGCN, STAGATE) and standard clustering
7
+ algorithms (Leiden, Louvain) adapted for spatial data. The primary entry point is the `identify_spatial_domains`
8
+ function, which handles data preparation and dispatches to the selected method.
9
+ """
10
+
11
+ from collections import Counter
12
+ from typing import TYPE_CHECKING, Any
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+ import scanpy as sc
17
+
18
+ if TYPE_CHECKING:
19
+ from ..spatial_mcp_adapter import ToolContext
20
+
21
+ from ..models.analysis import SpatialDomainResult
22
+ from ..models.data import SpatialDomainParameters
23
+ from ..utils.adata_utils import (
24
+ ensure_categorical,
25
+ get_spatial_key,
26
+ require_spatial_coords,
27
+ )
28
+ from ..utils.compute import ensure_neighbors, ensure_pca
29
+ from ..utils.dependency_manager import require
30
+ from ..utils.device_utils import get_device, resolve_device_async
31
+ from ..utils.exceptions import (
32
+ DataError,
33
+ DataNotFoundError,
34
+ ParameterError,
35
+ ProcessingError,
36
+ )
37
+
38
+
39
+ async def identify_spatial_domains(
40
+ data_id: str,
41
+ ctx: "ToolContext",
42
+ params: SpatialDomainParameters = SpatialDomainParameters(),
43
+ ) -> SpatialDomainResult:
44
+ """
45
+ Identifies spatial domains by clustering spots based on gene expression and location.
46
+
47
+ This function serves as the main entry point for various spatial domain
48
+ identification methods. It performs initial data validation and preparation,
49
+ including checks for required preprocessing steps like normalization and
50
+ highly variable gene selection. It then calls the specific algorithm
51
+ requested by the user. The resulting domain labels are stored back in the
52
+ AnnData object.
53
+
54
+ Args:
55
+ data_id: The identifier for the dataset.
56
+ ctx: The unified ToolContext for data access and logging.
57
+ params: An object containing parameters for the analysis, including the
58
+ method to use and its specific settings.
59
+
60
+ Returns:
61
+ A SpatialDomainResult object containing the identified domains and
62
+ associated metadata.
63
+ """
64
+ # COW FIX: Direct reference instead of copy
65
+ # Only add metadata to adata.obs/obsm/obsp, never overwrite entire adata
66
+ adata = await ctx.get_adata(data_id)
67
+
68
+ try:
69
+ # Check if spatial coordinates exist
70
+ spatial_key = get_spatial_key(adata)
71
+ if spatial_key is None:
72
+ raise DataNotFoundError("No spatial coordinates found in the dataset")
73
+
74
+ # Prepare data for domain identification
75
+ # Use highly variable genes if requested and available
76
+ if params.use_highly_variable and "highly_variable" in adata.var.columns:
77
+ adata_subset = adata[:, adata.var["highly_variable"]].copy()
78
+ else:
79
+ adata_subset = adata.copy()
80
+
81
+ # Check if data has been scaled (z-score normalized)
82
+ # Scaled data typically has negative values and is centered around 0
83
+ from scipy.sparse import issparse
84
+
85
+ # Validate data preprocessing state
86
+ data_min = (
87
+ adata_subset.X.min()
88
+ if not issparse(adata_subset.X)
89
+ else adata_subset.X.data.min()
90
+ )
91
+ data_max = (
92
+ adata_subset.X.max()
93
+ if not issparse(adata_subset.X)
94
+ else adata_subset.X.data.max()
95
+ )
96
+
97
+ # Check data preprocessing state
98
+ has_negatives = data_min < 0
99
+ has_large_values = data_max > 100
100
+
101
+ # Provide informative warnings without enforcing
102
+ if has_negatives:
103
+ await ctx.warning(
104
+ f"Data contains negative values (min={data_min:.2f}). "
105
+ "This might indicate scaled/z-scored data. "
106
+ "SpaGCN typically works best with normalized, log-transformed data."
107
+ )
108
+
109
+ # Use raw data if available for better results
110
+ if adata.raw is not None:
111
+ gene_mask = adata.raw.var_names.isin(adata_subset.var_names)
112
+ adata_subset = adata.raw[:, gene_mask].to_adata()
113
+
114
+ elif has_large_values:
115
+ await ctx.warning(
116
+ f"Data contains large values (max={data_max:.2f}). "
117
+ "This might indicate raw count data. "
118
+ "Consider normalizing and log-transforming for better results."
119
+ )
120
+
121
+ # Ensure data is float type for SpaGCN compatibility
122
+ if adata_subset.X.dtype != np.float32 and adata_subset.X.dtype != np.float64:
123
+ adata_subset.X = adata_subset.X.astype(np.float32)
124
+
125
+ # Check for problematic values that can cause SpaGCN to hang
126
+ # Handle both dense and sparse matrices
127
+ from scipy.sparse import issparse
128
+
129
+ if issparse(adata_subset.X):
130
+ # For sparse matrices, check the data attribute
131
+ if np.any(np.isnan(adata_subset.X.data)) or np.any(
132
+ np.isinf(adata_subset.X.data)
133
+ ):
134
+ await ctx.warning(
135
+ "Found NaN or infinite values in sparse data, replacing with 0"
136
+ )
137
+ adata_subset.X.data = np.nan_to_num(
138
+ adata_subset.X.data, nan=0.0, posinf=0.0, neginf=0.0
139
+ )
140
+ else:
141
+ # For dense matrices
142
+ if np.any(np.isnan(adata_subset.X)) or np.any(np.isinf(adata_subset.X)):
143
+ await ctx.warning(
144
+ "Found NaN or infinite values in data, replacing with 0"
145
+ )
146
+ adata_subset.X = np.nan_to_num(
147
+ adata_subset.X, nan=0.0, posinf=0.0, neginf=0.0
148
+ )
149
+
150
+ # Use pre-selected highly variable genes if available
151
+ if "highly_variable" in adata_subset.var.columns:
152
+ hvg_count = adata_subset.var["highly_variable"].sum()
153
+ if hvg_count > 0:
154
+ adata_subset = adata_subset[
155
+ :, adata_subset.var["highly_variable"]
156
+ ].copy()
157
+
158
+ # Identify domains based on method
159
+ if params.method == "spagcn":
160
+ domain_labels, embeddings_key, statistics = await _identify_domains_spagcn(
161
+ adata_subset, params, ctx
162
+ )
163
+ elif params.method in ["leiden", "louvain"]:
164
+ domain_labels, embeddings_key, statistics = (
165
+ await _identify_domains_clustering(adata_subset, params, ctx)
166
+ )
167
+ elif params.method == "stagate":
168
+ domain_labels, embeddings_key, statistics = await _identify_domains_stagate(
169
+ adata_subset, params, ctx
170
+ )
171
+ elif params.method == "graphst":
172
+ domain_labels, embeddings_key, statistics = await _identify_domains_graphst(
173
+ adata_subset, params, ctx
174
+ )
175
+ else:
176
+ raise ParameterError(
177
+ f"Unsupported method: {params.method}. Available methods: spagcn, leiden, louvain, stagate, graphst"
178
+ )
179
+
180
+ # Store domain labels in original adata
181
+ domain_key = f"spatial_domains_{params.method}"
182
+ adata.obs[domain_key] = domain_labels
183
+ ensure_categorical(adata, domain_key)
184
+
185
+ # Store embeddings if available
186
+ if embeddings_key and embeddings_key in adata_subset.obsm:
187
+ adata.obsm[embeddings_key] = adata_subset.obsm[embeddings_key]
188
+
189
+ # Refine domains if requested
190
+ refined_domain_key = None
191
+ if params.refine_domains:
192
+ try:
193
+ refined_domain_key = f"{domain_key}_refined"
194
+ refined_labels = _refine_spatial_domains(
195
+ adata,
196
+ domain_key,
197
+ refined_domain_key,
198
+ threshold=params.refinement_threshold,
199
+ )
200
+ adata.obs[refined_domain_key] = refined_labels
201
+ adata.obs[refined_domain_key] = adata.obs[refined_domain_key].astype(
202
+ "category"
203
+ )
204
+ except Exception as e:
205
+ await ctx.warning(
206
+ f"Domain refinement failed: {e}. Proceeding with unrefined domains."
207
+ )
208
+ refined_domain_key = None # Reset key if refinement failed
209
+
210
+ # Get domain counts
211
+ domain_counts = adata.obs[domain_key].value_counts().to_dict()
212
+ domain_counts = {str(k): int(v) for k, v in domain_counts.items()}
213
+
214
+ # COW FIX: No need to update data_store - changes already reflected via direct reference
215
+ # All modifications to adata.obs/obsm/obsp are in-place and preserved
216
+
217
+ # Create result
218
+ result = SpatialDomainResult(
219
+ data_id=data_id,
220
+ method=params.method,
221
+ n_domains=len(domain_counts),
222
+ domain_key=domain_key,
223
+ domain_counts=domain_counts,
224
+ refined_domain_key=refined_domain_key,
225
+ statistics=statistics,
226
+ embeddings_key=embeddings_key,
227
+ )
228
+
229
+ return result
230
+
231
+ except Exception as e:
232
+ raise ProcessingError(
233
+ f"Error in spatial domain identification: {e}"
234
+ ) from e
235
+
236
+
237
+ async def _identify_domains_spagcn(
238
+ adata: Any, params: SpatialDomainParameters, ctx: "ToolContext"
239
+ ) -> tuple:
240
+ """
241
+ Identifies spatial domains using the SpaGCN algorithm.
242
+
243
+ SpaGCN (Spatial Graph Convolutional Network) constructs a spatial graph where
244
+ each spot is a node. It then uses a graph convolutional network to learn a
245
+ low-dimensional embedding that integrates gene expression, spatial relationships,
246
+ and optionally histology image features. The final domains are obtained by
247
+ clustering these learned embeddings. This method requires the `SpaGCN` package.
248
+ """
249
+ spg = require("SpaGCN", ctx, feature="SpaGCN spatial domain identification")
250
+
251
+ # Apply SpaGCN-specific gene filtering (algorithm requirement)
252
+ try:
253
+ spg.prefilter_genes(adata, min_cells=3)
254
+ spg.prefilter_specialgenes(adata)
255
+ except Exception as e:
256
+ await ctx.warning(
257
+ f"SpaGCN gene filtering failed: {e}. Continuing without filtering."
258
+ )
259
+
260
+ try:
261
+ # Get and validate spatial coordinates (auto-detects key, validates NaN/inf/identical)
262
+ coords = require_spatial_coords(adata)
263
+ n_spots = coords.shape[0]
264
+
265
+ # Warn about potentially unstable domain assignments
266
+ spots_per_domain = n_spots / params.n_domains
267
+ if spots_per_domain < 10:
268
+ await ctx.warning(
269
+ f"Requesting {params.n_domains} domains for {n_spots} spots "
270
+ f"({spots_per_domain:.1f} spots per domain). "
271
+ "This may result in unstable or noisy domain assignments."
272
+ )
273
+
274
+ # For SpaGCN, we need pixel coordinates for histology
275
+ # If not available, use array coordinates
276
+ x_array = coords[:, 0].tolist()
277
+ y_array = coords[:, 1].tolist()
278
+ x_pixel = x_array.copy()
279
+ y_pixel = y_array.copy()
280
+
281
+ # Create a dummy histology image if not available
282
+ img = None
283
+ scale_factor = 1.0 # Default scale factor
284
+
285
+ # Try to get histology image from adata.uns (10x Visium data)
286
+ if params.spagcn_use_histology and "spatial" in adata.uns:
287
+ # Get the first available library ID
288
+ library_ids = list(adata.uns["spatial"].keys())
289
+
290
+ if library_ids:
291
+ lib_id = library_ids[0]
292
+ spatial_data = adata.uns["spatial"][lib_id]
293
+
294
+ # Try to get image from spatial data
295
+ if "images" in spatial_data:
296
+ img_dict = spatial_data["images"]
297
+
298
+ # Try to get scalefactors
299
+ scalefactors = spatial_data.get("scalefactors", {})
300
+
301
+ # Prefer high-res image, fall back to low-res
302
+ if (
303
+ "hires" in img_dict
304
+ and "tissue_hires_scalef" in scalefactors
305
+ ):
306
+ img = img_dict["hires"]
307
+ scale_factor = scalefactors["tissue_hires_scalef"]
308
+ elif (
309
+ "lowres" in img_dict
310
+ and "tissue_lowres_scalef" in scalefactors
311
+ ):
312
+ img = img_dict["lowres"]
313
+ scale_factor = scalefactors["tissue_lowres_scalef"]
314
+ elif "hires" in img_dict:
315
+ # Try without scalefactor
316
+ img = img_dict["hires"]
317
+ elif "lowres" in img_dict:
318
+ # Try without scalefactor
319
+ img = img_dict["lowres"]
320
+
321
+ if img is None:
322
+ # Create dummy image or disable histology
323
+ params.spagcn_use_histology = False
324
+ img = np.ones((100, 100, 3), dtype=np.uint8) * 255 # White dummy image
325
+ else:
326
+ # Apply scale factor to get pixel coordinates
327
+ x_pixel = [int(x * scale_factor) for x in x_array]
328
+ y_pixel = [int(y * scale_factor) for y in y_array]
329
+
330
+ # Import and call SpaGCN function
331
+ from SpaGCN.ez_mode import detect_spatial_domains_ez_mode
332
+
333
+ # Call SpaGCN with error handling and timeout protection
334
+ try:
335
+ # Validate input data before calling SpaGCN
336
+ if len(x_array) != adata.shape[0]:
337
+ raise DataError(
338
+ f"Spatial coordinates length ({len(x_array)}) doesn't match data ({adata.shape[0]})"
339
+ )
340
+
341
+ # Add timeout protection for SpaGCN call which can hang
342
+ import asyncio
343
+ import concurrent.futures
344
+
345
+ # Run SpaGCN in a thread pool to avoid blocking
346
+ loop = asyncio.get_event_loop()
347
+ with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
348
+ future = loop.run_in_executor(
349
+ executor,
350
+ lambda: detect_spatial_domains_ez_mode(
351
+ adata, # Pass the adata parameter (which is actually adata_subset)
352
+ img,
353
+ x_array,
354
+ y_array,
355
+ x_pixel,
356
+ y_pixel,
357
+ n_clusters=params.n_domains,
358
+ histology=params.spagcn_use_histology,
359
+ s=params.spagcn_s,
360
+ b=params.spagcn_b,
361
+ p=params.spagcn_p,
362
+ r_seed=params.spagcn_random_seed,
363
+ ),
364
+ )
365
+
366
+ # Simple, predictable timeout
367
+ timeout_seconds = (
368
+ params.timeout if params.timeout else 600
369
+ ) # Default 10 minutes
370
+
371
+ try:
372
+ domain_labels = await asyncio.wait_for(
373
+ future, timeout=timeout_seconds
374
+ )
375
+ except asyncio.TimeoutError as e:
376
+ error_msg = (
377
+ f"SpaGCN timed out after {timeout_seconds:.0f} seconds. "
378
+ f"Dataset: {n_spots} spots, {adata.n_vars} genes. "
379
+ "Try: 1) Reducing n_domains, 2) Using leiden/louvain instead, "
380
+ "3) Preprocessing with fewer genes/spots, or 4) Adjusting parameters (s, b, p)."
381
+ )
382
+ raise ProcessingError(error_msg) from e
383
+ except Exception as spagcn_error:
384
+ raise ProcessingError(
385
+ f"SpaGCN detect_spatial_domains_ez_mode failed: {str(spagcn_error)}"
386
+ ) from spagcn_error
387
+
388
+ domain_labels = pd.Series(domain_labels, index=adata.obs.index).astype(str)
389
+
390
+ statistics = {
391
+ "method": "spagcn",
392
+ "n_clusters": params.n_domains,
393
+ "s_parameter": params.spagcn_s,
394
+ "b_parameter": params.spagcn_b,
395
+ "p_parameter": params.spagcn_p,
396
+ "use_histology": params.spagcn_use_histology,
397
+ }
398
+
399
+ return domain_labels, None, statistics
400
+
401
+ except Exception as e:
402
+ raise ProcessingError(f"SpaGCN execution failed: {e}") from e
403
+
404
+
405
+ async def _identify_domains_clustering(
406
+ adata: Any, params: SpatialDomainParameters, ctx: "ToolContext"
407
+ ) -> tuple:
408
+ """
409
+ Identifies spatial domains using Leiden or Louvain clustering on a composite graph.
410
+
411
+ This function adapts standard graph-based clustering algorithms for spatial data.
412
+ It first constructs a k-nearest neighbor graph based on gene expression (typically
413
+ from PCA embeddings) and another based on spatial coordinates. These two graphs are
414
+ then combined into a single weighted graph. Applying Leiden or Louvain clustering
415
+ to this composite graph partitions the data into domains that are cohesive in both
416
+ expression and physical space.
417
+ """
418
+ try:
419
+ # Get parameters from params, use defaults if not provided
420
+ n_neighbors = params.cluster_n_neighbors or 15
421
+ spatial_weight = params.cluster_spatial_weight or 0.3
422
+
423
+ # Ensure PCA and neighbors are computed (lazy computation)
424
+ ensure_pca(adata)
425
+ ensure_neighbors(adata, n_neighbors=n_neighbors)
426
+
427
+ # Add spatial information to the neighborhood graph
428
+ if "spatial" in adata.obsm:
429
+
430
+ try:
431
+ sq = require("squidpy", ctx, feature="spatial neighborhood graph")
432
+
433
+ # Use squidpy's scientifically validated spatial neighbors
434
+ sq.gr.spatial_neighbors(adata, coord_type="generic")
435
+
436
+ # Combine expression and spatial graphs
437
+ expr_weight = 1 - spatial_weight
438
+
439
+ if "spatial_connectivities" in adata.obsp:
440
+ combined_conn = (
441
+ expr_weight * adata.obsp["connectivities"]
442
+ + spatial_weight * adata.obsp["spatial_connectivities"]
443
+ )
444
+ adata.obsp["connectivities"] = combined_conn
445
+
446
+ except Exception as spatial_error:
447
+ raise ProcessingError(
448
+ f"Spatial graph construction failed: {spatial_error}"
449
+ ) from spatial_error
450
+
451
+ # Perform clustering
452
+ # Use a variable to store key_added to ensure consistency
453
+ key_added = (
454
+ f"spatial_{params.method}" # e.g., "spatial_leiden" or "spatial_louvain"
455
+ )
456
+
457
+ if params.method == "leiden":
458
+ sc.tl.leiden(adata, resolution=params.resolution, key_added=key_added)
459
+ else: # louvain
460
+ # Deprecation notice for louvain
461
+ await ctx.warning(
462
+ "Louvain clustering is deprecated and may not be available on all platforms "
463
+ "(especially macOS due to compilation issues). "
464
+ "Consider using 'leiden' instead, which is an improved algorithm with better performance. "
465
+ "Automatic fallback to Leiden will be used if Louvain is unavailable."
466
+ )
467
+ try:
468
+ sc.tl.louvain(adata, resolution=params.resolution, key_added=key_added)
469
+ except ImportError as e:
470
+ # Fallback to leiden if louvain is not available
471
+ await ctx.warning(
472
+ f"Louvain not available: {e}. Using Leiden clustering instead."
473
+ )
474
+ sc.tl.leiden(adata, resolution=params.resolution, key_added=key_added)
475
+
476
+ domain_labels = adata.obs[key_added].astype(str)
477
+
478
+ statistics = {
479
+ "method": params.method,
480
+ "resolution": params.resolution,
481
+ "n_neighbors": n_neighbors,
482
+ "spatial_weight": spatial_weight if "spatial" in adata.obsm else 0.0,
483
+ }
484
+
485
+ return domain_labels, "X_pca", statistics
486
+
487
+ except Exception as e:
488
+ raise ProcessingError(f"{params.method} clustering failed: {e}") from e
489
+
490
+
491
+ def _refine_spatial_domains(
492
+ adata: Any, domain_key: str, refined_key: str, threshold: float = 0.5
493
+ ) -> pd.Series:
494
+ """
495
+ Refines spatial domain assignments using a spatial smoothing algorithm.
496
+
497
+ This post-processing step aims to create more spatially coherent domains by
498
+ reducing noise. It iterates through each spot and re-assigns its domain label
499
+ to the majority label of its k-nearest spatial neighbors, but ONLY if a
500
+ sufficient proportion of neighbors differ from the current label.
501
+
502
+ This threshold-based approach follows SpaGCN (Hu et al., Nature Methods 2021),
503
+ which only relabels spots when more than half of their neighbors are assigned
504
+ to a different domain, preventing over-smoothing while still reducing noise.
505
+
506
+ Args:
507
+ adata: AnnData object containing spatial data
508
+ domain_key: Column in adata.obs containing domain labels to refine
509
+ refined_key: Name for the refined domain key
510
+ threshold: Minimum proportion of neighbors that must differ to trigger
511
+ relabeling (default: 0.5, i.e., 50%, following SpaGCN)
512
+
513
+ Returns:
514
+ pd.Series: Refined domain labels
515
+ """
516
+ try:
517
+ # Get and validate spatial coordinates
518
+ coords = require_spatial_coords(adata)
519
+
520
+ # Get domain labels
521
+ labels = adata.obs[domain_key].astype(str)
522
+
523
+ if len(labels) == 0:
524
+ raise DataNotFoundError("Dataset is empty, cannot refine domains")
525
+
526
+ # Simple spatial smoothing: assign each spot to the most common domain in its neighborhood
527
+ from sklearn.neighbors import NearestNeighbors
528
+
529
+ # Find k nearest neighbors (ensure we have enough data points)
530
+ k = min(10, len(labels) - 1)
531
+ if k < 1:
532
+ # If we have too few points, no refinement possible
533
+ return labels
534
+
535
+ try:
536
+ nbrs = NearestNeighbors(n_neighbors=k).fit(coords)
537
+ distances, indices = nbrs.kneighbors(coords)
538
+ except Exception as nn_error:
539
+ # If nearest neighbors fails, raise error
540
+ raise ProcessingError(
541
+ f"Nearest neighbors computation failed: {nn_error}"
542
+ ) from nn_error
543
+
544
+ # Optimized: Pre-extract values and use Counter instead of pandas mode()
545
+ # Counter.most_common() is ~6x faster than pandas Series.mode()
546
+ labels_values = labels.values
547
+ refined_labels = []
548
+
549
+ for i, neighbors in enumerate(indices):
550
+ original_label = labels_values[i]
551
+ neighbor_labels = labels_values[neighbors]
552
+
553
+ # Calculate proportion of neighbors that differ from current label
554
+ different_count = np.sum(neighbor_labels != original_label)
555
+ different_ratio = different_count / len(neighbor_labels)
556
+
557
+ # Only relabel if sufficient proportion of neighbors differ (SpaGCN approach)
558
+ if different_ratio >= threshold:
559
+ # Get most common label using Counter (6x faster than pandas mode)
560
+ counter = Counter(neighbor_labels)
561
+ most_common = counter.most_common(1)[0][0]
562
+ refined_labels.append(most_common)
563
+ else:
564
+ # Keep original label if not enough neighbors differ
565
+ refined_labels.append(original_label)
566
+
567
+ return pd.Series(refined_labels, index=labels.index)
568
+
569
+ except Exception as e:
570
+ # Raise error instead of silently failing
571
+ raise ProcessingError(f"Failed to refine spatial domains: {e}") from e
572
+
573
+
574
+ async def _identify_domains_stagate(
575
+ adata: Any, params: SpatialDomainParameters, ctx: "ToolContext"
576
+ ) -> tuple:
577
+ """
578
+ Identifies spatial domains using the STAGATE algorithm.
579
+
580
+ STAGATE (Spatially-aware graph attention network) learns low-dimensional
581
+ embeddings for spots by integrating gene expression with spatial information
582
+ through a graph attention mechanism. This allows the model to weigh the
583
+ importance of neighboring spots adaptively. The resulting embeddings are then
584
+ clustered to define spatial domains. This method requires the `STAGATE_pyG`
585
+ package.
586
+ """
587
+ STAGATE_pyG = require(
588
+ "STAGATE_pyG", ctx, feature="STAGATE spatial domain identification"
589
+ )
590
+ import torch
591
+
592
+ try:
593
+ # STAGATE_pyG works with preprocessed data
594
+ adata_stagate = adata.copy()
595
+
596
+ # Calculate spatial graph
597
+ # STAGATE_pyG uses smaller default radius (50 instead of 150)
598
+ rad_cutoff = params.stagate_rad_cutoff or 50
599
+ STAGATE_pyG.Cal_Spatial_Net(adata_stagate, rad_cutoff=rad_cutoff)
600
+
601
+ # Optional: Display network statistics
602
+ try:
603
+ STAGATE_pyG.Stats_Spatial_Net(adata_stagate)
604
+ except Exception:
605
+ pass # Stats display is optional
606
+
607
+ # Set device
608
+ device = torch.device(get_device(prefer_gpu=True))
609
+
610
+ # Train STAGATE with timeout protection
611
+ import asyncio
612
+ import concurrent.futures
613
+
614
+ loop = asyncio.get_running_loop()
615
+ with concurrent.futures.ThreadPoolExecutor() as executor:
616
+ timeout_seconds = params.timeout or 600
617
+
618
+ adata_stagate = await asyncio.wait_for(
619
+ loop.run_in_executor(
620
+ executor,
621
+ lambda: STAGATE_pyG.train_STAGATE(adata_stagate, device=device),
622
+ ),
623
+ timeout=timeout_seconds,
624
+ )
625
+
626
+ # Get embeddings
627
+ embeddings_key = "STAGATE"
628
+ n_clusters_target = params.n_domains
629
+
630
+ # Perform mclust clustering on STAGATE embeddings
631
+ # Note: We use our own mclust implementation because STAGATE_pyG.mclust_R
632
+ # has rpy2 compatibility issues with newer versions
633
+ try:
634
+ import numpy as np
635
+ import rpy2.robjects as robjects
636
+ from rpy2.robjects import numpy2ri
637
+
638
+ # Activate numpy to R conversion
639
+ numpy2ri.activate()
640
+
641
+ # Set random seed
642
+ random_seed = params.stagate_random_seed or 42
643
+ np.random.seed(random_seed)
644
+ robjects.r["set.seed"](random_seed)
645
+
646
+ # Load mclust library
647
+ robjects.r.library("mclust")
648
+
649
+ # Get embedding data and convert to float64 (required for R)
650
+ embedding_data = adata_stagate.obsm[embeddings_key].astype(np.float64)
651
+
652
+ # Assign data to R environment (correct way to pass data)
653
+ robjects.r.assign("stagate_embedding", embedding_data)
654
+
655
+ # Call Mclust directly via R code
656
+ robjects.r(
657
+ f"mclust_result <- Mclust(stagate_embedding, G={n_clusters_target})"
658
+ )
659
+
660
+ # Extract classification results
661
+ mclust_labels = np.array(robjects.r("mclust_result$classification"))
662
+
663
+ # Store in adata
664
+ adata_stagate.obs["mclust"] = mclust_labels
665
+ adata_stagate.obs["mclust"] = adata_stagate.obs["mclust"].astype(int)
666
+ adata_stagate.obs["mclust"] = adata_stagate.obs["mclust"].astype("category")
667
+
668
+ domain_labels = adata_stagate.obs["mclust"].astype(str)
669
+ clustering_method = "mclust"
670
+
671
+ # Deactivate numpy2ri to avoid conflicts
672
+ numpy2ri.deactivate()
673
+
674
+ except ImportError as e:
675
+ raise ProcessingError(
676
+ f"STAGATE requires rpy2 for mclust clustering: {e}. "
677
+ "Install with: pip install rpy2"
678
+ ) from e
679
+ except Exception as mclust_error:
680
+ # mclust unavailable - provide clear guidance
681
+ raise ProcessingError(
682
+ f"STAGATE mclust clustering failed with n_domains={n_clusters_target}: "
683
+ f"{type(mclust_error).__name__}: {mclust_error}. "
684
+ "To fix: Install R and run 'install.packages(\"mclust\")' in R, then 'pip install rpy2'. "
685
+ "Alternatively, use method='leiden' or method='spagcn' which don't require R."
686
+ ) from mclust_error
687
+
688
+ # Copy embeddings to original adata
689
+ adata.obsm[embeddings_key] = adata_stagate.obsm["STAGATE"]
690
+
691
+ statistics = {
692
+ "method": "stagate_pyg",
693
+ "n_clusters": len(domain_labels.unique()),
694
+ "target_n_clusters": n_clusters_target,
695
+ "clustering_method": clustering_method,
696
+ "rad_cutoff": rad_cutoff,
697
+ "device": str(device),
698
+ "framework": "PyTorch Geometric",
699
+ }
700
+
701
+ return domain_labels, embeddings_key, statistics
702
+
703
+ except asyncio.TimeoutError as e:
704
+ raise ProcessingError(
705
+ f"STAGATE training timeout after {params.timeout or 600} seconds"
706
+ ) from e
707
+ except Exception as e:
708
+ raise ProcessingError(f"STAGATE execution failed: {e}") from e
709
+
710
+
711
+ async def _identify_domains_graphst(
712
+ adata: Any, params: SpatialDomainParameters, ctx: "ToolContext"
713
+ ) -> tuple:
714
+ """
715
+ Identifies spatial domains using the GraphST algorithm.
716
+
717
+ GraphST (Graph Self-supervised Contrastive Learning) learns spatial domain
718
+ representations by combining graph neural networks with self-supervised
719
+ contrastive learning. It constructs a spatial graph based on spot locations
720
+ and learns embeddings that preserve both gene expression patterns and spatial
721
+ relationships. The learned embeddings are then clustered to define spatial
722
+ domains. This method requires the `GraphST` package.
723
+ """
724
+ require("GraphST", ctx, feature="GraphST spatial domain identification")
725
+ import asyncio
726
+ import concurrent.futures
727
+
728
+ import torch
729
+ from GraphST.GraphST import GraphST
730
+ from GraphST.utils import clustering as graphst_clustering
731
+
732
+ try:
733
+ # GraphST works with preprocessed data
734
+ adata_graphst = adata.copy()
735
+
736
+ # Set device (support CUDA, MPS, and CPU)
737
+ device_str = await resolve_device_async(
738
+ prefer_gpu=params.graphst_use_gpu, ctx=ctx, allow_mps=True
739
+ )
740
+ device = torch.device(device_str)
741
+
742
+ # Determine number of clusters
743
+ n_clusters = params.graphst_n_clusters or params.n_domains
744
+
745
+ # Initialize model
746
+ model = GraphST(
747
+ adata_graphst,
748
+ device=device,
749
+ random_seed=params.graphst_random_seed,
750
+ )
751
+
752
+ # Train model (this is blocking, run in executor)
753
+ # Run training in thread pool to avoid blocking
754
+ loop = asyncio.get_running_loop()
755
+ with concurrent.futures.ThreadPoolExecutor() as executor:
756
+ # Set timeout
757
+ timeout_seconds = params.timeout or 600
758
+
759
+ adata_graphst = await asyncio.wait_for(
760
+ loop.run_in_executor(executor, lambda: model.train()),
761
+ timeout=timeout_seconds,
762
+ )
763
+
764
+ # Get embeddings key
765
+ embeddings_key = "emb" # GraphST stores embeddings in adata.obsm['emb']
766
+
767
+ # Perform clustering on GraphST embeddings
768
+
769
+ # Run clustering in thread pool
770
+ with concurrent.futures.ThreadPoolExecutor() as executor:
771
+
772
+ def run_clustering():
773
+ graphst_clustering(
774
+ adata_graphst,
775
+ n_clusters=n_clusters,
776
+ radius=params.graphst_radius if params.graphst_refinement else None,
777
+ method=params.graphst_clustering_method,
778
+ refinement=params.graphst_refinement,
779
+ )
780
+
781
+ await loop.run_in_executor(executor, run_clustering)
782
+
783
+ # Get domain labels
784
+ domain_labels = adata_graphst.obs["domain"].astype(str)
785
+
786
+ # Copy embeddings to original adata
787
+ adata.obsm[embeddings_key] = adata_graphst.obsm["emb"]
788
+
789
+ statistics = {
790
+ "method": "graphst",
791
+ "n_clusters": len(domain_labels.unique()),
792
+ "clustering_method": params.graphst_clustering_method,
793
+ "refinement": params.graphst_refinement,
794
+ "device": str(device),
795
+ "framework": "PyTorch",
796
+ }
797
+
798
+ if params.graphst_refinement:
799
+ statistics["refinement_radius"] = params.graphst_radius
800
+
801
+ return domain_labels, embeddings_key, statistics
802
+
803
+ except asyncio.TimeoutError as e:
804
+ raise ProcessingError(
805
+ f"GraphST training timeout after {params.timeout or 600} seconds"
806
+ ) from e
807
+ except Exception as e:
808
+ raise ProcessingError(f"GraphST execution failed: {e}") from e