chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,807 @@
1
+ """
2
+ Integration tools for spatial transcriptomics data.
3
+ """
4
+
5
+ import logging
6
+ from typing import TYPE_CHECKING, Optional
7
+
8
+ import anndata as ad
9
+ import numpy as np
10
+ import scanpy as sc
11
+
12
+ from ..models.analysis import IntegrationResult
13
+ from ..models.data import IntegrationParameters
14
+ from ..utils.dependency_manager import require
15
+ from ..utils.exceptions import (
16
+ DataError,
17
+ DataNotFoundError,
18
+ ParameterError,
19
+ ProcessingError,
20
+ )
21
+
22
+ if TYPE_CHECKING:
23
+ from ..spatial_mcp_adapter import ToolContext
24
+
25
+ from ..utils.adata_utils import (
26
+ get_spatial_key,
27
+ store_analysis_metadata,
28
+ validate_adata_basics,
29
+ )
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ def integrate_multiple_samples(
35
+ adatas,
36
+ batch_key="batch",
37
+ method="harmony",
38
+ n_pcs=30,
39
+ params: Optional[IntegrationParameters] = None,
40
+ ):
41
+ """Integrate multiple spatial transcriptomics samples
42
+
43
+ This function expects preprocessed data (normalized, log-transformed, with HVGs marked).
44
+ Use preprocessing.py or preprocess_data() before calling this function.
45
+
46
+ Args:
47
+ adatas: List of preprocessed AnnData objects or a single combined AnnData object
48
+ batch_key: Batch information key
49
+ method: Integration method, options: 'harmony', 'bbknn', 'scanorama', 'scvi'
50
+ n_pcs: Number of principal components for integration
51
+ params: Optional IntegrationParameters for method-specific settings (scVI, etc.)
52
+
53
+ Returns:
54
+ Integrated AnnData object with batch correction applied
55
+
56
+ Raises:
57
+ ValueError: If data is not properly preprocessed
58
+ """
59
+
60
+ # Merge datasets
61
+ if isinstance(adatas, list):
62
+ # Validate list has at least 2 datasets for integration
63
+ if len(adatas) < 2:
64
+ raise ParameterError(
65
+ f"Integration requires at least 2 datasets, got {len(adatas)}. "
66
+ "Use preprocess_data for single dataset processing."
67
+ )
68
+
69
+ # Check if datasets have batch labels
70
+ has_batch_labels = all(batch_key in adata.obs for adata in adatas)
71
+
72
+ if not has_batch_labels:
73
+ # Auto-create batch labels for multi-sample integration
74
+ # Each sample becomes its own batch (scientifically correct for independent samples)
75
+ for i, adata in enumerate(adatas):
76
+ if batch_key not in adata.obs:
77
+ adata.obs[batch_key] = f"sample_{i}"
78
+
79
+ # Merge datasets
80
+ combined = adatas[0].concatenate(
81
+ adatas[1:],
82
+ batch_key=batch_key,
83
+ join="outer", # Use outer join to keep all genes
84
+ )
85
+
86
+ # FIX: Clean var columns with NA values in object dtype
87
+ # Problem: outer join creates NA values in var columns when genes don't exist in all samples
88
+ # When object columns contain NA, H5AD save corrupts var.index (becomes 0,1,2...)
89
+ # and moves gene names to _index column
90
+ # Solution: Fill NA with appropriate values or convert types
91
+ import pandas as pd
92
+
93
+ for col in combined.var.columns:
94
+ if combined.var[col].dtype == "object" and combined.var[col].isna().any():
95
+ # For boolean-like columns (highly_variable, etc.), fill NA with False
96
+ unique_vals = combined.var[col].dropna().unique()
97
+ if set(unique_vals).issubset({True, False, "True", "False"}):
98
+ combined.var[col] = combined.var[col].fillna(False).astype(bool)
99
+ else:
100
+ # For string columns, fill NA with empty string
101
+ combined.var[col] = combined.var[col].fillna("").astype(str)
102
+
103
+ # FIX: Remove incomplete diffmap artifacts created by concatenation (scanpy issue #1021)
104
+ # Problem: concatenate() copies obsm['X_diffmap'] but NOT uns['diffmap_evals']
105
+ # This creates incomplete state that causes KeyError in sc.tl.umap()
106
+ # Solution: Delete incomplete artifacts to allow UMAP to use default initialization
107
+ if "X_diffmap" in combined.obsm:
108
+ del combined.obsm["X_diffmap"]
109
+ if "diffmap_evals" in combined.uns:
110
+ del combined.uns["diffmap_evals"]
111
+
112
+ else:
113
+ # If already a merged dataset, ensure it has batch information
114
+ combined = adatas
115
+ if batch_key not in combined.obs:
116
+ raise ParameterError(
117
+ f"Merged dataset is missing batch information key '{batch_key}'"
118
+ )
119
+
120
+ # Validate input data is preprocessed
121
+ # Check if data appears to be raw (high values without log transformation)
122
+ max_val = combined.X.max() if hasattr(combined.X, "max") else np.max(combined.X)
123
+ min_val = combined.X.min() if hasattr(combined.X, "min") else np.min(combined.X)
124
+
125
+ # Raw count data typically has high integer values and no negative values
126
+ # Properly preprocessed data should be either:
127
+ # 1. Log-transformed (positive values, typically 0-15 range)
128
+ # 2. Scaled (centered around 0, can have negative values)
129
+ if min_val >= 0 and max_val > 100:
130
+ raise DataError("Data appears to be raw counts. Run preprocessing first.")
131
+
132
+ # Check if data appears to be normalized (reasonable range after preprocessing)
133
+ if max_val > 50:
134
+ logger.warning(
135
+ f"Data has very high values (max={max_val:.1f}). "
136
+ "Consider log transformation if not already applied."
137
+ )
138
+
139
+ # Validate data quality before processing
140
+ validate_adata_basics(combined, min_obs=10, min_vars=10, check_empty_ratio=True)
141
+
142
+ # Check if data has highly variable genes marked (should be done in preprocessing)
143
+ if "highly_variable" not in combined.var.columns:
144
+ logger.warning(
145
+ "No highly variable genes marked after merge. Recalculating HVGs with batch correction."
146
+ )
147
+ # Recalculate HVGs with batch correction
148
+ sc.pp.highly_variable_genes(
149
+ combined,
150
+ min_mean=0.0125,
151
+ max_mean=3,
152
+ min_disp=0.5,
153
+ batch_key=batch_key,
154
+ n_top_genes=2000,
155
+ )
156
+ n_hvg = combined.var["highly_variable"].sum()
157
+ else:
158
+ n_hvg = combined.var["highly_variable"].sum()
159
+ if n_hvg == 0:
160
+ logger.warning(
161
+ "No genes marked as highly variable after merge, recalculating"
162
+ )
163
+ # Recalculate HVGs with batch correction
164
+ sc.pp.highly_variable_genes(
165
+ combined,
166
+ min_mean=0.0125,
167
+ max_mean=3,
168
+ min_disp=0.5,
169
+ batch_key=batch_key,
170
+ n_top_genes=2000,
171
+ )
172
+ n_hvg = combined.var["highly_variable"].sum()
173
+ elif n_hvg < 50:
174
+ logger.warning(
175
+ f"Very few HVGs ({n_hvg}), recalculating with batch correction"
176
+ )
177
+ sc.pp.highly_variable_genes(
178
+ combined,
179
+ min_mean=0.0125,
180
+ max_mean=3,
181
+ min_disp=0.5,
182
+ batch_key=batch_key,
183
+ n_top_genes=2000,
184
+ )
185
+ n_hvg = combined.var["highly_variable"].sum()
186
+
187
+ # Save raw data if not already saved
188
+ # IMPORTANT: Create a proper frozen copy for .raw to preserve counts
189
+ # Using `combined.raw = combined` creates a view that gets modified during normalization
190
+ if combined.raw is None:
191
+ combined.raw = ad.AnnData(
192
+ X=combined.X.copy(), # Must copy - will be modified during normalization
193
+ var=combined.var, # No copy needed - AnnData internally creates independent copy
194
+ obs=combined.obs.copy(), # Must copy - will be modified by clustering/annotation
195
+ uns={}, # Empty dict - raw doesn't need uns metadata
196
+ )
197
+
198
+ # ========================================================================
199
+ # EARLY BRANCH FOR scVI-TOOLS METHODS
200
+ # scVI requires normalized+log data WITHOUT scaling/PCA
201
+ # It generates its own latent representation
202
+ # NOTE: scVI-tools methods work better with ALL genes, not just HVGs
203
+ # ========================================================================
204
+ if method == "scvi":
205
+ # Use user-configurable parameters if provided, otherwise use defaults
206
+ # This ensures scientific reproducibility and user control
207
+ scvi_n_hidden = params.scvi_n_hidden if params else 128
208
+ scvi_n_latent = params.scvi_n_latent if params else 10
209
+ scvi_n_layers = params.scvi_n_layers if params else 1
210
+ scvi_dropout_rate = params.scvi_dropout_rate if params else 0.1
211
+ scvi_gene_likelihood = params.scvi_gene_likelihood if params else "zinb"
212
+ scvi_n_epochs = params.n_epochs if params else None
213
+ scvi_use_gpu = params.use_gpu if params else False
214
+
215
+ try:
216
+ combined = integrate_with_scvi(
217
+ combined,
218
+ batch_key=batch_key,
219
+ n_hidden=scvi_n_hidden,
220
+ n_latent=scvi_n_latent,
221
+ n_layers=scvi_n_layers,
222
+ dropout_rate=scvi_dropout_rate,
223
+ gene_likelihood=scvi_gene_likelihood,
224
+ n_epochs=scvi_n_epochs,
225
+ use_gpu=scvi_use_gpu,
226
+ )
227
+ except Exception as e:
228
+ raise ProcessingError(
229
+ f"scVI integration failed: {e}. "
230
+ f"Ensure data is preprocessed and has ≥2 batches."
231
+ ) from e
232
+
233
+ # Calculate UMAP embedding to visualize integration effect
234
+ sc.tl.umap(combined)
235
+
236
+ # Store metadata for scientific provenance tracking
237
+ n_batches = combined.obs[batch_key].nunique()
238
+ batch_sizes = combined.obs[batch_key].value_counts().to_dict()
239
+
240
+ # CRITICAL FIX: Convert dict keys to strings for H5AD compatibility
241
+ batch_sizes = {str(k): int(v) for k, v in batch_sizes.items()}
242
+
243
+ store_analysis_metadata(
244
+ combined,
245
+ analysis_name="integration_scvi",
246
+ method="scvi",
247
+ parameters={
248
+ "batch_key": batch_key,
249
+ "n_hidden": scvi_n_hidden,
250
+ "n_latent": scvi_n_latent,
251
+ "n_layers": scvi_n_layers,
252
+ "dropout_rate": scvi_dropout_rate,
253
+ "gene_likelihood": scvi_gene_likelihood,
254
+ "n_epochs": scvi_n_epochs,
255
+ "use_gpu": scvi_use_gpu,
256
+ },
257
+ results_keys={"obsm": ["X_scVI"], "uns": ["neighbors"]},
258
+ statistics={
259
+ "n_batches": int(n_batches),
260
+ "batch_sizes": batch_sizes,
261
+ "n_cells_total": int(combined.n_obs),
262
+ "n_genes": int(combined.n_vars),
263
+ },
264
+ )
265
+
266
+ return combined
267
+
268
+ # ========================================================================
269
+ # CLASSICAL METHODS: Continue with scale → PCA → integration
270
+ # ========================================================================
271
+
272
+ # Filter to highly variable genes for classical methods
273
+ if "highly_variable" in combined.var.columns:
274
+ n_hvg = combined.var["highly_variable"].sum()
275
+ if n_hvg == 0:
276
+ raise DataError(
277
+ "No highly variable genes found. Check HVG selection parameters."
278
+ )
279
+ # Memory optimization: Subsetting creates view, reassignment triggers GC
280
+ # No need to materialize with .copy() - view will be materialized on first write
281
+ combined = combined[:, combined.var["highly_variable"]]
282
+
283
+ # Remove genes with zero variance to avoid NaN in scaling
284
+ import numpy as np
285
+ from scipy import sparse
286
+
287
+ # MEMORY OPTIMIZATION: Calculate variance without toarray()
288
+ # Uses E[X²] - E[X]² formula for sparse matrices
289
+ # Saves ~80% memory (e.g., 76 MB → 15 MB for 10k cells × 2k genes)
290
+ if sparse.issparse(combined.X):
291
+ # Sparse matrix: compute variance using E[X²] - E[X]² formula
292
+ # This avoids creating dense copy (5-10x memory reduction)
293
+ mean_per_gene = np.array(combined.X.mean(axis=0)).flatten()
294
+
295
+ # Calculate E[X²]
296
+ X_squared = combined.X.copy()
297
+ # Square the data: use np.array() for type safety (handles memoryview, ensures copy)
298
+ X_squared.data = np.array(X_squared.data) ** 2
299
+ mean_squared = np.array(X_squared.mean(axis=0)).flatten()
300
+
301
+ # Variance = E[X²] - E[X]²
302
+ gene_var = mean_squared - mean_per_gene**2
303
+ else:
304
+ # Dense matrix: use standard variance calculation
305
+ gene_var = np.var(combined.X, axis=0)
306
+ nonzero_var_genes = gene_var > 0
307
+ if not np.all(nonzero_var_genes):
308
+ n_removed = np.sum(~nonzero_var_genes)
309
+ logger.warning(f"Removing {n_removed} genes with zero variance before scaling")
310
+ # Memory optimization: Subsetting creates view, no need to copy
311
+ # View will be materialized when scaling modifies the data
312
+ combined = combined[:, nonzero_var_genes]
313
+
314
+ # Scale data with proper error handling
315
+ try:
316
+ sc.pp.scale(combined, zero_center=True, max_value=10)
317
+ except Exception as e:
318
+ logger.warning(f"Scaling with zero centering failed: {e}")
319
+ try:
320
+ sc.pp.scale(combined, zero_center=False, max_value=10)
321
+ except Exception as e2:
322
+ raise ProcessingError(
323
+ f"Data scaling failed completely. Zero-center error: {e}. Non-zero-center error: {e2}. "
324
+ f"This usually indicates data contains extreme outliers or invalid values. "
325
+ f"Consider additional quality control or outlier removal."
326
+ ) from e2
327
+
328
+ # PCA with proper error handling
329
+ # Determine safe number of components
330
+ max_possible_components = min(n_pcs, combined.n_vars, combined.n_obs - 1)
331
+
332
+ if max_possible_components < 2:
333
+ raise DataError(
334
+ f"Cannot perform PCA: only {max_possible_components} components possible. "
335
+ f"Dataset has {combined.n_obs} cells and {combined.n_vars} genes. "
336
+ f"Minimum 2 components required for downstream analysis."
337
+ )
338
+
339
+ # Check data matrix before PCA
340
+ # MEMORY OPTIMIZATION: Check sparse matrix .data directly without toarray()
341
+ # Sparse matrices only store non-zero elements, and zero elements cannot be NaN/Inf
342
+ # Saves ~80% memory (e.g., 76 MB → 15 MB for 10k cells × 2k genes)
343
+ import numpy as np
344
+ from scipy import sparse
345
+
346
+ if sparse.issparse(combined.X):
347
+ # Sparse matrix: only check non-zero elements stored in .data
348
+ # This avoids creating a dense copy (5-10x memory reduction)
349
+ if np.isnan(combined.X.data).any():
350
+ raise DataError("Data contains NaN values after scaling")
351
+ if np.isinf(combined.X.data).any():
352
+ raise DataError("Data contains infinite values after scaling")
353
+ else:
354
+ # Dense matrix: check all elements
355
+ if np.isnan(combined.X).any():
356
+ raise DataError("Data contains NaN values after scaling")
357
+ if np.isinf(combined.X).any():
358
+ raise DataError("Data contains infinite values after scaling")
359
+
360
+ # Variance check removed: zero-variance genes already filtered at lines 301-323
361
+
362
+ # Try PCA with different solvers, but fail properly if none work
363
+ pca_success = False
364
+ for solver, max_comps in [
365
+ ("arpack", min(max_possible_components, 50)),
366
+ ("randomized", min(max_possible_components, 50)),
367
+ ("full", min(max_possible_components, 20)),
368
+ ]:
369
+ try:
370
+ sc.tl.pca(combined, n_comps=max_comps, svd_solver=solver, zero_center=False)
371
+ pca_success = True
372
+ break
373
+ except Exception as e:
374
+ logger.warning(f"PCA with {solver} solver failed: {e}")
375
+ continue
376
+
377
+ if not pca_success:
378
+ raise ProcessingError(
379
+ f"PCA failed for {combined.n_obs}×{combined.n_vars} data. Check data quality."
380
+ )
381
+
382
+ # Apply batch correction based on selected method
383
+ if method == "harmony":
384
+ # Use Harmony for batch correction
385
+ # BEST PRACTICE: Use scanpy.external wrapper for better integration with scanpy workflow
386
+ require("harmonypy", feature="Harmony integration")
387
+ try:
388
+ import scanpy.external as sce
389
+
390
+ # Check if harmony_integrate is available in scanpy.external
391
+ if hasattr(sce.pp, "harmony_integrate"):
392
+ # Use scanpy.external wrapper (preferred method)
393
+ sce.pp.harmony_integrate(
394
+ combined,
395
+ key=batch_key,
396
+ basis="X_pca", # Use PCA representation
397
+ adjusted_basis="X_pca_harmony", # Store corrected embedding
398
+ )
399
+ # Use corrected embedding for downstream analysis
400
+ sc.pp.neighbors(combined, use_rep="X_pca_harmony")
401
+ else:
402
+ # Fallback to raw harmonypy (same algorithm, different interface)
403
+ import harmonypy
404
+ import pandas as pd
405
+
406
+ # Get PCA result
407
+ X_pca = combined.obsm["X_pca"]
408
+
409
+ # Create DataFrame with batch information
410
+ meta_data = pd.DataFrame({batch_key: combined.obs[batch_key]})
411
+
412
+ # Run Harmony
413
+ harmony_out = harmonypy.run_harmony(
414
+ data_mat=X_pca,
415
+ meta_data=meta_data,
416
+ vars_use=[batch_key],
417
+ sigma=0.1,
418
+ nclust=None,
419
+ max_iter_harmony=10,
420
+ verbose=True,
421
+ )
422
+
423
+ # Save Harmony corrected result
424
+ combined.obsm["X_harmony"] = harmony_out.Z_corr.T
425
+
426
+ # Use corrected result to calculate neighbor graph
427
+ sc.pp.neighbors(combined, use_rep="X_harmony")
428
+
429
+ except Exception as e:
430
+ raise ProcessingError(
431
+ f"Harmony integration failed: {e}. "
432
+ f"Check batch_key '{batch_key}' has ≥2 valid batches."
433
+ ) from e
434
+
435
+ elif method == "bbknn":
436
+ # Use BBKNN for batch correction
437
+ require("bbknn", feature="BBKNN integration")
438
+ import bbknn
439
+
440
+ bbknn.bbknn(combined, batch_key=batch_key, neighbors_within_batch=3)
441
+
442
+ elif method == "scanorama":
443
+ # Use Scanorama for batch correction
444
+ # BEST PRACTICE: Use scanpy.external wrapper for better integration with scanpy workflow
445
+ require("scanorama", feature="Scanorama integration")
446
+ try:
447
+ import scanpy.external as sce
448
+
449
+ # Check if scanorama_integrate is available in scanpy.external
450
+ if hasattr(sce.pp, "scanorama_integrate"):
451
+ # Use scanpy.external wrapper (preferred method)
452
+ sce.pp.scanorama_integrate(
453
+ combined, key=batch_key, basis="X_pca", adjusted_basis="X_scanorama"
454
+ )
455
+ # Use integrated representation for neighbor graph
456
+ sc.pp.neighbors(combined, use_rep="X_scanorama")
457
+ else:
458
+ # Fallback to raw scanorama (same algorithm, different interface)
459
+ import numpy as np
460
+ import scanorama
461
+
462
+ # Separate data by batch
463
+ datasets = []
464
+ genes_list = []
465
+ batch_order = []
466
+
467
+ for batch in combined.obs[batch_key].unique():
468
+ batch_mask = combined.obs[batch_key] == batch
469
+ batch_data = combined[batch_mask]
470
+
471
+ # Scanorama natively supports sparse matrices
472
+ datasets.append(batch_data.X)
473
+ genes_list.append(batch_data.var_names.tolist())
474
+ batch_order.append(batch)
475
+
476
+ # Run Scanorama integration
477
+ integrated, corrected_genes = scanorama.integrate(
478
+ datasets, genes_list, dimred=100
479
+ )
480
+
481
+ # Stack integrated results back together
482
+ integrated_X = np.vstack(integrated)
483
+
484
+ # Store integrated representation in obsm
485
+ combined.obsm["X_scanorama"] = integrated_X
486
+
487
+ # Use integrated representation for neighbor graph
488
+ sc.pp.neighbors(combined, use_rep="X_scanorama")
489
+
490
+ except Exception as e:
491
+ raise ProcessingError(
492
+ f"Scanorama integration failed: {e}. "
493
+ f"Check gene overlap between batches."
494
+ ) from e
495
+
496
+ else:
497
+ # Default: use uncorrected PCA result
498
+ logger.warning(
499
+ f"Integration method '{method}' not recognized. "
500
+ f"Using uncorrected PCA embedding."
501
+ )
502
+ sc.pp.neighbors(combined)
503
+
504
+ # Calculate UMAP embedding to visualize integration effect
505
+ sc.tl.umap(combined)
506
+
507
+ # Store metadata for scientific provenance tracking
508
+ # Determine which representation was used
509
+ if method == "harmony":
510
+ if "X_pca_harmony" in combined.obsm:
511
+ results_keys = {"obsm": ["X_pca_harmony"], "uns": ["neighbors"]}
512
+ else:
513
+ results_keys = {"obsm": ["X_harmony"], "uns": ["neighbors"]}
514
+ elif method == "bbknn":
515
+ results_keys = {"uns": ["neighbors"]}
516
+ elif method == "scanorama":
517
+ results_keys = {"obsm": ["X_scanorama"], "uns": ["neighbors"]}
518
+ else:
519
+ results_keys = {"obsm": ["X_pca"], "uns": ["neighbors"]}
520
+
521
+ # Get batch statistics
522
+ n_batches = combined.obs[batch_key].nunique()
523
+ batch_sizes = combined.obs[batch_key].value_counts().to_dict()
524
+
525
+ # CRITICAL FIX: Convert dict keys to strings for H5AD compatibility
526
+ # H5AD requires all dictionary keys to be strings
527
+ # Without this, save_data() fails with "Can't implicitly convert non-string objects to strings"
528
+ batch_sizes = {str(k): int(v) for k, v in batch_sizes.items()}
529
+
530
+ store_analysis_metadata(
531
+ combined,
532
+ analysis_name=f"integration_{method}",
533
+ method=method,
534
+ parameters={
535
+ "batch_key": batch_key,
536
+ "n_pcs": n_pcs,
537
+ "n_batches": n_batches,
538
+ },
539
+ results_keys=results_keys,
540
+ statistics={
541
+ "n_batches": int(n_batches), # Also ensure int types for H5AD
542
+ "batch_sizes": batch_sizes,
543
+ "n_cells_total": int(combined.n_obs),
544
+ "n_genes": int(combined.n_vars),
545
+ },
546
+ )
547
+
548
+ return combined
549
+
550
+
551
+ def align_spatial_coordinates(combined_adata, batch_key="batch", reference_batch=None):
552
+ """Align spatial coordinates of multiple samples
553
+
554
+ Args:
555
+ combined_adata: Combined AnnData object containing multiple samples
556
+ batch_key: Batch information key
557
+ reference_batch: Reference batch, if None use the first batch
558
+
559
+ Returns:
560
+ AnnData object with aligned spatial coordinates
561
+ """
562
+ import numpy as np
563
+ from sklearn.preprocessing import StandardScaler
564
+
565
+ # Ensure data contains spatial coordinates
566
+ spatial_key = get_spatial_key(combined_adata)
567
+ if not spatial_key:
568
+ raise DataNotFoundError("Data is missing spatial coordinates")
569
+
570
+ # Get batch information
571
+ batches = combined_adata.obs[batch_key].unique()
572
+
573
+ if len(batches) == 0:
574
+ raise DataError("Dataset is empty, cannot perform spatial registration")
575
+
576
+ # If reference batch not specified, use the first batch
577
+ if reference_batch is None:
578
+ reference_batch = batches[0]
579
+ elif reference_batch not in batches:
580
+ raise ParameterError(f"Reference batch '{reference_batch}' not found in data")
581
+
582
+ # Get reference batch spatial coordinates
583
+ ref_coords = combined_adata[combined_adata.obs[batch_key] == reference_batch].obsm[
584
+ spatial_key
585
+ ]
586
+
587
+ # Standardize reference coordinates
588
+ scaler = StandardScaler()
589
+ ref_coords_scaled = scaler.fit_transform(ref_coords)
590
+
591
+ # Align spatial coordinates for each batch
592
+ aligned_coords = []
593
+
594
+ for batch in batches:
595
+ # Get current batch index
596
+ batch_idx = combined_adata.obs[batch_key] == batch
597
+
598
+ if batch == reference_batch:
599
+ # Reference batch remains unchanged
600
+ aligned_coords.append(ref_coords_scaled)
601
+ else:
602
+ # Get current batch spatial coordinates
603
+ batch_coords = combined_adata[batch_idx].obsm[spatial_key]
604
+
605
+ # Standardize current batch coordinates
606
+ batch_coords_scaled = scaler.transform(batch_coords)
607
+
608
+ # Add to aligned coordinates list
609
+ aligned_coords.append(batch_coords_scaled)
610
+
611
+ # Merge all aligned coordinates
612
+ combined_adata.obsm["spatial_aligned"] = np.zeros((combined_adata.n_obs, 2))
613
+
614
+ # Fill aligned coordinates back to original data
615
+ start_idx = 0
616
+ for batch, coords in zip(batches, aligned_coords, strict=False):
617
+ batch_idx = combined_adata.obs[batch_key] == batch
618
+ n_cells = np.sum(batch_idx)
619
+ combined_adata.obsm["spatial_aligned"][start_idx : start_idx + n_cells] = coords
620
+ start_idx += n_cells
621
+
622
+ # Store metadata for scientific provenance tracking
623
+ n_batches = len(batches)
624
+ batch_sizes = {
625
+ batch: np.sum(combined_adata.obs[batch_key] == batch) for batch in batches
626
+ }
627
+
628
+ store_analysis_metadata(
629
+ combined_adata,
630
+ analysis_name="spatial_alignment",
631
+ method="standardization",
632
+ parameters={
633
+ "batch_key": batch_key,
634
+ "reference_batch": reference_batch,
635
+ },
636
+ results_keys={"obsm": ["spatial_aligned"]},
637
+ statistics={
638
+ "n_batches": n_batches,
639
+ "batch_sizes": batch_sizes,
640
+ "reference_batch": reference_batch,
641
+ },
642
+ )
643
+
644
+ return combined_adata
645
+
646
+
647
+ def integrate_with_scvi(
648
+ combined: sc.AnnData,
649
+ batch_key: str = "batch",
650
+ n_hidden: int = 128,
651
+ n_latent: int = 10,
652
+ n_layers: int = 1,
653
+ dropout_rate: float = 0.1,
654
+ gene_likelihood: str = "zinb",
655
+ n_epochs: Optional[int] = None,
656
+ use_gpu: bool = False,
657
+ ) -> sc.AnnData:
658
+ """Integrate data using scVI for batch correction
659
+
660
+ scVI is a deep generative model for single-cell RNA-seq that can perform
661
+ batch correction by learning a low-dimensional latent representation.
662
+
663
+ Args:
664
+ combined: Combined AnnData object with multiple batches
665
+ batch_key: Column name in obs containing batch labels
666
+ n_hidden: Number of nodes per hidden layer (default: 128)
667
+ n_latent: Dimensionality of the latent space (default: 10)
668
+ n_layers: Number of hidden layers (default: 1)
669
+ dropout_rate: Dropout rate for neural networks (default: 0.1)
670
+ gene_likelihood: Distribution for gene expression (default: "zinb")
671
+ n_epochs: Number of training epochs (None = auto-determine)
672
+ use_gpu: Whether to use GPU acceleration (default: False)
673
+
674
+ Returns:
675
+ AnnData object with scVI latent representation in obsm['X_scvi']
676
+
677
+ Raises:
678
+ ImportError: If scvi-tools is not installed
679
+ ValueError: If data is not preprocessed or invalid
680
+
681
+ Reference:
682
+ Lopez et al. (2018) "Deep generative modeling for single-cell transcriptomics"
683
+ Nature Methods 15, 1053–1058
684
+ """
685
+ import numpy as np
686
+
687
+ require("scvi", feature="scVI integration")
688
+ import scvi
689
+
690
+ # Validate data is preprocessed
691
+ max_val = combined.X.max() if hasattr(combined.X, "max") else np.max(combined.X)
692
+ if max_val > 50:
693
+ raise DataError(
694
+ f"scVI requires preprocessed data. Max value {max_val:.1f} too high."
695
+ )
696
+
697
+ # Validate batch key
698
+ if batch_key not in combined.obs:
699
+ raise ParameterError(
700
+ f"Batch key '{batch_key}' not found in adata.obs. "
701
+ f"Available columns: {list(combined.obs.columns)}"
702
+ )
703
+
704
+ # Check for batch diversity
705
+ n_batches = combined.obs[batch_key].nunique()
706
+ if n_batches < 2:
707
+ raise DataError(
708
+ f"scVI requires at least 2 batches, found {n_batches}. "
709
+ "Check your batch labels."
710
+ )
711
+
712
+ # Setup AnnData for scVI
713
+ scvi.model.SCVI.setup_anndata(
714
+ combined, batch_key=batch_key, layer=None # Use .X (should be preprocessed)
715
+ )
716
+
717
+ # Initialize scVI model
718
+ model = scvi.model.SCVI(
719
+ combined,
720
+ n_hidden=n_hidden,
721
+ n_latent=n_latent,
722
+ n_layers=n_layers,
723
+ dropout_rate=dropout_rate,
724
+ gene_likelihood=gene_likelihood,
725
+ )
726
+
727
+ # Auto-determine epochs based on dataset size if not specified
728
+ if n_epochs is None:
729
+ n_cells = combined.n_obs
730
+ if n_cells < 1000:
731
+ n_epochs = 400
732
+ elif n_cells < 10000:
733
+ n_epochs = 200
734
+ else:
735
+ n_epochs = 100
736
+
737
+ # Train model
738
+ # Note: scvi-tools 1.x uses accelerator instead of use_gpu
739
+ accelerator = "gpu" if use_gpu else "cpu"
740
+ model.train(max_epochs=n_epochs, early_stopping=True, accelerator=accelerator)
741
+
742
+ # Get latent representation
743
+ combined.obsm["X_scvi"] = model.get_latent_representation()
744
+
745
+ # Compute neighbors using scVI embedding
746
+ sc.pp.neighbors(combined, use_rep="X_scvi")
747
+
748
+ return combined
749
+
750
+
751
+ async def integrate_samples(
752
+ data_ids: list[str],
753
+ ctx: "ToolContext",
754
+ params: IntegrationParameters = IntegrationParameters(),
755
+ ) -> IntegrationResult:
756
+ """Integrate multiple spatial transcriptomics samples and perform batch correction
757
+
758
+ Args:
759
+ data_ids: List of dataset IDs to integrate
760
+ ctx: ToolContext for unified data access and logging
761
+ params: Integration parameters
762
+
763
+ Returns:
764
+ Integration result
765
+ """
766
+ # Collect all AnnData objects
767
+ # Memory optimization: concatenate() creates new object without modifying sources
768
+ # Verified by comprehensive testing: all operations preserve original datasets
769
+ # Users can still access A, B, C after integration via ctx references
770
+ adatas = []
771
+ for data_id in data_ids:
772
+ adata = await ctx.get_adata(data_id)
773
+ adatas.append(adata)
774
+
775
+ # Integrate samples (pass full params for method-specific settings like scVI)
776
+ combined_adata = integrate_multiple_samples(
777
+ adatas,
778
+ batch_key=params.batch_key,
779
+ method=params.method,
780
+ n_pcs=params.n_pcs,
781
+ params=params,
782
+ )
783
+
784
+ # Align spatial coordinates if requested and available
785
+ # Note: Spatial alignment is optional - BBKNN, Harmony, MNN, Scanorama
786
+ # work on gene expression/PCA space without spatial coordinates
787
+ if params.align_spatial and "spatial" in combined_adata.obsm:
788
+ combined_adata = align_spatial_coordinates(
789
+ combined_adata,
790
+ batch_key=params.batch_key,
791
+ reference_batch=params.reference_batch,
792
+ )
793
+
794
+ # Generate new integrated dataset ID
795
+ integrated_id = f"integrated_{'-'.join(data_ids)}"
796
+
797
+ # Store integrated data using ToolContext
798
+ await ctx.add_dataset(integrated_id, combined_adata)
799
+
800
+ # Return result
801
+ return IntegrationResult(
802
+ data_id=integrated_id,
803
+ n_samples=len(data_ids),
804
+ integration_method=params.method,
805
+ umap_visualization=None, # Use visualize_data tool instead
806
+ spatial_visualization=None, # Use visualize_data tool instead
807
+ )