chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1372 @@
1
+ """
2
+ AnnData utilities for ChatSpatial.
3
+
4
+ This module provides:
5
+ 1. Standard field name constants
6
+ 2. Field discovery functions (get_*_key)
7
+ 3. Data access functions (get_*)
8
+ 4. Validation functions (validate_*)
9
+ 5. Ensure functions (ensure_*)
10
+
11
+ One file for all AnnData-related utilities. No duplication.
12
+
13
+ Naming Conventions (MUST follow across codebase):
14
+ -------------------------------------------------
15
+ - validate_*(adata, ...) -> None
16
+ Check-only. Raises exception if validation fails.
17
+ Does NOT modify data. Use for precondition checks.
18
+ Example: validate_obs_column(adata, "leiden")
19
+
20
+ - ensure_*(adata, ...) -> bool
21
+ Check-and-fix. Returns True if action was taken, False if already OK.
22
+ MAY modify data in-place. Idempotent (safe to call multiple times).
23
+ Example: ensure_categorical(adata, "leiden")
24
+
25
+ - require(name, ctx, feature) -> module
26
+ Dependency check. Raises ImportError with install instructions if missing.
27
+ Used in dependency_manager.py only.
28
+
29
+ Async variants: Add '_async' suffix (e.g., ensure_unique_var_names_async).
30
+ """
31
+
32
+ from typing import TYPE_CHECKING, Any, Literal, Optional
33
+
34
+ import numpy as np
35
+ import pandas as pd
36
+
37
+ if TYPE_CHECKING:
38
+ import anndata as ad
39
+
40
+ from scipy import sparse
41
+
42
+ from .exceptions import DataError
43
+
44
+ # =============================================================================
45
+ # Constants: Standard Field Names
46
+ # =============================================================================
47
+ SPATIAL_KEY = "spatial"
48
+ CELL_TYPE_KEY = "cell_type"
49
+ CLUSTER_KEY = "leiden"
50
+ BATCH_KEY = "batch"
51
+
52
+ # Alternative names for compatibility
53
+ ALTERNATIVE_SPATIAL_KEYS: set[str] = {
54
+ "spatial",
55
+ "X_spatial",
56
+ "coordinates",
57
+ "coords",
58
+ "spatial_coords",
59
+ "positions",
60
+ }
61
+ ALTERNATIVE_CELL_TYPE_KEYS: set[str] = {
62
+ "cell_type",
63
+ "celltype",
64
+ "cell_types",
65
+ "annotation",
66
+ "cell_annotation",
67
+ "predicted_celltype",
68
+ }
69
+ ALTERNATIVE_CLUSTER_KEYS: set[str] = {
70
+ "leiden",
71
+ "louvain",
72
+ "clusters",
73
+ "cluster",
74
+ "clustering",
75
+ "cluster_labels",
76
+ "spatial_domains",
77
+ }
78
+ ALTERNATIVE_BATCH_KEYS: set[str] = {
79
+ "batch",
80
+ "sample",
81
+ "dataset",
82
+ "experiment",
83
+ "replicate",
84
+ "batch_id",
85
+ "sample_id",
86
+ }
87
+
88
+
89
+ # =============================================================================
90
+ # Field Discovery: Find keys in AnnData
91
+ # =============================================================================
92
+ def get_spatial_key(adata: "ad.AnnData") -> Optional[str]:
93
+ """Find spatial coordinate key in adata.obsm."""
94
+ for key in ALTERNATIVE_SPATIAL_KEYS:
95
+ if key in adata.obsm:
96
+ return key
97
+ return None
98
+
99
+
100
+ def get_cell_type_key(adata: "ad.AnnData") -> Optional[str]:
101
+ """Find cell type column in adata.obs."""
102
+ for key in ALTERNATIVE_CELL_TYPE_KEYS:
103
+ if key in adata.obs:
104
+ return key
105
+ return None
106
+
107
+
108
+ def get_cluster_key(adata: "ad.AnnData") -> Optional[str]:
109
+ """Find cluster column in adata.obs."""
110
+ for key in ALTERNATIVE_CLUSTER_KEYS:
111
+ if key in adata.obs:
112
+ return key
113
+ return None
114
+
115
+
116
+ def get_batch_key(adata: "ad.AnnData") -> Optional[str]:
117
+ """Find batch/sample column in adata.obs."""
118
+ for key in ALTERNATIVE_BATCH_KEYS:
119
+ if key in adata.obs:
120
+ return key
121
+ return None
122
+
123
+
124
+ # =============================================================================
125
+ # Data Access: Get data from AnnData
126
+ # =============================================================================
127
+ def sample_expression_values(
128
+ adata: "ad.AnnData",
129
+ n_samples: int = 1000,
130
+ layer: Optional[str] = None,
131
+ ) -> np.ndarray:
132
+ """
133
+ Sample expression values from data matrix for validation checks.
134
+
135
+ Efficiently samples values from sparse or dense matrices without
136
+ materializing the full matrix. Used for data type detection
137
+ (integer vs float, negative values, etc.).
138
+
139
+ Args:
140
+ adata: AnnData object
141
+ n_samples: Maximum number of values to sample (default: 1000)
142
+ layer: Optional layer name. If None, uses adata.X
143
+
144
+ Returns:
145
+ 1D numpy array of sampled expression values
146
+
147
+ Examples:
148
+ # Check for negative values (indicates log-normalized data)
149
+ sample = sample_expression_values(adata)
150
+ if np.any(sample < 0):
151
+ raise ValueError("Log normalization requires non-negative data")
152
+
153
+ # Check for non-integer values (indicates normalized data)
154
+ sample = sample_expression_values(adata)
155
+ if np.any((sample % 1) != 0):
156
+ raise ValueError("Method requires raw count data (integers)")
157
+ """
158
+ # Get the data matrix
159
+ X = adata.layers[layer] if layer is not None else adata.X
160
+
161
+ # Handle sparse matrices efficiently
162
+ if sparse.issparse(X):
163
+ # For sparse matrices, sample from .data array (non-zero values only)
164
+ # This is efficient as it doesn't require converting to dense
165
+ # Note: All scipy sparse matrices have .data attribute
166
+ if len(X.data) > 0:
167
+ return X.data[: min(n_samples, len(X.data))]
168
+ else:
169
+ # Empty sparse matrix - return slice converted to dense
170
+ return X[:n_samples].toarray().flatten()
171
+ else:
172
+ # For dense matrices, flatten and sample
173
+ return X.flatten()[: min(n_samples, X.size)]
174
+
175
+
176
+ def require_spatial_coords(
177
+ adata: "ad.AnnData",
178
+ spatial_key: Optional[str] = None,
179
+ validate: bool = True,
180
+ ) -> np.ndarray:
181
+ """
182
+ Get validated spatial coordinates array from AnnData.
183
+
184
+ This is the primary function for accessing spatial coordinates.
185
+ Returns the full coordinates array with optional validation.
186
+
187
+ Args:
188
+ adata: AnnData object
189
+ spatial_key: Optional key in obsm. If None, auto-detects using
190
+ ALTERNATIVE_SPATIAL_KEYS
191
+ validate: If True (default), validates coordinates for:
192
+ - At least 2 dimensions
193
+ - No NaN values
194
+ - Not all identical
195
+
196
+ Returns:
197
+ Spatial coordinates as 2D numpy array (n_cells, n_dims)
198
+
199
+ Raises:
200
+ DataError: If spatial coordinates not found or validation fails
201
+
202
+ Examples:
203
+ # Auto-detect spatial key
204
+ coords = require_spatial_coords(adata)
205
+
206
+ # Use specific key without validation
207
+ coords = require_spatial_coords(adata, spatial_key="X_spatial", validate=False)
208
+ """
209
+ # Find spatial key if not specified
210
+ if spatial_key is None:
211
+ spatial_key = get_spatial_key(adata)
212
+ if spatial_key is None:
213
+ # Also check obs for x/y columns
214
+ if "x" in adata.obs and "y" in adata.obs:
215
+ x = pd.to_numeric(adata.obs["x"], errors="coerce").values
216
+ y = pd.to_numeric(adata.obs["y"], errors="coerce").values
217
+ coords = np.column_stack([x, y])
218
+ if validate and np.any(np.isnan(coords)):
219
+ raise DataError("Spatial coordinates in obs['x'/'y'] contain NaN")
220
+ return coords
221
+
222
+ raise DataError(
223
+ "No spatial coordinates found. Expected in adata.obsm['spatial'] "
224
+ "or similar key. Available obsm keys: "
225
+ f"{list(adata.obsm.keys()) if adata.obsm else 'none'}"
226
+ )
227
+
228
+ # Check if key exists
229
+ if spatial_key not in adata.obsm:
230
+ raise DataError(
231
+ f"Spatial coordinates '{spatial_key}' not found in adata.obsm. "
232
+ f"Available keys: {list(adata.obsm.keys())}"
233
+ )
234
+
235
+ coords = adata.obsm[spatial_key]
236
+
237
+ # Validate if requested
238
+ if validate:
239
+ if coords.shape[1] < 2:
240
+ raise DataError(
241
+ f"Spatial coordinates should have at least 2 dimensions, "
242
+ f"found {coords.shape[1]}"
243
+ )
244
+ if np.any(np.isnan(coords)):
245
+ raise DataError("Spatial coordinates contain NaN values")
246
+ if np.any(np.isinf(coords)):
247
+ raise DataError("Spatial coordinates contain infinite values")
248
+ if np.std(coords[:, 0]) == 0 and np.std(coords[:, 1]) == 0:
249
+ raise DataError("All spatial coordinates are identical")
250
+
251
+ return coords
252
+
253
+
254
+ # =============================================================================
255
+ # Validation: Check and validate AnnData
256
+ # =============================================================================
257
+ def validate_obs_column(
258
+ adata: "ad.AnnData",
259
+ column: str,
260
+ friendly_name: Optional[str] = None,
261
+ ) -> None:
262
+ """
263
+ Validate that a column exists in adata.obs.
264
+
265
+ Raises:
266
+ DataError: If column not found
267
+ """
268
+ if column not in adata.obs.columns:
269
+ name = friendly_name or f"Column '{column}'"
270
+ available = ", ".join(list(adata.obs.columns)[:10])
271
+ suffix = "..." if len(adata.obs.columns) > 10 else ""
272
+ raise DataError(
273
+ f"{name} not found in adata.obs. Available: {available}{suffix}"
274
+ )
275
+
276
+
277
+ def validate_var_column(
278
+ adata: "ad.AnnData",
279
+ column: str,
280
+ friendly_name: Optional[str] = None,
281
+ ) -> None:
282
+ """
283
+ Validate that a column exists in adata.var.
284
+
285
+ Raises:
286
+ DataError: If column not found
287
+ """
288
+ if column not in adata.var.columns:
289
+ name = friendly_name or f"Column '{column}'"
290
+ available = ", ".join(list(adata.var.columns)[:10])
291
+ suffix = "..." if len(adata.var.columns) > 10 else ""
292
+ raise DataError(
293
+ f"{name} not found in adata.var. Available: {available}{suffix}"
294
+ )
295
+
296
+
297
+ def validate_adata_basics(
298
+ adata: "ad.AnnData",
299
+ min_obs: int = 1,
300
+ min_vars: int = 1,
301
+ check_empty_ratio: bool = False,
302
+ max_empty_obs_ratio: float = 0.1,
303
+ max_empty_vars_ratio: float = 0.5,
304
+ ) -> None:
305
+ """Validate basic AnnData structure.
306
+
307
+ Args:
308
+ adata: AnnData object to validate
309
+ min_obs: Minimum number of observations (cells/spots) required
310
+ min_vars: Minimum number of variables (genes) required
311
+ check_empty_ratio: If True, also check for empty cells/genes
312
+ max_empty_obs_ratio: Max fraction of cells with zero expression (default 10%)
313
+ max_empty_vars_ratio: Max fraction of genes with zero expression (default 50%)
314
+
315
+ Raises:
316
+ DataError: If validation fails
317
+ """
318
+ if adata is None:
319
+ raise DataError("AnnData object cannot be None")
320
+ if adata.n_obs < min_obs:
321
+ raise DataError(f"Dataset has {adata.n_obs} observations, need {min_obs}")
322
+ if adata.n_vars < min_vars:
323
+ raise DataError(f"Dataset has {adata.n_vars} variables, need {min_vars}")
324
+
325
+ if check_empty_ratio:
326
+ # Count non-zero entries per cell/gene (sparse-aware)
327
+ if sparse.issparse(adata.X):
328
+ cell_nnz = np.array(adata.X.getnnz(axis=1)).flatten()
329
+ gene_nnz = np.array(adata.X.getnnz(axis=0)).flatten()
330
+ else:
331
+ cell_nnz = np.sum(adata.X > 0, axis=1)
332
+ gene_nnz = np.sum(adata.X > 0, axis=0)
333
+
334
+ empty_cells = np.sum(cell_nnz == 0)
335
+ empty_genes = np.sum(gene_nnz == 0)
336
+
337
+ if empty_cells > adata.n_obs * max_empty_obs_ratio:
338
+ pct = empty_cells / adata.n_obs * 100
339
+ raise DataError(
340
+ f"{empty_cells} cells ({pct:.1f}%) have zero expression. "
341
+ f"Check data quality and consider filtering."
342
+ )
343
+
344
+ if empty_genes > adata.n_vars * max_empty_vars_ratio:
345
+ pct = empty_genes / adata.n_vars * 100
346
+ raise DataError(
347
+ f"{empty_genes} genes ({pct:.1f}%) have zero expression. "
348
+ f"Consider gene filtering."
349
+ )
350
+
351
+
352
+ def ensure_categorical(adata: "ad.AnnData", column: str) -> None:
353
+ """Ensure a column is categorical dtype, converting if needed."""
354
+ if column not in adata.obs.columns:
355
+ raise DataError(f"Column '{column}' not found in adata.obs")
356
+ if not pd.api.types.is_categorical_dtype(adata.obs[column]):
357
+ adata.obs[column] = adata.obs[column].astype("category")
358
+
359
+
360
+ # =============================================================================
361
+ # Standardization
362
+ # =============================================================================
363
+ def standardize_adata(adata: "ad.AnnData", copy: bool = True) -> "ad.AnnData":
364
+ """
365
+ Standardize AnnData to ChatSpatial conventions.
366
+
367
+ Does:
368
+ 1. Move spatial coordinates to obsm['spatial']
369
+ 2. Make gene names unique
370
+ 3. Convert known categorical columns to category dtype
371
+
372
+ Does NOT:
373
+ - Compute HVGs (use preprocessing)
374
+ - Compute spatial neighbors (computed by analysis tools)
375
+ """
376
+ if copy:
377
+ adata = adata.copy()
378
+
379
+ # Standardize spatial coordinates
380
+ _move_spatial_to_standard(adata)
381
+
382
+ # Make gene names unique
383
+ ensure_unique_var_names(adata)
384
+
385
+ # Ensure categorical columns for known key types
386
+ all_categorical_keys = (
387
+ ALTERNATIVE_CELL_TYPE_KEYS | ALTERNATIVE_CLUSTER_KEYS | ALTERNATIVE_BATCH_KEYS
388
+ )
389
+ for key in adata.obs.columns:
390
+ if key in all_categorical_keys:
391
+ ensure_categorical(adata, key)
392
+
393
+ return adata
394
+
395
+
396
+ def _move_spatial_to_standard(adata: "ad.AnnData") -> None:
397
+ """Move spatial coordinates to standard obsm['spatial'] location."""
398
+ if SPATIAL_KEY in adata.obsm:
399
+ return
400
+
401
+ # Check alternative obsm keys
402
+ for key in ALTERNATIVE_SPATIAL_KEYS:
403
+ if key in adata.obsm and key != SPATIAL_KEY:
404
+ adata.obsm[SPATIAL_KEY] = adata.obsm[key]
405
+ return
406
+
407
+ # Check obs x/y
408
+ if "x" in adata.obs and "y" in adata.obs:
409
+ try:
410
+ x = pd.to_numeric(adata.obs["x"], errors="coerce").values
411
+ y = pd.to_numeric(adata.obs["y"], errors="coerce").values
412
+ if not (np.any(np.isnan(x)) or np.any(np.isnan(y))):
413
+ adata.obsm[SPATIAL_KEY] = np.column_stack([x, y]).astype("float64")
414
+ except Exception:
415
+ pass
416
+
417
+
418
+ # =============================================================================
419
+ # Advanced Validation: validate_adata with optional checks
420
+ # =============================================================================
421
+ def validate_adata(
422
+ adata: "ad.AnnData",
423
+ required_keys: dict,
424
+ check_spatial: bool = False,
425
+ check_velocity: bool = False,
426
+ spatial_key: str = "spatial",
427
+ ) -> None:
428
+ """
429
+ Validate AnnData object has required keys and optional data integrity checks.
430
+
431
+ Args:
432
+ adata: AnnData object to validate
433
+ required_keys: Dict of required keys by category (obs, var, obsm, etc.)
434
+ check_spatial: Whether to validate spatial coordinates
435
+ check_velocity: Whether to validate velocity data layers
436
+ spatial_key: Key for spatial coordinates in adata.obsm
437
+
438
+ Raises:
439
+ DataError: If required keys are missing or validation fails
440
+ """
441
+ missing = []
442
+
443
+ for category, keys in required_keys.items():
444
+ if isinstance(keys, str):
445
+ keys = [keys]
446
+
447
+ attr = getattr(adata, category, None)
448
+ if attr is None:
449
+ missing.extend([f"{category}.{k}" for k in keys])
450
+ continue
451
+
452
+ for key in keys:
453
+ if hasattr(attr, "columns"): # DataFrame
454
+ if key not in attr.columns:
455
+ missing.append(f"{category}.{key}")
456
+ elif hasattr(attr, "keys"): # Dict-like
457
+ if key not in attr.keys():
458
+ missing.append(f"{category}.{key}")
459
+ else:
460
+ missing.append(f"{category}.{key}")
461
+
462
+ if missing:
463
+ raise DataError(f"Missing required keys: {', '.join(missing)}")
464
+
465
+ # Enhanced validation checks
466
+ if check_spatial:
467
+ _validate_spatial_data(adata, spatial_key, missing)
468
+
469
+ if check_velocity:
470
+ _validate_velocity_data(adata, missing)
471
+
472
+ if missing:
473
+ raise DataError(f"Validation failed: {', '.join(missing)}")
474
+
475
+
476
+ def _validate_spatial_data(
477
+ adata: "ad.AnnData", spatial_key: str, issues: list[str]
478
+ ) -> None:
479
+ """Internal helper for spatial data validation."""
480
+ if spatial_key not in adata.obsm:
481
+ issues.append(f"Missing '{spatial_key}' coordinates in adata.obsm")
482
+ return
483
+
484
+ spatial_coords = adata.obsm[spatial_key]
485
+
486
+ if spatial_coords.shape[1] < 2:
487
+ issues.append(
488
+ f"Spatial coordinates should have at least 2 dimensions, "
489
+ f"found {spatial_coords.shape[1]}"
490
+ )
491
+
492
+ if np.any(np.isnan(spatial_coords)):
493
+ issues.append("Spatial coordinates contain NaN values")
494
+
495
+ if np.std(spatial_coords[:, 0]) == 0 and np.std(spatial_coords[:, 1]) == 0:
496
+ issues.append("All spatial coordinates are identical")
497
+
498
+
499
+ def _validate_velocity_data(adata: "ad.AnnData", issues: list[str]) -> None:
500
+ """Internal helper for velocity data validation."""
501
+ if "spliced" not in adata.layers:
502
+ issues.append("Missing 'spliced' layer required for RNA velocity")
503
+ if "unspliced" not in adata.layers:
504
+ issues.append("Missing 'unspliced' layer required for RNA velocity")
505
+
506
+ if "spliced" in adata.layers and "unspliced" in adata.layers:
507
+ for layer_name in ["spliced", "unspliced"]:
508
+ layer_data = adata.layers[layer_name]
509
+
510
+ if hasattr(layer_data, "nnz"): # Sparse matrix
511
+ if layer_data.nnz == 0:
512
+ issues.append(f"'{layer_name}' layer is empty (all zeros)")
513
+ else: # Dense matrix
514
+ if np.all(layer_data == 0):
515
+ issues.append(f"'{layer_name}' layer is empty (all zeros)")
516
+
517
+ if hasattr(layer_data, "data"): # Sparse matrix
518
+ if np.any(np.isnan(layer_data.data)):
519
+ issues.append(f"'{layer_name}' layer contains NaN values")
520
+ else: # Dense matrix
521
+ if np.any(np.isnan(layer_data)):
522
+ issues.append(f"'{layer_name}' layer contains NaN values")
523
+
524
+
525
+ # =============================================================================
526
+ # Metadata Storage: Scientific Provenance Tracking
527
+ # =============================================================================
528
+ def store_analysis_metadata(
529
+ adata: "ad.AnnData",
530
+ analysis_name: str,
531
+ method: str,
532
+ parameters: dict[str, Any],
533
+ results_keys: dict[str, list[str]],
534
+ statistics: Optional[dict[str, Any]] = None,
535
+ species: Optional[str] = None,
536
+ database: Optional[str] = None,
537
+ reference_info: Optional[dict[str, Any]] = None,
538
+ ) -> None:
539
+ """Store analysis metadata in adata.uns for scientific provenance tracking.
540
+
541
+ This function stores ONLY scientifically important metadata:
542
+ - Method name (required for reproducibility)
543
+ - Parameters (required for reproducibility)
544
+ - Results locations (required for data access)
545
+ - Statistics (required for quality assessment)
546
+ - Species/Database (required for biological interpretation)
547
+ - Reference info (required for reference-based methods)
548
+
549
+ Args:
550
+ adata: AnnData object to store metadata in
551
+ analysis_name: Name of the analysis (e.g., "annotation_tangram")
552
+ method: Method name (e.g., "tangram", "liana", "cellrank")
553
+ parameters: Dictionary of analysis parameters
554
+ results_keys: Dictionary mapping storage location to list of keys
555
+ Example: {"obs": ["cell_type_tangram"], "obsm": ["tangram_ct_pred"]}
556
+ statistics: Optional dictionary of quality/summary statistics
557
+ species: Optional species identifier (critical for communication/enrichment)
558
+ database: Optional database/resource name (critical for communication/enrichment)
559
+ reference_info: Optional reference dataset information
560
+ """
561
+ # Build metadata dictionary - only scientifically important information
562
+ metadata = {
563
+ "method": method,
564
+ "parameters": parameters,
565
+ "results_keys": results_keys,
566
+ }
567
+
568
+ # Add optional scientific metadata
569
+ if statistics is not None:
570
+ metadata["statistics"] = statistics
571
+
572
+ if species is not None:
573
+ metadata["species"] = species
574
+
575
+ if database is not None:
576
+ metadata["database"] = database
577
+
578
+ if reference_info is not None:
579
+ metadata["reference_info"] = reference_info
580
+
581
+ # Store in adata.uns with unique key
582
+ metadata_key = f"{analysis_name}_metadata"
583
+ adata.uns[metadata_key] = metadata
584
+
585
+
586
+ def get_analysis_parameter(
587
+ adata: "ad.AnnData",
588
+ analysis_name: str,
589
+ parameter_name: str,
590
+ default: Any = None,
591
+ ) -> Any:
592
+ """Get a parameter from stored analysis metadata.
593
+
594
+ Retrieves parameters stored by store_analysis_metadata(). Use this to
595
+ access analysis parameters (like cluster_key) without re-inferring them.
596
+
597
+ Args:
598
+ adata: AnnData object
599
+ analysis_name: Name of the analysis (e.g., "spatial_stats_neighborhood")
600
+ parameter_name: Name of the parameter (e.g., "cluster_key")
601
+ default: Default value if parameter not found
602
+
603
+ Returns:
604
+ Parameter value or default
605
+
606
+ Example:
607
+ # Get cluster_key used in neighborhood analysis
608
+ cluster_key = get_analysis_parameter(
609
+ adata, "spatial_stats_neighborhood", "cluster_key"
610
+ )
611
+ """
612
+ metadata_key = f"{analysis_name}_metadata"
613
+ if metadata_key not in adata.uns:
614
+ return default
615
+
616
+ metadata = adata.uns[metadata_key]
617
+ if "parameters" not in metadata:
618
+ return default
619
+
620
+ return metadata["parameters"].get(parameter_name, default)
621
+
622
+
623
+ # =============================================================================
624
+ # Gene Selection Utilities
625
+ # =============================================================================
626
+ def get_highly_variable_genes(
627
+ adata: "ad.AnnData",
628
+ max_genes: int = 500,
629
+ fallback_to_variance: bool = True,
630
+ ) -> list[str]:
631
+ """
632
+ Get highly variable genes from AnnData.
633
+
634
+ Priority order:
635
+ 1. Use precomputed HVG from adata.var['highly_variable']
636
+ 2. If fallback enabled, compute variance and return top variable genes
637
+
638
+ Args:
639
+ adata: AnnData object
640
+ max_genes: Maximum number of genes to return
641
+ fallback_to_variance: If True, compute variance when HVG not available
642
+
643
+ Returns:
644
+ List of gene names (may be shorter than max_genes if fewer available)
645
+ """
646
+ # Try precomputed HVG first
647
+ if "highly_variable" in adata.var.columns:
648
+ hvg_genes = adata.var_names[adata.var["highly_variable"]].tolist()
649
+ return hvg_genes[:max_genes]
650
+
651
+ # Fallback to variance calculation
652
+ if fallback_to_variance:
653
+ from scipy import sparse
654
+
655
+ if sparse.issparse(adata.X):
656
+ # Compute variance on sparse matrix without converting to dense
657
+ # Var(X) = E[X^2] - E[X]^2 (memory efficient, ~5x faster)
658
+ mean = np.array(adata.X.mean(axis=0)).flatten()
659
+ mean_sq = np.array(adata.X.power(2).mean(axis=0)).flatten()
660
+ var_scores = mean_sq - mean**2
661
+ else:
662
+ var_scores = np.array(adata.X.var(axis=0)).flatten()
663
+
664
+ top_indices = np.argsort(var_scores)[-max_genes:]
665
+ return adata.var_names[top_indices].tolist()
666
+
667
+ return []
668
+
669
+
670
+ def select_genes_for_analysis(
671
+ adata: "ad.AnnData",
672
+ genes: Optional[list[str]] = None,
673
+ n_genes: int = 20,
674
+ require_hvg: bool = True,
675
+ analysis_name: str = "analysis",
676
+ ) -> list[str]:
677
+ """
678
+ Select genes for spatial/statistical analysis.
679
+
680
+ Unified gene selection logic for all analysis tools. Replaces duplicated
681
+ code across spatial_statistics.py and other tools.
682
+
683
+ Priority:
684
+ 1. User-specified genes (filtered to existing genes)
685
+ 2. Highly variable genes (HVG) from preprocessing
686
+
687
+ Args:
688
+ adata: AnnData object
689
+ genes: User-specified gene list. If provided, filters to genes in adata.
690
+ n_genes: Maximum number of genes to return when using HVG.
691
+ require_hvg: If True (default), raise error when HVG not found.
692
+ If False, return empty list when HVG not found.
693
+ analysis_name: Name of analysis for error messages (e.g., "Moran's I").
694
+
695
+ Returns:
696
+ List of gene names to analyze.
697
+
698
+ Raises:
699
+ DataError: If genes specified but none found, or HVG required but missing.
700
+
701
+ Examples:
702
+ # Use user-specified genes
703
+ genes = select_genes_for_analysis(adata, genes=["CD4", "CD8A"])
704
+
705
+ # Use top 50 HVGs
706
+ genes = select_genes_for_analysis(adata, n_genes=50)
707
+
708
+ # For analysis that can work without HVG
709
+ genes = select_genes_for_analysis(adata, require_hvg=False)
710
+ """
711
+ # Case 1: User specified genes
712
+ if genes is not None:
713
+ valid_genes = [g for g in genes if g in adata.var_names]
714
+ if not valid_genes:
715
+ # Find closest matches for better error message
716
+ from difflib import get_close_matches
717
+
718
+ suggestions = []
719
+ for g in genes[:3]: # Check first 3 genes
720
+ matches = get_close_matches(
721
+ g, adata.var_names.tolist(), n=1, cutoff=0.6
722
+ )
723
+ if matches:
724
+ suggestions.append(f"'{g}' → '{matches[0]}'?")
725
+
726
+ suggestion_str = (
727
+ f" Did you mean: {', '.join(suggestions)}" if suggestions else ""
728
+ )
729
+ raise DataError(
730
+ f"None of the specified genes found in data: {genes[:5]}..."
731
+ f"{suggestion_str}"
732
+ )
733
+ return valid_genes
734
+
735
+ # Case 2: Use HVG
736
+ if "highly_variable" in adata.var.columns and adata.var["highly_variable"].any():
737
+ hvg_genes = adata.var_names[adata.var["highly_variable"]].tolist()
738
+ return hvg_genes[:n_genes]
739
+
740
+ # Case 3: HVG not available
741
+ if require_hvg:
742
+ raise DataError(
743
+ f"Highly variable genes (HVG) required for {analysis_name}.\n\n"
744
+ "Solutions:\n"
745
+ "1. Run preprocess_data() first to compute HVGs\n"
746
+ "2. Specify genes explicitly via 'genes' parameter"
747
+ )
748
+
749
+ return []
750
+
751
+
752
+ # =============================================================================
753
+ # Gene Name Utilities
754
+ # =============================================================================
755
+ def ensure_unique_var_names(
756
+ adata: "ad.AnnData",
757
+ label: str = "data",
758
+ ) -> int:
759
+ """
760
+ Ensure gene names are unique, fixing duplicates if needed.
761
+
762
+ Args:
763
+ adata: AnnData object (modified in-place)
764
+ label: Label for logging (not used in sync version, for API consistency)
765
+
766
+ Returns:
767
+ Number of duplicate gene names that were fixed (0 if already unique)
768
+ """
769
+ if adata.var_names.is_unique:
770
+ return 0
771
+
772
+ n_duplicates = len(adata.var_names) - len(set(adata.var_names))
773
+ adata.var_names_make_unique()
774
+ return n_duplicates
775
+
776
+
777
+ async def ensure_unique_var_names_async(
778
+ adata: "ad.AnnData",
779
+ ctx: Any, # ToolContext, use Any to avoid circular import
780
+ label: str = "data",
781
+ ) -> int:
782
+ """
783
+ Ensure gene names are unique with user feedback via ctx.
784
+
785
+ Async variant of ensure_unique_var_names with context logging.
786
+
787
+ Args:
788
+ adata: AnnData object (modified in-place)
789
+ ctx: ToolContext for logging warnings to user
790
+ label: Descriptive label for the data (e.g., "reference data", "query data")
791
+
792
+ Returns:
793
+ Number of duplicate gene names that were fixed (0 if already unique)
794
+ """
795
+ n_fixed = ensure_unique_var_names(adata, label)
796
+ if n_fixed > 0:
797
+ await ctx.warning(f"Found {n_fixed} duplicate gene names in {label}, fixed")
798
+ return n_fixed
799
+
800
+
801
+ # =============================================================================
802
+ # Raw Counts Data Access: Unified interface for accessing raw data
803
+ # =============================================================================
804
+
805
+
806
+ def check_is_integer_counts(X: Any, sample_size: int = 100) -> tuple[bool, bool, bool]:
807
+ """Check if a matrix contains integer counts.
808
+
809
+ This is a lightweight utility for checking data format without
810
+ going through the full data source detection logic.
811
+
812
+ Args:
813
+ X: Data matrix (sparse or dense)
814
+ sample_size: Number of rows/cols to sample for efficiency
815
+
816
+ Returns:
817
+ Tuple of (is_integer, has_negatives, has_decimals)
818
+ """
819
+ n_rows = min(sample_size, X.shape[0])
820
+ n_cols = min(sample_size, X.shape[1])
821
+ sample = X[:n_rows, :n_cols]
822
+
823
+ if sparse.issparse(sample):
824
+ sample = sample.toarray()
825
+
826
+ has_negatives = float(sample.min()) < 0
827
+ has_decimals = not np.allclose(sample, np.round(sample), atol=1e-6)
828
+ is_integer = not has_negatives and not has_decimals
829
+
830
+ return is_integer, has_negatives, has_decimals
831
+
832
+
833
+ def ensure_counts_layer(
834
+ adata: "ad.AnnData",
835
+ layer_name: str = "counts",
836
+ error_message: Optional[str] = None,
837
+ ) -> bool:
838
+ """Ensure a counts layer exists in AnnData, creating from raw if needed.
839
+
840
+ This is the single source of truth for counts layer preparation.
841
+ Used by scVI-tools methods (scANVI, Cell2location, etc.) that require
842
+ raw counts in a specific layer.
843
+
844
+ Args:
845
+ adata: AnnData object (modified in-place)
846
+ layer_name: Name of the layer to ensure (default: "counts")
847
+ error_message: Custom error message if counts cannot be created
848
+
849
+ Returns:
850
+ True if layer was created, False if already existed
851
+
852
+ Raises:
853
+ DataNotFoundError: If counts layer cannot be created
854
+
855
+ Note:
856
+ When adata has been subsetted to HVGs, this function correctly
857
+ subsets adata.raw to match the current var_names.
858
+
859
+ Examples:
860
+ # Ensure counts layer exists before scANVI setup
861
+ ensure_counts_layer(adata_ref)
862
+ scvi.model.SCANVI.setup_anndata(adata_ref, layer="counts", ...)
863
+
864
+ # With custom error message
865
+ ensure_counts_layer(adata, error_message="scANVI requires raw counts")
866
+ """
867
+ from .exceptions import DataNotFoundError
868
+
869
+ if layer_name in adata.layers:
870
+ return False
871
+
872
+ if adata.raw is not None:
873
+ # Get raw counts, subsetting to current var_names
874
+ # Note: adata.raw may have full genes while adata has HVG subset
875
+ adata.layers[layer_name] = adata.raw[:, adata.var_names].X
876
+ return True
877
+
878
+ # Cannot create counts layer
879
+ default_error = (
880
+ f"Cannot create '{layer_name}' layer: adata.raw is None. "
881
+ "Load unpreprocessed data or ensure adata.raw is preserved during preprocessing."
882
+ )
883
+ raise DataNotFoundError(error_message or default_error)
884
+
885
+
886
+ class RawDataResult:
887
+ """Result of raw data extraction."""
888
+
889
+ def __init__(
890
+ self,
891
+ X: Any, # sparse or dense matrix
892
+ var_names: pd.Index,
893
+ source: str,
894
+ is_integer_counts: bool,
895
+ has_negatives: bool = False,
896
+ has_decimals: bool = False,
897
+ ):
898
+ self.X = X
899
+ self.var_names = var_names
900
+ self.source = source
901
+ self.is_integer_counts = is_integer_counts
902
+ self.has_negatives = has_negatives
903
+ self.has_decimals = has_decimals
904
+
905
+
906
+ def get_raw_data_source(
907
+ adata: "ad.AnnData",
908
+ prefer_complete_genes: bool = True,
909
+ require_integer_counts: bool = False,
910
+ sample_size: int = 100,
911
+ ) -> RawDataResult:
912
+ """
913
+ Get raw count data from AnnData using a unified priority order.
914
+
915
+ This is THE single source of truth for accessing raw counts data.
916
+ All tools should use this function instead of implementing their own logic.
917
+
918
+ Priority order (when prefer_complete_genes=True):
919
+ 1. adata.raw - Complete gene set, preserved before HVG filtering
920
+ 2. adata.layers["counts"] - Raw counts layer
921
+ 3. adata.X - Current expression matrix
922
+
923
+ Priority order (when prefer_complete_genes=False):
924
+ 1. adata.layers["counts"] - Raw counts layer
925
+ 2. adata.X - Current expression matrix
926
+ (adata.raw is skipped as it may have different dimensions)
927
+
928
+ Args:
929
+ adata: AnnData object
930
+ prefer_complete_genes: If True, prefer adata.raw for complete gene coverage.
931
+ Set to False when you need data aligned with current adata dimensions.
932
+ require_integer_counts: If True, validate that data contains integer counts.
933
+ Raises DataError if only normalized data is found.
934
+ sample_size: Number of cells/genes to sample for validation.
935
+
936
+ Returns:
937
+ RawDataResult with data matrix, var_names, source name, and validation info.
938
+
939
+ Raises:
940
+ DataError: If require_integer_counts=True and no integer counts found.
941
+
942
+ Example:
943
+ result = get_raw_data_source(adata, prefer_complete_genes=True)
944
+ print(f"Using {result.source}: {len(result.var_names)} genes")
945
+ if result.is_integer_counts:
946
+ # Safe to use for deconvolution/velocity
947
+ pass
948
+ """
949
+ sources_tried = []
950
+
951
+ # Source 1: adata.raw (complete gene set)
952
+ if prefer_complete_genes and adata.raw is not None:
953
+ try:
954
+ raw_adata = adata.raw.to_adata()
955
+ is_int, has_neg, has_dec = check_is_integer_counts(raw_adata.X, sample_size)
956
+
957
+ if is_int or not require_integer_counts:
958
+ return RawDataResult(
959
+ X=raw_adata.X,
960
+ var_names=raw_adata.var_names,
961
+ source="raw",
962
+ is_integer_counts=is_int,
963
+ has_negatives=has_neg,
964
+ has_decimals=has_dec,
965
+ )
966
+ sources_tried.append("raw (normalized, skipped)")
967
+ except Exception:
968
+ sources_tried.append("raw (error, skipped)")
969
+
970
+ # Source 2: layers["counts"]
971
+ if "counts" in adata.layers:
972
+ X_counts = adata.layers["counts"]
973
+ is_int, has_neg, has_dec = check_is_integer_counts(X_counts, sample_size)
974
+
975
+ if is_int or not require_integer_counts:
976
+ return RawDataResult(
977
+ X=X_counts,
978
+ var_names=adata.var_names,
979
+ source="counts_layer",
980
+ is_integer_counts=is_int,
981
+ has_negatives=has_neg,
982
+ has_decimals=has_dec,
983
+ )
984
+ sources_tried.append("counts_layer (normalized, skipped)")
985
+
986
+ # Source 3: current X
987
+ is_int, has_neg, has_dec = check_is_integer_counts(adata.X, sample_size)
988
+
989
+ if is_int or not require_integer_counts:
990
+ return RawDataResult(
991
+ X=adata.X,
992
+ var_names=adata.var_names,
993
+ source="current",
994
+ is_integer_counts=is_int,
995
+ has_negatives=has_neg,
996
+ has_decimals=has_dec,
997
+ )
998
+
999
+ # If we reach here, require_integer_counts=True but no valid source found
1000
+ # (line 1012 would have returned if require_integer_counts=False)
1001
+ raise DataError(
1002
+ f"No raw integer counts found. Sources tried: {sources_tried + ['current (normalized)']}. "
1003
+ f"Data appears to be normalized (has_negatives={has_neg}, has_decimals={has_dec}). "
1004
+ "Deconvolution and velocity methods require raw integer counts. "
1005
+ "Solutions: (1) Load unpreprocessed data, (2) Ensure adata.layers['counts'] "
1006
+ "contains raw counts, or (3) Re-run preprocessing with adata.raw preservation."
1007
+ )
1008
+
1009
+
1010
+ # =============================================================================
1011
+ # Expression Data Extraction: Unified sparse/dense handling
1012
+ # =============================================================================
1013
+ def to_dense(X: Any, copy: bool = False) -> np.ndarray:
1014
+ """
1015
+ Convert sparse matrix to dense numpy array.
1016
+
1017
+ Handles both scipy sparse matrices and already-dense arrays uniformly.
1018
+ This is THE single function for sparse-to-dense conversion across ChatSpatial.
1019
+
1020
+ Args:
1021
+ X: Expression matrix (sparse or dense)
1022
+ copy: If True, always return a copy (safe for modification).
1023
+ If False (default), may return view for dense input (read-only use).
1024
+
1025
+ Returns:
1026
+ Dense numpy array
1027
+
1028
+ Note:
1029
+ - Sparse input: Always returns a new array (toarray() creates copy)
1030
+ - Dense input with copy=False: May return view (no memory overhead)
1031
+ - Dense input with copy=True: Always returns independent copy
1032
+
1033
+ Examples:
1034
+ # Read-only use (default, memory efficient)
1035
+ dense_X = to_dense(adata.X)
1036
+
1037
+ # When you need to modify the result
1038
+ dense_X = to_dense(adata.X, copy=True)
1039
+ dense_X[0, 0] = 999 # Safe, won't affect original
1040
+ """
1041
+ if sparse.issparse(X):
1042
+ return X.toarray()
1043
+ # For dense: np.array with copy=False may still copy if needed (e.g., non-contiguous)
1044
+ # np.array with copy=True always copies
1045
+ return np.array(X, copy=copy)
1046
+
1047
+
1048
+ def get_gene_expression(
1049
+ adata: "ad.AnnData",
1050
+ gene: str,
1051
+ layer: Optional[str] = None,
1052
+ ) -> np.ndarray:
1053
+ """
1054
+ Get expression values of a single gene as 1D array.
1055
+
1056
+ This is THE single function for extracting single-gene expression.
1057
+ Replaces 12+ duplicated code patterns across the codebase.
1058
+
1059
+ Args:
1060
+ adata: AnnData object
1061
+ gene: Gene name (must exist in adata.var_names)
1062
+ layer: Optional layer name. If None, uses adata.X
1063
+
1064
+ Returns:
1065
+ 1D numpy array of expression values (length = n_obs)
1066
+
1067
+ Raises:
1068
+ DataError: If gene not found in adata
1069
+
1070
+ Examples:
1071
+ # Basic usage
1072
+ cd4_expr = get_gene_expression(adata, "CD4")
1073
+
1074
+ # From specific layer
1075
+ counts = get_gene_expression(adata, "CD4", layer="counts")
1076
+
1077
+ # Use in visualization
1078
+ adata.obs["_temp_expr"] = get_gene_expression(adata, gene)
1079
+ """
1080
+ if gene not in adata.var_names:
1081
+ raise DataError(
1082
+ f"Gene '{gene}' not found in data. "
1083
+ f"Available genes (first 5): {adata.var_names[:5].tolist()}"
1084
+ )
1085
+
1086
+ if layer is not None:
1087
+ if layer not in adata.layers:
1088
+ raise DataError(
1089
+ f"Layer '{layer}' not found. Available: {list(adata.layers.keys())}"
1090
+ )
1091
+ gene_idx = adata.var_names.get_loc(gene)
1092
+ X = adata.layers[layer][:, gene_idx]
1093
+ else:
1094
+ X = adata[:, gene].X
1095
+
1096
+ return to_dense(X).flatten()
1097
+
1098
+
1099
+ def get_genes_expression(
1100
+ adata: "ad.AnnData",
1101
+ genes: list[str],
1102
+ layer: Optional[str] = None,
1103
+ ) -> np.ndarray:
1104
+ """
1105
+ Get expression values of multiple genes as 2D array.
1106
+
1107
+ Args:
1108
+ adata: AnnData object
1109
+ genes: List of gene names (must exist in adata.var_names)
1110
+ layer: Optional layer name. If None, uses adata.X
1111
+
1112
+ Returns:
1113
+ 2D numpy array of shape (n_obs, n_genes)
1114
+
1115
+ Raises:
1116
+ DataError: If any gene not found in adata
1117
+
1118
+ Examples:
1119
+ # Get expression matrix for heatmap
1120
+ expr_matrix = get_genes_expression(adata, ["CD4", "CD8A", "CD3D"])
1121
+
1122
+ # From counts layer
1123
+ counts = get_genes_expression(adata, marker_genes, layer="counts")
1124
+ """
1125
+ # Validate genes
1126
+ missing = [g for g in genes if g not in adata.var_names]
1127
+ if missing:
1128
+ raise DataError(
1129
+ f"Genes not found: {missing[:5]}{'...' if len(missing) > 5 else ''}. "
1130
+ f"Available genes (first 5): {adata.var_names[:5].tolist()}"
1131
+ )
1132
+
1133
+ if layer is not None:
1134
+ if layer not in adata.layers:
1135
+ raise DataError(
1136
+ f"Layer '{layer}' not found. Available: {list(adata.layers.keys())}"
1137
+ )
1138
+ gene_indices = [adata.var_names.get_loc(g) for g in genes]
1139
+ X = adata.layers[layer][:, gene_indices]
1140
+ else:
1141
+ X = adata[:, genes].X
1142
+
1143
+ return to_dense(X)
1144
+
1145
+
1146
+ # =============================================================================
1147
+ # Metadata Profiling: Extract structure information for LLM understanding
1148
+ # =============================================================================
1149
+ def get_column_profile(
1150
+ adata: "ad.AnnData", layer: Literal["obs", "var"] = "obs"
1151
+ ) -> list[dict[str, Any]]:
1152
+ """
1153
+ Get metadata column profile for obs or var.
1154
+
1155
+ Returns detailed information about each column to help LLM understand the data.
1156
+
1157
+ Args:
1158
+ adata: AnnData object
1159
+ layer: Which layer to profile ("obs" or "var")
1160
+
1161
+ Returns:
1162
+ List of column information dictionaries with keys:
1163
+ - name: Column name
1164
+ - dtype: "numerical" or "categorical"
1165
+ - n_unique: Number of unique values
1166
+ - range: (min, max) for numerical columns, None for categorical
1167
+ - sample_values: Sample values for categorical columns, None for numerical
1168
+ """
1169
+ df = adata.obs if layer == "obs" else adata.var
1170
+ profiles = []
1171
+
1172
+ for col in df.columns:
1173
+ col_data = df[col]
1174
+
1175
+ # Determine if numeric
1176
+ is_numeric = pd.api.types.is_numeric_dtype(col_data)
1177
+
1178
+ if is_numeric:
1179
+ # Numerical column
1180
+ profiles.append(
1181
+ {
1182
+ "name": col,
1183
+ "dtype": "numerical",
1184
+ "n_unique": int(col_data.nunique()),
1185
+ "range": (float(col_data.min()), float(col_data.max())),
1186
+ "sample_values": None,
1187
+ }
1188
+ )
1189
+ else:
1190
+ # Categorical column
1191
+ unique_vals = col_data.unique()
1192
+ n_unique = len(unique_vals)
1193
+
1194
+ # Take first 5 sample values (or 3 if too many unique values)
1195
+ if n_unique <= 100:
1196
+ sample_vals = unique_vals[:5].tolist()
1197
+ else:
1198
+ sample_vals = unique_vals[:3].tolist()
1199
+
1200
+ profiles.append(
1201
+ {
1202
+ "name": col,
1203
+ "dtype": "categorical",
1204
+ "n_unique": n_unique,
1205
+ "sample_values": [str(v) for v in sample_vals],
1206
+ "range": None,
1207
+ }
1208
+ )
1209
+
1210
+ return profiles
1211
+
1212
+
1213
+ def get_gene_profile(
1214
+ adata: "ad.AnnData",
1215
+ ) -> tuple[Optional[list[str]], list[str]]:
1216
+ """
1217
+ Get gene expression profile including HVGs and top expressed genes.
1218
+
1219
+ Args:
1220
+ adata: AnnData object
1221
+
1222
+ Returns:
1223
+ Tuple of (top_highly_variable_genes, top_expressed_genes)
1224
+ - top_highly_variable_genes: List of HVG names or None if not computed
1225
+ - top_expressed_genes: List of top 10 expressed gene names
1226
+ """
1227
+ # Highly variable genes (no fallback - only return if precomputed)
1228
+ hvg_list = get_highly_variable_genes(
1229
+ adata, max_genes=10, fallback_to_variance=False
1230
+ )
1231
+ top_hvg = hvg_list if hvg_list else None
1232
+
1233
+ # Top expressed genes
1234
+ try:
1235
+ mean_expr = np.array(adata.X.mean(axis=0)).flatten()
1236
+ top_idx = np.argsort(mean_expr)[-10:][::-1] # Descending order
1237
+ top_expr = adata.var_names[top_idx].tolist()
1238
+ except Exception:
1239
+ top_expr = adata.var_names[:10].tolist() # Fallback
1240
+
1241
+ return top_hvg, top_expr
1242
+
1243
+
1244
+ def get_adata_profile(adata: "ad.AnnData") -> dict[str, Any]:
1245
+ """
1246
+ Get comprehensive metadata profile for LLM understanding.
1247
+
1248
+ This is the main function for extracting dataset information that helps
1249
+ LLM make informed analysis decisions.
1250
+
1251
+ Args:
1252
+ adata: AnnData object
1253
+
1254
+ Returns:
1255
+ Dictionary containing:
1256
+ - obs_columns: Profile of observation metadata columns
1257
+ - var_columns: Profile of variable metadata columns
1258
+ - obsm_keys: List of keys in obsm (embeddings, coordinates)
1259
+ - uns_keys: List of keys in uns (unstructured annotations)
1260
+ - top_highly_variable_genes: Top HVGs if computed
1261
+ - top_expressed_genes: Top expressed genes
1262
+ """
1263
+ # Get column profiles
1264
+ obs_profile = get_column_profile(adata, layer="obs")
1265
+ var_profile = get_column_profile(adata, layer="var")
1266
+
1267
+ # Get gene profiles
1268
+ top_hvg, top_expr = get_gene_profile(adata)
1269
+
1270
+ # Get multi-dimensional data keys
1271
+ obsm_keys = list(adata.obsm.keys()) if hasattr(adata, "obsm") else []
1272
+ uns_keys = list(adata.uns.keys()) if hasattr(adata, "uns") else []
1273
+
1274
+ return {
1275
+ "obs_columns": obs_profile,
1276
+ "var_columns": var_profile,
1277
+ "obsm_keys": obsm_keys,
1278
+ "uns_keys": uns_keys,
1279
+ "top_highly_variable_genes": top_hvg,
1280
+ "top_expressed_genes": top_expr,
1281
+ }
1282
+
1283
+
1284
+ # =============================================================================
1285
+ # Gene Overlap: Find and validate common genes between datasets
1286
+ # =============================================================================
1287
+ def find_common_genes(*gene_collections: Any) -> list[str]:
1288
+ """
1289
+ Find common genes across multiple gene collections.
1290
+
1291
+ This is THE single function for computing gene intersections across ChatSpatial.
1292
+ Supports any number of gene collections (2 or more).
1293
+
1294
+ Args:
1295
+ *gene_collections: Two or more gene collections. Each can be:
1296
+ - List[str]: Gene name list
1297
+ - pd.Index: AnnData var_names
1298
+ - Any Iterable[str]: Will be converted to set
1299
+
1300
+ Returns:
1301
+ List of common gene names (order not guaranteed)
1302
+
1303
+ Raises:
1304
+ ValueError: If fewer than 2 collections provided
1305
+
1306
+ Examples:
1307
+ # Between two AnnData objects
1308
+ common = find_common_genes(adata1.var_names, adata2.var_names)
1309
+
1310
+ # Multiple datasets (e.g., spatial registration)
1311
+ common = find_common_genes(
1312
+ adata1.var_names, adata2.var_names, adata3.var_names
1313
+ )
1314
+
1315
+ # With explicit lists
1316
+ common = find_common_genes(["GeneA", "GeneB"], ["GeneB", "GeneC"])
1317
+ """
1318
+ if len(gene_collections) < 2:
1319
+ raise ValueError("find_common_genes requires at least 2 gene collections")
1320
+
1321
+ # Convert first collection to set
1322
+ result = set(gene_collections[0])
1323
+
1324
+ # Intersect with remaining collections
1325
+ for genes in gene_collections[1:]:
1326
+ result &= set(genes)
1327
+
1328
+ return list(result)
1329
+
1330
+
1331
+ def validate_gene_overlap(
1332
+ common_genes: list[str],
1333
+ source_n_genes: int,
1334
+ target_n_genes: int,
1335
+ min_genes: int = 100,
1336
+ source_name: str = "source",
1337
+ target_name: str = "target",
1338
+ ) -> None:
1339
+ """
1340
+ Validate that gene overlap meets minimum requirements.
1341
+
1342
+ This is THE single validation function for gene overlap across ChatSpatial.
1343
+ Moved from deconvolution.py._validate_common_genes for reuse.
1344
+
1345
+ Args:
1346
+ common_genes: List of common gene names
1347
+ source_n_genes: Number of genes in source data
1348
+ target_n_genes: Number of genes in target data
1349
+ min_genes: Minimum required common genes (default: 100)
1350
+ source_name: Name of source data for error messages
1351
+ target_name: Name of target data for error messages
1352
+
1353
+ Raises:
1354
+ DataError: If insufficient common genes
1355
+
1356
+ Examples:
1357
+ # Basic validation
1358
+ common = find_common_genes(spatial.var_names, reference.var_names)
1359
+ validate_gene_overlap(common, spatial.n_vars, reference.n_vars)
1360
+
1361
+ # With custom threshold and names
1362
+ validate_gene_overlap(
1363
+ common, spatial.n_vars, reference.n_vars,
1364
+ min_genes=50, source_name="spatial", target_name="reference"
1365
+ )
1366
+ """
1367
+ if len(common_genes) < min_genes:
1368
+ raise DataError(
1369
+ f"Insufficient gene overlap: {len(common_genes)} < {min_genes} required. "
1370
+ f"{source_name}: {source_n_genes} genes, {target_name}: {target_n_genes} genes. "
1371
+ f"Check species/gene naming convention match."
1372
+ )