chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,441 @@
1
+ """
2
+ Spatial Registration Tool
3
+
4
+ Aligns and registers multiple spatial transcriptomics slices using
5
+ optimal transport (PASTE) or diffeomorphic mapping (STalign).
6
+ """
7
+
8
+ import logging
9
+ from typing import TYPE_CHECKING, Optional
10
+
11
+ import numpy as np
12
+
13
+ if TYPE_CHECKING:
14
+ import anndata as ad
15
+ from ..spatial_mcp_adapter import ToolContext
16
+
17
+ from ..models.data import RegistrationParameters
18
+ from ..utils.adata_utils import (
19
+ ensure_unique_var_names,
20
+ find_common_genes,
21
+ get_spatial_key,
22
+ )
23
+ from ..utils.dependency_manager import require
24
+ from ..utils.device_utils import get_device, get_ot_backend
25
+ from ..utils.exceptions import ParameterError, ProcessingError
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # =============================================================================
30
+ # Validation Helpers
31
+ # =============================================================================
32
+
33
+
34
+ def _validate_spatial_coords(adata_list: list["ad.AnnData"]) -> str:
35
+ """
36
+ Validate all slices have spatial coordinates.
37
+
38
+ Returns the spatial key found.
39
+ Raises ParameterError if any slice is missing coordinates.
40
+ """
41
+ spatial_key = None
42
+ for i, adata in enumerate(adata_list):
43
+ key = get_spatial_key(adata)
44
+ if key is None:
45
+ raise ParameterError(
46
+ f"Slice {i} missing spatial coordinates. "
47
+ f"Expected in adata.obsm['spatial'] or similar."
48
+ )
49
+ if spatial_key is None:
50
+ spatial_key = key
51
+ return spatial_key or "spatial"
52
+
53
+
54
+ def _get_common_genes(adata_list: list["ad.AnnData"]) -> list[str]:
55
+ """Get common genes across all slices after making names unique."""
56
+ # Make names unique first
57
+ for adata in adata_list:
58
+ ensure_unique_var_names(adata)
59
+
60
+ # Use unified function for intersection
61
+ genes = find_common_genes(*[adata.var_names for adata in adata_list])
62
+ return genes
63
+
64
+
65
+ # =============================================================================
66
+ # STalign Image Preparation (module-level, not nested)
67
+ # =============================================================================
68
+
69
+
70
+ def _prepare_stalign_image(
71
+ coords: np.ndarray,
72
+ intensity: np.ndarray,
73
+ image_size: tuple,
74
+ ) -> tuple:
75
+ """
76
+ Convert point cloud to rasterized image for STalign.
77
+
78
+ Args:
79
+ coords: Spatial coordinates (N, 2)
80
+ intensity: Intensity values per point (N,)
81
+ image_size: Output image dimensions (height, width)
82
+
83
+ Returns:
84
+ Tuple of (xgrid, image_tensor)
85
+ """
86
+ import torch
87
+
88
+ # Normalize coordinates to image space with padding
89
+ coords_norm = coords.copy()
90
+ padding = 0.1
91
+
92
+ for dim in range(2):
93
+ cmin, cmax = coords[:, dim].min(), coords[:, dim].max()
94
+ crange = cmax - cmin
95
+ if crange > 0:
96
+ target_min = padding * image_size[dim]
97
+ target_max = (1 - padding) * image_size[dim]
98
+ coords_norm[:, dim] = (coords[:, dim] - cmin) / crange
99
+ coords_norm[:, dim] = (
100
+ coords_norm[:, dim] * (target_max - target_min) + target_min
101
+ )
102
+
103
+ # Create coordinate grid
104
+ xgrid = [
105
+ torch.linspace(0, image_size[0], image_size[0], dtype=torch.float32),
106
+ torch.linspace(0, image_size[1], image_size[1], dtype=torch.float32),
107
+ ]
108
+
109
+ # Rasterize with Gaussian smoothing
110
+ image = np.zeros(image_size, dtype=np.float32)
111
+ for norm_coord, intens in zip(coords_norm, intensity):
112
+ x_idx = int(np.clip(norm_coord[1], 0, image_size[0] - 1))
113
+ y_idx = int(np.clip(norm_coord[0], 0, image_size[1] - 1))
114
+
115
+ # Gaussian kernel (radius 2)
116
+ for dx in range(-2, 3):
117
+ for dy in range(-2, 3):
118
+ xi, yi = x_idx + dx, y_idx + dy
119
+ if 0 <= xi < image_size[0] and 0 <= yi < image_size[1]:
120
+ weight = np.exp(-(dx * dx + dy * dy) / 2.0)
121
+ image[xi, yi] += intens * weight
122
+
123
+ # Normalize
124
+ if image.max() > 0:
125
+ image /= image.max()
126
+
127
+ return xgrid, torch.tensor(image, dtype=torch.float32)
128
+
129
+
130
+ # =============================================================================
131
+ # Core Registration Functions
132
+ # =============================================================================
133
+
134
+
135
+ def _register_paste(
136
+ adata_list: list["ad.AnnData"],
137
+ params: RegistrationParameters,
138
+ spatial_key: str = "spatial",
139
+ ) -> list["ad.AnnData"]:
140
+ """Register slices using PASTE optimal transport."""
141
+ import paste as pst
142
+ import scanpy as sc
143
+
144
+ reference_idx = params.reference_idx or 0
145
+ registered = [adata.copy() for adata in adata_list]
146
+ common_genes = _get_common_genes(registered)
147
+
148
+ if len(registered) == 2:
149
+ # Pairwise alignment
150
+
151
+ slice1 = registered[0][:, common_genes].copy()
152
+ slice2 = registered[1][:, common_genes].copy()
153
+
154
+ # Normalize
155
+ sc.pp.normalize_total(slice1, target_sum=1e4)
156
+ sc.pp.log1p(slice1)
157
+ sc.pp.normalize_total(slice2, target_sum=1e4)
158
+ sc.pp.log1p(slice2)
159
+
160
+ # Run PASTE
161
+ pi = pst.pairwise_align(
162
+ slice1,
163
+ slice2,
164
+ alpha=params.paste_alpha,
165
+ numItermax=params.paste_numItermax,
166
+ verbose=True,
167
+ )
168
+
169
+ # Stack and extract aligned coordinates
170
+ aligned = pst.stack_slices_pairwise([slice1, slice2], [pi])
171
+ registered[0].obsm["spatial_registered"] = aligned[0].obsm["spatial"]
172
+ registered[1].obsm["spatial_registered"] = aligned[1].obsm["spatial"]
173
+
174
+ else:
175
+ # Multi-slice center alignment
176
+
177
+ slices = [adata[:, common_genes] for adata in registered]
178
+ backend = get_ot_backend(params.use_gpu)
179
+
180
+ # Initial pairwise alignments to reference
181
+ pis = []
182
+ for i, slice_data in enumerate(slices):
183
+ if i == reference_idx:
184
+ pis.append(np.eye(slices[i].shape[0]))
185
+ else:
186
+ pi = pst.pairwise_align(
187
+ slices[reference_idx],
188
+ slice_data,
189
+ alpha=params.paste_alpha,
190
+ backend=backend,
191
+ use_gpu=params.use_gpu,
192
+ verbose=False,
193
+ gpu_verbose=False,
194
+ )
195
+ pis.append(pi)
196
+
197
+ # Center alignment
198
+ _, pis_new = pst.center_align(
199
+ slices[reference_idx],
200
+ slices,
201
+ pis_init=pis,
202
+ alpha=params.paste_alpha,
203
+ backend=backend,
204
+ use_gpu=params.use_gpu,
205
+ n_components=params.paste_n_components,
206
+ verbose=False,
207
+ gpu_verbose=False,
208
+ )
209
+
210
+ # Apply transformations
211
+ for i, (adata, pi) in enumerate(zip(registered, pis_new, strict=False)):
212
+ if i == reference_idx:
213
+ adata.obsm["spatial_registered"] = adata.obsm[spatial_key].copy()
214
+ else:
215
+ adata.obsm["spatial_registered"] = _transform_coordinates(
216
+ pi, slices[reference_idx].obsm[spatial_key]
217
+ )
218
+
219
+ return registered
220
+
221
+
222
+ def _register_stalign(
223
+ adata_list: list["ad.AnnData"],
224
+ params: RegistrationParameters,
225
+ spatial_key: str = "spatial",
226
+ ) -> list["ad.AnnData"]:
227
+ """Register slices using STalign diffeomorphic mapping."""
228
+ import STalign.STalign as ST
229
+ import torch
230
+
231
+ if len(adata_list) != 2:
232
+ raise ParameterError(
233
+ f"STalign only supports pairwise registration, got {len(adata_list)} slices. "
234
+ f"Use PASTE for multi-slice alignment."
235
+ )
236
+
237
+ registered = [adata.copy() for adata in adata_list]
238
+ source, target = registered[0], registered[1]
239
+
240
+ # Prepare coordinates
241
+ source_coords = source.obsm[spatial_key].astype(np.float32)
242
+ target_coords = target.obsm[spatial_key].astype(np.float32)
243
+
244
+ # Prepare intensity
245
+ if params.stalign_use_expression:
246
+ common_genes = _get_common_genes(registered)
247
+ if len(common_genes) < 100:
248
+ logger.warning(f"Only {len(common_genes)} common genes found")
249
+
250
+ # Compute sum intensity (sparse-aware)
251
+ source_expr = source[:, common_genes].X
252
+ target_expr = target[:, common_genes].X
253
+
254
+ def _safe_sum(X):
255
+ if hasattr(X, "toarray"):
256
+ return np.array(X.sum(axis=1)).flatten().astype(np.float32)
257
+ return X.sum(axis=1).astype(np.float32)
258
+
259
+ source_intensity = _safe_sum(source_expr)
260
+ target_intensity = _safe_sum(target_expr)
261
+ else:
262
+ source_intensity = np.ones(len(source_coords), dtype=np.float32)
263
+ target_intensity = np.ones(len(target_coords), dtype=np.float32)
264
+
265
+ # Prepare images
266
+ image_size = params.stalign_image_size
267
+ source_grid, source_image = _prepare_stalign_image(
268
+ source_coords, source_intensity, image_size
269
+ )
270
+ target_grid, target_image = _prepare_stalign_image(
271
+ target_coords, target_intensity, image_size
272
+ )
273
+
274
+ # STalign parameters
275
+ device = get_device(prefer_gpu=params.use_gpu)
276
+ stalign_params = {
277
+ "a": params.stalign_a,
278
+ "p": 2.0,
279
+ "expand": 2.0,
280
+ "nt": 3,
281
+ "niter": params.stalign_niter,
282
+ "diffeo_start": 0,
283
+ "epL": 2e-08,
284
+ "epT": 0.2,
285
+ "epV": 2000.0,
286
+ "sigmaM": 1.0,
287
+ "sigmaB": 2.0,
288
+ "sigmaA": 5.0,
289
+ "sigmaR": 500000.0,
290
+ "sigmaP": 20.0,
291
+ "device": device,
292
+ "dtype": torch.float32,
293
+ }
294
+
295
+ try:
296
+ result = ST.LDDMM(
297
+ xI=source_grid,
298
+ I=source_image,
299
+ xJ=target_grid,
300
+ J=target_image,
301
+ **stalign_params,
302
+ )
303
+
304
+ A = result.get("A")
305
+ v = result.get("v")
306
+ xv = result.get("xv")
307
+
308
+ if A is None or v is None or xv is None:
309
+ raise ProcessingError("STalign did not return valid transformation")
310
+
311
+ # Transform coordinates
312
+ source_points = torch.tensor(source_coords, dtype=torch.float32)
313
+ transformed = ST.transform_points_source_to_target(xv, v, A, source_points)
314
+
315
+ if isinstance(transformed, torch.Tensor):
316
+ transformed = transformed.numpy()
317
+
318
+ source.obsm["spatial_registered"] = transformed
319
+ target.obsm["spatial_registered"] = target_coords.copy()
320
+
321
+ except Exception as e:
322
+ raise ProcessingError(
323
+ f"STalign registration failed: {e}. Consider using PASTE method."
324
+ ) from e
325
+
326
+ return registered
327
+
328
+
329
+ def _transform_coordinates(
330
+ transport_matrix: np.ndarray,
331
+ reference_coords: np.ndarray,
332
+ ) -> np.ndarray:
333
+ """Transform coordinates via optimal transport matrix."""
334
+ # Normalize rows
335
+ row_sums = transport_matrix.sum(axis=1, keepdims=True)
336
+ row_sums[row_sums == 0] = 1 # Avoid division by zero
337
+ normalized = transport_matrix / row_sums
338
+
339
+ # Weighted average of reference coordinates
340
+ return normalized @ reference_coords
341
+
342
+
343
+ # =============================================================================
344
+ # Public API
345
+ # =============================================================================
346
+
347
+
348
+ def register_slices(
349
+ adata_list: list["ad.AnnData"],
350
+ params: Optional[RegistrationParameters] = None,
351
+ ) -> list["ad.AnnData"]:
352
+ """
353
+ Register multiple spatial transcriptomics slices.
354
+
355
+ Args:
356
+ adata_list: List of AnnData objects to register
357
+ params: Registration parameters (uses defaults if None)
358
+
359
+ Returns:
360
+ List of registered AnnData objects with 'spatial_registered' in obsm
361
+ """
362
+ if params is None:
363
+ params = RegistrationParameters()
364
+
365
+ if len(adata_list) < 2:
366
+ raise ParameterError("Registration requires at least 2 slices")
367
+
368
+ # Validate spatial coordinates and get the spatial key
369
+ spatial_key = _validate_spatial_coords(adata_list)
370
+
371
+ if params.method == "paste":
372
+ return _register_paste(adata_list, params, spatial_key)
373
+ elif params.method == "stalign":
374
+ return _register_stalign(adata_list, params, spatial_key)
375
+ else:
376
+ raise ParameterError(f"Unknown method: {params.method}")
377
+
378
+
379
+ # =============================================================================
380
+ # MCP Tool Wrapper
381
+ # =============================================================================
382
+
383
+
384
+ async def register_spatial_slices_mcp(
385
+ source_id: str,
386
+ target_id: str,
387
+ ctx: "ToolContext",
388
+ method: str = "paste",
389
+ ) -> dict:
390
+ """
391
+ MCP wrapper for spatial registration.
392
+
393
+ Args:
394
+ source_id: Source dataset ID
395
+ target_id: Target dataset ID
396
+ ctx: Tool context for data access
397
+ method: Registration method ('paste' or 'stalign')
398
+
399
+ Returns:
400
+ Registration result dictionary
401
+ """
402
+ # Check dependencies
403
+ if method == "paste":
404
+ require("paste", ctx, feature="PASTE spatial registration")
405
+ elif method == "stalign":
406
+ require("STalign", ctx, feature="STalign spatial registration")
407
+
408
+ # Get data
409
+ source_adata = await ctx.get_adata(source_id)
410
+ target_adata = await ctx.get_adata(target_id)
411
+
412
+ # Create parameters
413
+ params = RegistrationParameters(method=method)
414
+
415
+ try:
416
+ registered = register_slices([source_adata, target_adata], params)
417
+
418
+ # Copy registered coordinates back (in-place modification)
419
+ if "spatial_registered" in registered[0].obsm:
420
+ source_adata.obsm["spatial_registered"] = registered[0].obsm[
421
+ "spatial_registered"
422
+ ]
423
+ if "spatial_registered" in registered[1].obsm:
424
+ target_adata.obsm["spatial_registered"] = registered[1].obsm[
425
+ "spatial_registered"
426
+ ]
427
+
428
+ result = {
429
+ "method": method,
430
+ "source_id": source_id,
431
+ "target_id": target_id,
432
+ "n_source_spots": source_adata.n_obs,
433
+ "n_target_spots": target_adata.n_obs,
434
+ "registration_completed": True,
435
+ "spatial_key_registered": "spatial_registered",
436
+ }
437
+
438
+ return result
439
+
440
+ except Exception as e:
441
+ raise ProcessingError(f"Registration failed: {e}") from e