chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,441 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Spatial Registration Tool
|
|
3
|
+
|
|
4
|
+
Aligns and registers multiple spatial transcriptomics slices using
|
|
5
|
+
optimal transport (PASTE) or diffeomorphic mapping (STalign).
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from typing import TYPE_CHECKING, Optional
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
import anndata as ad
|
|
15
|
+
from ..spatial_mcp_adapter import ToolContext
|
|
16
|
+
|
|
17
|
+
from ..models.data import RegistrationParameters
|
|
18
|
+
from ..utils.adata_utils import (
|
|
19
|
+
ensure_unique_var_names,
|
|
20
|
+
find_common_genes,
|
|
21
|
+
get_spatial_key,
|
|
22
|
+
)
|
|
23
|
+
from ..utils.dependency_manager import require
|
|
24
|
+
from ..utils.device_utils import get_device, get_ot_backend
|
|
25
|
+
from ..utils.exceptions import ParameterError, ProcessingError
|
|
26
|
+
|
|
27
|
+
logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
# =============================================================================
|
|
30
|
+
# Validation Helpers
|
|
31
|
+
# =============================================================================
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _validate_spatial_coords(adata_list: list["ad.AnnData"]) -> str:
|
|
35
|
+
"""
|
|
36
|
+
Validate all slices have spatial coordinates.
|
|
37
|
+
|
|
38
|
+
Returns the spatial key found.
|
|
39
|
+
Raises ParameterError if any slice is missing coordinates.
|
|
40
|
+
"""
|
|
41
|
+
spatial_key = None
|
|
42
|
+
for i, adata in enumerate(adata_list):
|
|
43
|
+
key = get_spatial_key(adata)
|
|
44
|
+
if key is None:
|
|
45
|
+
raise ParameterError(
|
|
46
|
+
f"Slice {i} missing spatial coordinates. "
|
|
47
|
+
f"Expected in adata.obsm['spatial'] or similar."
|
|
48
|
+
)
|
|
49
|
+
if spatial_key is None:
|
|
50
|
+
spatial_key = key
|
|
51
|
+
return spatial_key or "spatial"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _get_common_genes(adata_list: list["ad.AnnData"]) -> list[str]:
|
|
55
|
+
"""Get common genes across all slices after making names unique."""
|
|
56
|
+
# Make names unique first
|
|
57
|
+
for adata in adata_list:
|
|
58
|
+
ensure_unique_var_names(adata)
|
|
59
|
+
|
|
60
|
+
# Use unified function for intersection
|
|
61
|
+
genes = find_common_genes(*[adata.var_names for adata in adata_list])
|
|
62
|
+
return genes
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
# =============================================================================
|
|
66
|
+
# STalign Image Preparation (module-level, not nested)
|
|
67
|
+
# =============================================================================
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def _prepare_stalign_image(
|
|
71
|
+
coords: np.ndarray,
|
|
72
|
+
intensity: np.ndarray,
|
|
73
|
+
image_size: tuple,
|
|
74
|
+
) -> tuple:
|
|
75
|
+
"""
|
|
76
|
+
Convert point cloud to rasterized image for STalign.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
coords: Spatial coordinates (N, 2)
|
|
80
|
+
intensity: Intensity values per point (N,)
|
|
81
|
+
image_size: Output image dimensions (height, width)
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
Tuple of (xgrid, image_tensor)
|
|
85
|
+
"""
|
|
86
|
+
import torch
|
|
87
|
+
|
|
88
|
+
# Normalize coordinates to image space with padding
|
|
89
|
+
coords_norm = coords.copy()
|
|
90
|
+
padding = 0.1
|
|
91
|
+
|
|
92
|
+
for dim in range(2):
|
|
93
|
+
cmin, cmax = coords[:, dim].min(), coords[:, dim].max()
|
|
94
|
+
crange = cmax - cmin
|
|
95
|
+
if crange > 0:
|
|
96
|
+
target_min = padding * image_size[dim]
|
|
97
|
+
target_max = (1 - padding) * image_size[dim]
|
|
98
|
+
coords_norm[:, dim] = (coords[:, dim] - cmin) / crange
|
|
99
|
+
coords_norm[:, dim] = (
|
|
100
|
+
coords_norm[:, dim] * (target_max - target_min) + target_min
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
# Create coordinate grid
|
|
104
|
+
xgrid = [
|
|
105
|
+
torch.linspace(0, image_size[0], image_size[0], dtype=torch.float32),
|
|
106
|
+
torch.linspace(0, image_size[1], image_size[1], dtype=torch.float32),
|
|
107
|
+
]
|
|
108
|
+
|
|
109
|
+
# Rasterize with Gaussian smoothing
|
|
110
|
+
image = np.zeros(image_size, dtype=np.float32)
|
|
111
|
+
for norm_coord, intens in zip(coords_norm, intensity):
|
|
112
|
+
x_idx = int(np.clip(norm_coord[1], 0, image_size[0] - 1))
|
|
113
|
+
y_idx = int(np.clip(norm_coord[0], 0, image_size[1] - 1))
|
|
114
|
+
|
|
115
|
+
# Gaussian kernel (radius 2)
|
|
116
|
+
for dx in range(-2, 3):
|
|
117
|
+
for dy in range(-2, 3):
|
|
118
|
+
xi, yi = x_idx + dx, y_idx + dy
|
|
119
|
+
if 0 <= xi < image_size[0] and 0 <= yi < image_size[1]:
|
|
120
|
+
weight = np.exp(-(dx * dx + dy * dy) / 2.0)
|
|
121
|
+
image[xi, yi] += intens * weight
|
|
122
|
+
|
|
123
|
+
# Normalize
|
|
124
|
+
if image.max() > 0:
|
|
125
|
+
image /= image.max()
|
|
126
|
+
|
|
127
|
+
return xgrid, torch.tensor(image, dtype=torch.float32)
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
# =============================================================================
|
|
131
|
+
# Core Registration Functions
|
|
132
|
+
# =============================================================================
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _register_paste(
|
|
136
|
+
adata_list: list["ad.AnnData"],
|
|
137
|
+
params: RegistrationParameters,
|
|
138
|
+
spatial_key: str = "spatial",
|
|
139
|
+
) -> list["ad.AnnData"]:
|
|
140
|
+
"""Register slices using PASTE optimal transport."""
|
|
141
|
+
import paste as pst
|
|
142
|
+
import scanpy as sc
|
|
143
|
+
|
|
144
|
+
reference_idx = params.reference_idx or 0
|
|
145
|
+
registered = [adata.copy() for adata in adata_list]
|
|
146
|
+
common_genes = _get_common_genes(registered)
|
|
147
|
+
|
|
148
|
+
if len(registered) == 2:
|
|
149
|
+
# Pairwise alignment
|
|
150
|
+
|
|
151
|
+
slice1 = registered[0][:, common_genes].copy()
|
|
152
|
+
slice2 = registered[1][:, common_genes].copy()
|
|
153
|
+
|
|
154
|
+
# Normalize
|
|
155
|
+
sc.pp.normalize_total(slice1, target_sum=1e4)
|
|
156
|
+
sc.pp.log1p(slice1)
|
|
157
|
+
sc.pp.normalize_total(slice2, target_sum=1e4)
|
|
158
|
+
sc.pp.log1p(slice2)
|
|
159
|
+
|
|
160
|
+
# Run PASTE
|
|
161
|
+
pi = pst.pairwise_align(
|
|
162
|
+
slice1,
|
|
163
|
+
slice2,
|
|
164
|
+
alpha=params.paste_alpha,
|
|
165
|
+
numItermax=params.paste_numItermax,
|
|
166
|
+
verbose=True,
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
# Stack and extract aligned coordinates
|
|
170
|
+
aligned = pst.stack_slices_pairwise([slice1, slice2], [pi])
|
|
171
|
+
registered[0].obsm["spatial_registered"] = aligned[0].obsm["spatial"]
|
|
172
|
+
registered[1].obsm["spatial_registered"] = aligned[1].obsm["spatial"]
|
|
173
|
+
|
|
174
|
+
else:
|
|
175
|
+
# Multi-slice center alignment
|
|
176
|
+
|
|
177
|
+
slices = [adata[:, common_genes] for adata in registered]
|
|
178
|
+
backend = get_ot_backend(params.use_gpu)
|
|
179
|
+
|
|
180
|
+
# Initial pairwise alignments to reference
|
|
181
|
+
pis = []
|
|
182
|
+
for i, slice_data in enumerate(slices):
|
|
183
|
+
if i == reference_idx:
|
|
184
|
+
pis.append(np.eye(slices[i].shape[0]))
|
|
185
|
+
else:
|
|
186
|
+
pi = pst.pairwise_align(
|
|
187
|
+
slices[reference_idx],
|
|
188
|
+
slice_data,
|
|
189
|
+
alpha=params.paste_alpha,
|
|
190
|
+
backend=backend,
|
|
191
|
+
use_gpu=params.use_gpu,
|
|
192
|
+
verbose=False,
|
|
193
|
+
gpu_verbose=False,
|
|
194
|
+
)
|
|
195
|
+
pis.append(pi)
|
|
196
|
+
|
|
197
|
+
# Center alignment
|
|
198
|
+
_, pis_new = pst.center_align(
|
|
199
|
+
slices[reference_idx],
|
|
200
|
+
slices,
|
|
201
|
+
pis_init=pis,
|
|
202
|
+
alpha=params.paste_alpha,
|
|
203
|
+
backend=backend,
|
|
204
|
+
use_gpu=params.use_gpu,
|
|
205
|
+
n_components=params.paste_n_components,
|
|
206
|
+
verbose=False,
|
|
207
|
+
gpu_verbose=False,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Apply transformations
|
|
211
|
+
for i, (adata, pi) in enumerate(zip(registered, pis_new, strict=False)):
|
|
212
|
+
if i == reference_idx:
|
|
213
|
+
adata.obsm["spatial_registered"] = adata.obsm[spatial_key].copy()
|
|
214
|
+
else:
|
|
215
|
+
adata.obsm["spatial_registered"] = _transform_coordinates(
|
|
216
|
+
pi, slices[reference_idx].obsm[spatial_key]
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
return registered
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def _register_stalign(
|
|
223
|
+
adata_list: list["ad.AnnData"],
|
|
224
|
+
params: RegistrationParameters,
|
|
225
|
+
spatial_key: str = "spatial",
|
|
226
|
+
) -> list["ad.AnnData"]:
|
|
227
|
+
"""Register slices using STalign diffeomorphic mapping."""
|
|
228
|
+
import STalign.STalign as ST
|
|
229
|
+
import torch
|
|
230
|
+
|
|
231
|
+
if len(adata_list) != 2:
|
|
232
|
+
raise ParameterError(
|
|
233
|
+
f"STalign only supports pairwise registration, got {len(adata_list)} slices. "
|
|
234
|
+
f"Use PASTE for multi-slice alignment."
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
registered = [adata.copy() for adata in adata_list]
|
|
238
|
+
source, target = registered[0], registered[1]
|
|
239
|
+
|
|
240
|
+
# Prepare coordinates
|
|
241
|
+
source_coords = source.obsm[spatial_key].astype(np.float32)
|
|
242
|
+
target_coords = target.obsm[spatial_key].astype(np.float32)
|
|
243
|
+
|
|
244
|
+
# Prepare intensity
|
|
245
|
+
if params.stalign_use_expression:
|
|
246
|
+
common_genes = _get_common_genes(registered)
|
|
247
|
+
if len(common_genes) < 100:
|
|
248
|
+
logger.warning(f"Only {len(common_genes)} common genes found")
|
|
249
|
+
|
|
250
|
+
# Compute sum intensity (sparse-aware)
|
|
251
|
+
source_expr = source[:, common_genes].X
|
|
252
|
+
target_expr = target[:, common_genes].X
|
|
253
|
+
|
|
254
|
+
def _safe_sum(X):
|
|
255
|
+
if hasattr(X, "toarray"):
|
|
256
|
+
return np.array(X.sum(axis=1)).flatten().astype(np.float32)
|
|
257
|
+
return X.sum(axis=1).astype(np.float32)
|
|
258
|
+
|
|
259
|
+
source_intensity = _safe_sum(source_expr)
|
|
260
|
+
target_intensity = _safe_sum(target_expr)
|
|
261
|
+
else:
|
|
262
|
+
source_intensity = np.ones(len(source_coords), dtype=np.float32)
|
|
263
|
+
target_intensity = np.ones(len(target_coords), dtype=np.float32)
|
|
264
|
+
|
|
265
|
+
# Prepare images
|
|
266
|
+
image_size = params.stalign_image_size
|
|
267
|
+
source_grid, source_image = _prepare_stalign_image(
|
|
268
|
+
source_coords, source_intensity, image_size
|
|
269
|
+
)
|
|
270
|
+
target_grid, target_image = _prepare_stalign_image(
|
|
271
|
+
target_coords, target_intensity, image_size
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
# STalign parameters
|
|
275
|
+
device = get_device(prefer_gpu=params.use_gpu)
|
|
276
|
+
stalign_params = {
|
|
277
|
+
"a": params.stalign_a,
|
|
278
|
+
"p": 2.0,
|
|
279
|
+
"expand": 2.0,
|
|
280
|
+
"nt": 3,
|
|
281
|
+
"niter": params.stalign_niter,
|
|
282
|
+
"diffeo_start": 0,
|
|
283
|
+
"epL": 2e-08,
|
|
284
|
+
"epT": 0.2,
|
|
285
|
+
"epV": 2000.0,
|
|
286
|
+
"sigmaM": 1.0,
|
|
287
|
+
"sigmaB": 2.0,
|
|
288
|
+
"sigmaA": 5.0,
|
|
289
|
+
"sigmaR": 500000.0,
|
|
290
|
+
"sigmaP": 20.0,
|
|
291
|
+
"device": device,
|
|
292
|
+
"dtype": torch.float32,
|
|
293
|
+
}
|
|
294
|
+
|
|
295
|
+
try:
|
|
296
|
+
result = ST.LDDMM(
|
|
297
|
+
xI=source_grid,
|
|
298
|
+
I=source_image,
|
|
299
|
+
xJ=target_grid,
|
|
300
|
+
J=target_image,
|
|
301
|
+
**stalign_params,
|
|
302
|
+
)
|
|
303
|
+
|
|
304
|
+
A = result.get("A")
|
|
305
|
+
v = result.get("v")
|
|
306
|
+
xv = result.get("xv")
|
|
307
|
+
|
|
308
|
+
if A is None or v is None or xv is None:
|
|
309
|
+
raise ProcessingError("STalign did not return valid transformation")
|
|
310
|
+
|
|
311
|
+
# Transform coordinates
|
|
312
|
+
source_points = torch.tensor(source_coords, dtype=torch.float32)
|
|
313
|
+
transformed = ST.transform_points_source_to_target(xv, v, A, source_points)
|
|
314
|
+
|
|
315
|
+
if isinstance(transformed, torch.Tensor):
|
|
316
|
+
transformed = transformed.numpy()
|
|
317
|
+
|
|
318
|
+
source.obsm["spatial_registered"] = transformed
|
|
319
|
+
target.obsm["spatial_registered"] = target_coords.copy()
|
|
320
|
+
|
|
321
|
+
except Exception as e:
|
|
322
|
+
raise ProcessingError(
|
|
323
|
+
f"STalign registration failed: {e}. Consider using PASTE method."
|
|
324
|
+
) from e
|
|
325
|
+
|
|
326
|
+
return registered
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
def _transform_coordinates(
|
|
330
|
+
transport_matrix: np.ndarray,
|
|
331
|
+
reference_coords: np.ndarray,
|
|
332
|
+
) -> np.ndarray:
|
|
333
|
+
"""Transform coordinates via optimal transport matrix."""
|
|
334
|
+
# Normalize rows
|
|
335
|
+
row_sums = transport_matrix.sum(axis=1, keepdims=True)
|
|
336
|
+
row_sums[row_sums == 0] = 1 # Avoid division by zero
|
|
337
|
+
normalized = transport_matrix / row_sums
|
|
338
|
+
|
|
339
|
+
# Weighted average of reference coordinates
|
|
340
|
+
return normalized @ reference_coords
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
# =============================================================================
|
|
344
|
+
# Public API
|
|
345
|
+
# =============================================================================
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def register_slices(
|
|
349
|
+
adata_list: list["ad.AnnData"],
|
|
350
|
+
params: Optional[RegistrationParameters] = None,
|
|
351
|
+
) -> list["ad.AnnData"]:
|
|
352
|
+
"""
|
|
353
|
+
Register multiple spatial transcriptomics slices.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
adata_list: List of AnnData objects to register
|
|
357
|
+
params: Registration parameters (uses defaults if None)
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
List of registered AnnData objects with 'spatial_registered' in obsm
|
|
361
|
+
"""
|
|
362
|
+
if params is None:
|
|
363
|
+
params = RegistrationParameters()
|
|
364
|
+
|
|
365
|
+
if len(adata_list) < 2:
|
|
366
|
+
raise ParameterError("Registration requires at least 2 slices")
|
|
367
|
+
|
|
368
|
+
# Validate spatial coordinates and get the spatial key
|
|
369
|
+
spatial_key = _validate_spatial_coords(adata_list)
|
|
370
|
+
|
|
371
|
+
if params.method == "paste":
|
|
372
|
+
return _register_paste(adata_list, params, spatial_key)
|
|
373
|
+
elif params.method == "stalign":
|
|
374
|
+
return _register_stalign(adata_list, params, spatial_key)
|
|
375
|
+
else:
|
|
376
|
+
raise ParameterError(f"Unknown method: {params.method}")
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
# =============================================================================
|
|
380
|
+
# MCP Tool Wrapper
|
|
381
|
+
# =============================================================================
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
async def register_spatial_slices_mcp(
|
|
385
|
+
source_id: str,
|
|
386
|
+
target_id: str,
|
|
387
|
+
ctx: "ToolContext",
|
|
388
|
+
method: str = "paste",
|
|
389
|
+
) -> dict:
|
|
390
|
+
"""
|
|
391
|
+
MCP wrapper for spatial registration.
|
|
392
|
+
|
|
393
|
+
Args:
|
|
394
|
+
source_id: Source dataset ID
|
|
395
|
+
target_id: Target dataset ID
|
|
396
|
+
ctx: Tool context for data access
|
|
397
|
+
method: Registration method ('paste' or 'stalign')
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
Registration result dictionary
|
|
401
|
+
"""
|
|
402
|
+
# Check dependencies
|
|
403
|
+
if method == "paste":
|
|
404
|
+
require("paste", ctx, feature="PASTE spatial registration")
|
|
405
|
+
elif method == "stalign":
|
|
406
|
+
require("STalign", ctx, feature="STalign spatial registration")
|
|
407
|
+
|
|
408
|
+
# Get data
|
|
409
|
+
source_adata = await ctx.get_adata(source_id)
|
|
410
|
+
target_adata = await ctx.get_adata(target_id)
|
|
411
|
+
|
|
412
|
+
# Create parameters
|
|
413
|
+
params = RegistrationParameters(method=method)
|
|
414
|
+
|
|
415
|
+
try:
|
|
416
|
+
registered = register_slices([source_adata, target_adata], params)
|
|
417
|
+
|
|
418
|
+
# Copy registered coordinates back (in-place modification)
|
|
419
|
+
if "spatial_registered" in registered[0].obsm:
|
|
420
|
+
source_adata.obsm["spatial_registered"] = registered[0].obsm[
|
|
421
|
+
"spatial_registered"
|
|
422
|
+
]
|
|
423
|
+
if "spatial_registered" in registered[1].obsm:
|
|
424
|
+
target_adata.obsm["spatial_registered"] = registered[1].obsm[
|
|
425
|
+
"spatial_registered"
|
|
426
|
+
]
|
|
427
|
+
|
|
428
|
+
result = {
|
|
429
|
+
"method": method,
|
|
430
|
+
"source_id": source_id,
|
|
431
|
+
"target_id": target_id,
|
|
432
|
+
"n_source_spots": source_adata.n_obs,
|
|
433
|
+
"n_target_spots": target_adata.n_obs,
|
|
434
|
+
"registration_completed": True,
|
|
435
|
+
"spatial_key_registered": "spatial_registered",
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
return result
|
|
439
|
+
|
|
440
|
+
except Exception as e:
|
|
441
|
+
raise ProcessingError(f"Registration failed: {e}") from e
|