chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1372 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AnnData utilities for ChatSpatial.
|
|
3
|
+
|
|
4
|
+
This module provides:
|
|
5
|
+
1. Standard field name constants
|
|
6
|
+
2. Field discovery functions (get_*_key)
|
|
7
|
+
3. Data access functions (get_*)
|
|
8
|
+
4. Validation functions (validate_*)
|
|
9
|
+
5. Ensure functions (ensure_*)
|
|
10
|
+
|
|
11
|
+
One file for all AnnData-related utilities. No duplication.
|
|
12
|
+
|
|
13
|
+
Naming Conventions (MUST follow across codebase):
|
|
14
|
+
-------------------------------------------------
|
|
15
|
+
- validate_*(adata, ...) -> None
|
|
16
|
+
Check-only. Raises exception if validation fails.
|
|
17
|
+
Does NOT modify data. Use for precondition checks.
|
|
18
|
+
Example: validate_obs_column(adata, "leiden")
|
|
19
|
+
|
|
20
|
+
- ensure_*(adata, ...) -> bool
|
|
21
|
+
Check-and-fix. Returns True if action was taken, False if already OK.
|
|
22
|
+
MAY modify data in-place. Idempotent (safe to call multiple times).
|
|
23
|
+
Example: ensure_categorical(adata, "leiden")
|
|
24
|
+
|
|
25
|
+
- require(name, ctx, feature) -> module
|
|
26
|
+
Dependency check. Raises ImportError with install instructions if missing.
|
|
27
|
+
Used in dependency_manager.py only.
|
|
28
|
+
|
|
29
|
+
Async variants: Add '_async' suffix (e.g., ensure_unique_var_names_async).
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
from typing import TYPE_CHECKING, Any, Literal, Optional
|
|
33
|
+
|
|
34
|
+
import numpy as np
|
|
35
|
+
import pandas as pd
|
|
36
|
+
|
|
37
|
+
if TYPE_CHECKING:
|
|
38
|
+
import anndata as ad
|
|
39
|
+
|
|
40
|
+
from scipy import sparse
|
|
41
|
+
|
|
42
|
+
from .exceptions import DataError
|
|
43
|
+
|
|
44
|
+
# =============================================================================
|
|
45
|
+
# Constants: Standard Field Names
|
|
46
|
+
# =============================================================================
|
|
47
|
+
SPATIAL_KEY = "spatial"
|
|
48
|
+
CELL_TYPE_KEY = "cell_type"
|
|
49
|
+
CLUSTER_KEY = "leiden"
|
|
50
|
+
BATCH_KEY = "batch"
|
|
51
|
+
|
|
52
|
+
# Alternative names for compatibility
|
|
53
|
+
ALTERNATIVE_SPATIAL_KEYS: set[str] = {
|
|
54
|
+
"spatial",
|
|
55
|
+
"X_spatial",
|
|
56
|
+
"coordinates",
|
|
57
|
+
"coords",
|
|
58
|
+
"spatial_coords",
|
|
59
|
+
"positions",
|
|
60
|
+
}
|
|
61
|
+
ALTERNATIVE_CELL_TYPE_KEYS: set[str] = {
|
|
62
|
+
"cell_type",
|
|
63
|
+
"celltype",
|
|
64
|
+
"cell_types",
|
|
65
|
+
"annotation",
|
|
66
|
+
"cell_annotation",
|
|
67
|
+
"predicted_celltype",
|
|
68
|
+
}
|
|
69
|
+
ALTERNATIVE_CLUSTER_KEYS: set[str] = {
|
|
70
|
+
"leiden",
|
|
71
|
+
"louvain",
|
|
72
|
+
"clusters",
|
|
73
|
+
"cluster",
|
|
74
|
+
"clustering",
|
|
75
|
+
"cluster_labels",
|
|
76
|
+
"spatial_domains",
|
|
77
|
+
}
|
|
78
|
+
ALTERNATIVE_BATCH_KEYS: set[str] = {
|
|
79
|
+
"batch",
|
|
80
|
+
"sample",
|
|
81
|
+
"dataset",
|
|
82
|
+
"experiment",
|
|
83
|
+
"replicate",
|
|
84
|
+
"batch_id",
|
|
85
|
+
"sample_id",
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
# =============================================================================
|
|
90
|
+
# Field Discovery: Find keys in AnnData
|
|
91
|
+
# =============================================================================
|
|
92
|
+
def get_spatial_key(adata: "ad.AnnData") -> Optional[str]:
|
|
93
|
+
"""Find spatial coordinate key in adata.obsm."""
|
|
94
|
+
for key in ALTERNATIVE_SPATIAL_KEYS:
|
|
95
|
+
if key in adata.obsm:
|
|
96
|
+
return key
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def get_cell_type_key(adata: "ad.AnnData") -> Optional[str]:
|
|
101
|
+
"""Find cell type column in adata.obs."""
|
|
102
|
+
for key in ALTERNATIVE_CELL_TYPE_KEYS:
|
|
103
|
+
if key in adata.obs:
|
|
104
|
+
return key
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def get_cluster_key(adata: "ad.AnnData") -> Optional[str]:
|
|
109
|
+
"""Find cluster column in adata.obs."""
|
|
110
|
+
for key in ALTERNATIVE_CLUSTER_KEYS:
|
|
111
|
+
if key in adata.obs:
|
|
112
|
+
return key
|
|
113
|
+
return None
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def get_batch_key(adata: "ad.AnnData") -> Optional[str]:
|
|
117
|
+
"""Find batch/sample column in adata.obs."""
|
|
118
|
+
for key in ALTERNATIVE_BATCH_KEYS:
|
|
119
|
+
if key in adata.obs:
|
|
120
|
+
return key
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
# =============================================================================
|
|
125
|
+
# Data Access: Get data from AnnData
|
|
126
|
+
# =============================================================================
|
|
127
|
+
def sample_expression_values(
|
|
128
|
+
adata: "ad.AnnData",
|
|
129
|
+
n_samples: int = 1000,
|
|
130
|
+
layer: Optional[str] = None,
|
|
131
|
+
) -> np.ndarray:
|
|
132
|
+
"""
|
|
133
|
+
Sample expression values from data matrix for validation checks.
|
|
134
|
+
|
|
135
|
+
Efficiently samples values from sparse or dense matrices without
|
|
136
|
+
materializing the full matrix. Used for data type detection
|
|
137
|
+
(integer vs float, negative values, etc.).
|
|
138
|
+
|
|
139
|
+
Args:
|
|
140
|
+
adata: AnnData object
|
|
141
|
+
n_samples: Maximum number of values to sample (default: 1000)
|
|
142
|
+
layer: Optional layer name. If None, uses adata.X
|
|
143
|
+
|
|
144
|
+
Returns:
|
|
145
|
+
1D numpy array of sampled expression values
|
|
146
|
+
|
|
147
|
+
Examples:
|
|
148
|
+
# Check for negative values (indicates log-normalized data)
|
|
149
|
+
sample = sample_expression_values(adata)
|
|
150
|
+
if np.any(sample < 0):
|
|
151
|
+
raise ValueError("Log normalization requires non-negative data")
|
|
152
|
+
|
|
153
|
+
# Check for non-integer values (indicates normalized data)
|
|
154
|
+
sample = sample_expression_values(adata)
|
|
155
|
+
if np.any((sample % 1) != 0):
|
|
156
|
+
raise ValueError("Method requires raw count data (integers)")
|
|
157
|
+
"""
|
|
158
|
+
# Get the data matrix
|
|
159
|
+
X = adata.layers[layer] if layer is not None else adata.X
|
|
160
|
+
|
|
161
|
+
# Handle sparse matrices efficiently
|
|
162
|
+
if sparse.issparse(X):
|
|
163
|
+
# For sparse matrices, sample from .data array (non-zero values only)
|
|
164
|
+
# This is efficient as it doesn't require converting to dense
|
|
165
|
+
# Note: All scipy sparse matrices have .data attribute
|
|
166
|
+
if len(X.data) > 0:
|
|
167
|
+
return X.data[: min(n_samples, len(X.data))]
|
|
168
|
+
else:
|
|
169
|
+
# Empty sparse matrix - return slice converted to dense
|
|
170
|
+
return X[:n_samples].toarray().flatten()
|
|
171
|
+
else:
|
|
172
|
+
# For dense matrices, flatten and sample
|
|
173
|
+
return X.flatten()[: min(n_samples, X.size)]
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def require_spatial_coords(
|
|
177
|
+
adata: "ad.AnnData",
|
|
178
|
+
spatial_key: Optional[str] = None,
|
|
179
|
+
validate: bool = True,
|
|
180
|
+
) -> np.ndarray:
|
|
181
|
+
"""
|
|
182
|
+
Get validated spatial coordinates array from AnnData.
|
|
183
|
+
|
|
184
|
+
This is the primary function for accessing spatial coordinates.
|
|
185
|
+
Returns the full coordinates array with optional validation.
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
adata: AnnData object
|
|
189
|
+
spatial_key: Optional key in obsm. If None, auto-detects using
|
|
190
|
+
ALTERNATIVE_SPATIAL_KEYS
|
|
191
|
+
validate: If True (default), validates coordinates for:
|
|
192
|
+
- At least 2 dimensions
|
|
193
|
+
- No NaN values
|
|
194
|
+
- Not all identical
|
|
195
|
+
|
|
196
|
+
Returns:
|
|
197
|
+
Spatial coordinates as 2D numpy array (n_cells, n_dims)
|
|
198
|
+
|
|
199
|
+
Raises:
|
|
200
|
+
DataError: If spatial coordinates not found or validation fails
|
|
201
|
+
|
|
202
|
+
Examples:
|
|
203
|
+
# Auto-detect spatial key
|
|
204
|
+
coords = require_spatial_coords(adata)
|
|
205
|
+
|
|
206
|
+
# Use specific key without validation
|
|
207
|
+
coords = require_spatial_coords(adata, spatial_key="X_spatial", validate=False)
|
|
208
|
+
"""
|
|
209
|
+
# Find spatial key if not specified
|
|
210
|
+
if spatial_key is None:
|
|
211
|
+
spatial_key = get_spatial_key(adata)
|
|
212
|
+
if spatial_key is None:
|
|
213
|
+
# Also check obs for x/y columns
|
|
214
|
+
if "x" in adata.obs and "y" in adata.obs:
|
|
215
|
+
x = pd.to_numeric(adata.obs["x"], errors="coerce").values
|
|
216
|
+
y = pd.to_numeric(adata.obs["y"], errors="coerce").values
|
|
217
|
+
coords = np.column_stack([x, y])
|
|
218
|
+
if validate and np.any(np.isnan(coords)):
|
|
219
|
+
raise DataError("Spatial coordinates in obs['x'/'y'] contain NaN")
|
|
220
|
+
return coords
|
|
221
|
+
|
|
222
|
+
raise DataError(
|
|
223
|
+
"No spatial coordinates found. Expected in adata.obsm['spatial'] "
|
|
224
|
+
"or similar key. Available obsm keys: "
|
|
225
|
+
f"{list(adata.obsm.keys()) if adata.obsm else 'none'}"
|
|
226
|
+
)
|
|
227
|
+
|
|
228
|
+
# Check if key exists
|
|
229
|
+
if spatial_key not in adata.obsm:
|
|
230
|
+
raise DataError(
|
|
231
|
+
f"Spatial coordinates '{spatial_key}' not found in adata.obsm. "
|
|
232
|
+
f"Available keys: {list(adata.obsm.keys())}"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
coords = adata.obsm[spatial_key]
|
|
236
|
+
|
|
237
|
+
# Validate if requested
|
|
238
|
+
if validate:
|
|
239
|
+
if coords.shape[1] < 2:
|
|
240
|
+
raise DataError(
|
|
241
|
+
f"Spatial coordinates should have at least 2 dimensions, "
|
|
242
|
+
f"found {coords.shape[1]}"
|
|
243
|
+
)
|
|
244
|
+
if np.any(np.isnan(coords)):
|
|
245
|
+
raise DataError("Spatial coordinates contain NaN values")
|
|
246
|
+
if np.any(np.isinf(coords)):
|
|
247
|
+
raise DataError("Spatial coordinates contain infinite values")
|
|
248
|
+
if np.std(coords[:, 0]) == 0 and np.std(coords[:, 1]) == 0:
|
|
249
|
+
raise DataError("All spatial coordinates are identical")
|
|
250
|
+
|
|
251
|
+
return coords
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
# =============================================================================
|
|
255
|
+
# Validation: Check and validate AnnData
|
|
256
|
+
# =============================================================================
|
|
257
|
+
def validate_obs_column(
|
|
258
|
+
adata: "ad.AnnData",
|
|
259
|
+
column: str,
|
|
260
|
+
friendly_name: Optional[str] = None,
|
|
261
|
+
) -> None:
|
|
262
|
+
"""
|
|
263
|
+
Validate that a column exists in adata.obs.
|
|
264
|
+
|
|
265
|
+
Raises:
|
|
266
|
+
DataError: If column not found
|
|
267
|
+
"""
|
|
268
|
+
if column not in adata.obs.columns:
|
|
269
|
+
name = friendly_name or f"Column '{column}'"
|
|
270
|
+
available = ", ".join(list(adata.obs.columns)[:10])
|
|
271
|
+
suffix = "..." if len(adata.obs.columns) > 10 else ""
|
|
272
|
+
raise DataError(
|
|
273
|
+
f"{name} not found in adata.obs. Available: {available}{suffix}"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
def validate_var_column(
|
|
278
|
+
adata: "ad.AnnData",
|
|
279
|
+
column: str,
|
|
280
|
+
friendly_name: Optional[str] = None,
|
|
281
|
+
) -> None:
|
|
282
|
+
"""
|
|
283
|
+
Validate that a column exists in adata.var.
|
|
284
|
+
|
|
285
|
+
Raises:
|
|
286
|
+
DataError: If column not found
|
|
287
|
+
"""
|
|
288
|
+
if column not in adata.var.columns:
|
|
289
|
+
name = friendly_name or f"Column '{column}'"
|
|
290
|
+
available = ", ".join(list(adata.var.columns)[:10])
|
|
291
|
+
suffix = "..." if len(adata.var.columns) > 10 else ""
|
|
292
|
+
raise DataError(
|
|
293
|
+
f"{name} not found in adata.var. Available: {available}{suffix}"
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def validate_adata_basics(
|
|
298
|
+
adata: "ad.AnnData",
|
|
299
|
+
min_obs: int = 1,
|
|
300
|
+
min_vars: int = 1,
|
|
301
|
+
check_empty_ratio: bool = False,
|
|
302
|
+
max_empty_obs_ratio: float = 0.1,
|
|
303
|
+
max_empty_vars_ratio: float = 0.5,
|
|
304
|
+
) -> None:
|
|
305
|
+
"""Validate basic AnnData structure.
|
|
306
|
+
|
|
307
|
+
Args:
|
|
308
|
+
adata: AnnData object to validate
|
|
309
|
+
min_obs: Minimum number of observations (cells/spots) required
|
|
310
|
+
min_vars: Minimum number of variables (genes) required
|
|
311
|
+
check_empty_ratio: If True, also check for empty cells/genes
|
|
312
|
+
max_empty_obs_ratio: Max fraction of cells with zero expression (default 10%)
|
|
313
|
+
max_empty_vars_ratio: Max fraction of genes with zero expression (default 50%)
|
|
314
|
+
|
|
315
|
+
Raises:
|
|
316
|
+
DataError: If validation fails
|
|
317
|
+
"""
|
|
318
|
+
if adata is None:
|
|
319
|
+
raise DataError("AnnData object cannot be None")
|
|
320
|
+
if adata.n_obs < min_obs:
|
|
321
|
+
raise DataError(f"Dataset has {adata.n_obs} observations, need {min_obs}")
|
|
322
|
+
if adata.n_vars < min_vars:
|
|
323
|
+
raise DataError(f"Dataset has {adata.n_vars} variables, need {min_vars}")
|
|
324
|
+
|
|
325
|
+
if check_empty_ratio:
|
|
326
|
+
# Count non-zero entries per cell/gene (sparse-aware)
|
|
327
|
+
if sparse.issparse(adata.X):
|
|
328
|
+
cell_nnz = np.array(adata.X.getnnz(axis=1)).flatten()
|
|
329
|
+
gene_nnz = np.array(adata.X.getnnz(axis=0)).flatten()
|
|
330
|
+
else:
|
|
331
|
+
cell_nnz = np.sum(adata.X > 0, axis=1)
|
|
332
|
+
gene_nnz = np.sum(adata.X > 0, axis=0)
|
|
333
|
+
|
|
334
|
+
empty_cells = np.sum(cell_nnz == 0)
|
|
335
|
+
empty_genes = np.sum(gene_nnz == 0)
|
|
336
|
+
|
|
337
|
+
if empty_cells > adata.n_obs * max_empty_obs_ratio:
|
|
338
|
+
pct = empty_cells / adata.n_obs * 100
|
|
339
|
+
raise DataError(
|
|
340
|
+
f"{empty_cells} cells ({pct:.1f}%) have zero expression. "
|
|
341
|
+
f"Check data quality and consider filtering."
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
if empty_genes > adata.n_vars * max_empty_vars_ratio:
|
|
345
|
+
pct = empty_genes / adata.n_vars * 100
|
|
346
|
+
raise DataError(
|
|
347
|
+
f"{empty_genes} genes ({pct:.1f}%) have zero expression. "
|
|
348
|
+
f"Consider gene filtering."
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def ensure_categorical(adata: "ad.AnnData", column: str) -> None:
|
|
353
|
+
"""Ensure a column is categorical dtype, converting if needed."""
|
|
354
|
+
if column not in adata.obs.columns:
|
|
355
|
+
raise DataError(f"Column '{column}' not found in adata.obs")
|
|
356
|
+
if not pd.api.types.is_categorical_dtype(adata.obs[column]):
|
|
357
|
+
adata.obs[column] = adata.obs[column].astype("category")
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
# =============================================================================
|
|
361
|
+
# Standardization
|
|
362
|
+
# =============================================================================
|
|
363
|
+
def standardize_adata(adata: "ad.AnnData", copy: bool = True) -> "ad.AnnData":
|
|
364
|
+
"""
|
|
365
|
+
Standardize AnnData to ChatSpatial conventions.
|
|
366
|
+
|
|
367
|
+
Does:
|
|
368
|
+
1. Move spatial coordinates to obsm['spatial']
|
|
369
|
+
2. Make gene names unique
|
|
370
|
+
3. Convert known categorical columns to category dtype
|
|
371
|
+
|
|
372
|
+
Does NOT:
|
|
373
|
+
- Compute HVGs (use preprocessing)
|
|
374
|
+
- Compute spatial neighbors (computed by analysis tools)
|
|
375
|
+
"""
|
|
376
|
+
if copy:
|
|
377
|
+
adata = adata.copy()
|
|
378
|
+
|
|
379
|
+
# Standardize spatial coordinates
|
|
380
|
+
_move_spatial_to_standard(adata)
|
|
381
|
+
|
|
382
|
+
# Make gene names unique
|
|
383
|
+
ensure_unique_var_names(adata)
|
|
384
|
+
|
|
385
|
+
# Ensure categorical columns for known key types
|
|
386
|
+
all_categorical_keys = (
|
|
387
|
+
ALTERNATIVE_CELL_TYPE_KEYS | ALTERNATIVE_CLUSTER_KEYS | ALTERNATIVE_BATCH_KEYS
|
|
388
|
+
)
|
|
389
|
+
for key in adata.obs.columns:
|
|
390
|
+
if key in all_categorical_keys:
|
|
391
|
+
ensure_categorical(adata, key)
|
|
392
|
+
|
|
393
|
+
return adata
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def _move_spatial_to_standard(adata: "ad.AnnData") -> None:
|
|
397
|
+
"""Move spatial coordinates to standard obsm['spatial'] location."""
|
|
398
|
+
if SPATIAL_KEY in adata.obsm:
|
|
399
|
+
return
|
|
400
|
+
|
|
401
|
+
# Check alternative obsm keys
|
|
402
|
+
for key in ALTERNATIVE_SPATIAL_KEYS:
|
|
403
|
+
if key in adata.obsm and key != SPATIAL_KEY:
|
|
404
|
+
adata.obsm[SPATIAL_KEY] = adata.obsm[key]
|
|
405
|
+
return
|
|
406
|
+
|
|
407
|
+
# Check obs x/y
|
|
408
|
+
if "x" in adata.obs and "y" in adata.obs:
|
|
409
|
+
try:
|
|
410
|
+
x = pd.to_numeric(adata.obs["x"], errors="coerce").values
|
|
411
|
+
y = pd.to_numeric(adata.obs["y"], errors="coerce").values
|
|
412
|
+
if not (np.any(np.isnan(x)) or np.any(np.isnan(y))):
|
|
413
|
+
adata.obsm[SPATIAL_KEY] = np.column_stack([x, y]).astype("float64")
|
|
414
|
+
except Exception:
|
|
415
|
+
pass
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
# =============================================================================
|
|
419
|
+
# Advanced Validation: validate_adata with optional checks
|
|
420
|
+
# =============================================================================
|
|
421
|
+
def validate_adata(
|
|
422
|
+
adata: "ad.AnnData",
|
|
423
|
+
required_keys: dict,
|
|
424
|
+
check_spatial: bool = False,
|
|
425
|
+
check_velocity: bool = False,
|
|
426
|
+
spatial_key: str = "spatial",
|
|
427
|
+
) -> None:
|
|
428
|
+
"""
|
|
429
|
+
Validate AnnData object has required keys and optional data integrity checks.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
adata: AnnData object to validate
|
|
433
|
+
required_keys: Dict of required keys by category (obs, var, obsm, etc.)
|
|
434
|
+
check_spatial: Whether to validate spatial coordinates
|
|
435
|
+
check_velocity: Whether to validate velocity data layers
|
|
436
|
+
spatial_key: Key for spatial coordinates in adata.obsm
|
|
437
|
+
|
|
438
|
+
Raises:
|
|
439
|
+
DataError: If required keys are missing or validation fails
|
|
440
|
+
"""
|
|
441
|
+
missing = []
|
|
442
|
+
|
|
443
|
+
for category, keys in required_keys.items():
|
|
444
|
+
if isinstance(keys, str):
|
|
445
|
+
keys = [keys]
|
|
446
|
+
|
|
447
|
+
attr = getattr(adata, category, None)
|
|
448
|
+
if attr is None:
|
|
449
|
+
missing.extend([f"{category}.{k}" for k in keys])
|
|
450
|
+
continue
|
|
451
|
+
|
|
452
|
+
for key in keys:
|
|
453
|
+
if hasattr(attr, "columns"): # DataFrame
|
|
454
|
+
if key not in attr.columns:
|
|
455
|
+
missing.append(f"{category}.{key}")
|
|
456
|
+
elif hasattr(attr, "keys"): # Dict-like
|
|
457
|
+
if key not in attr.keys():
|
|
458
|
+
missing.append(f"{category}.{key}")
|
|
459
|
+
else:
|
|
460
|
+
missing.append(f"{category}.{key}")
|
|
461
|
+
|
|
462
|
+
if missing:
|
|
463
|
+
raise DataError(f"Missing required keys: {', '.join(missing)}")
|
|
464
|
+
|
|
465
|
+
# Enhanced validation checks
|
|
466
|
+
if check_spatial:
|
|
467
|
+
_validate_spatial_data(adata, spatial_key, missing)
|
|
468
|
+
|
|
469
|
+
if check_velocity:
|
|
470
|
+
_validate_velocity_data(adata, missing)
|
|
471
|
+
|
|
472
|
+
if missing:
|
|
473
|
+
raise DataError(f"Validation failed: {', '.join(missing)}")
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
def _validate_spatial_data(
|
|
477
|
+
adata: "ad.AnnData", spatial_key: str, issues: list[str]
|
|
478
|
+
) -> None:
|
|
479
|
+
"""Internal helper for spatial data validation."""
|
|
480
|
+
if spatial_key not in adata.obsm:
|
|
481
|
+
issues.append(f"Missing '{spatial_key}' coordinates in adata.obsm")
|
|
482
|
+
return
|
|
483
|
+
|
|
484
|
+
spatial_coords = adata.obsm[spatial_key]
|
|
485
|
+
|
|
486
|
+
if spatial_coords.shape[1] < 2:
|
|
487
|
+
issues.append(
|
|
488
|
+
f"Spatial coordinates should have at least 2 dimensions, "
|
|
489
|
+
f"found {spatial_coords.shape[1]}"
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
if np.any(np.isnan(spatial_coords)):
|
|
493
|
+
issues.append("Spatial coordinates contain NaN values")
|
|
494
|
+
|
|
495
|
+
if np.std(spatial_coords[:, 0]) == 0 and np.std(spatial_coords[:, 1]) == 0:
|
|
496
|
+
issues.append("All spatial coordinates are identical")
|
|
497
|
+
|
|
498
|
+
|
|
499
|
+
def _validate_velocity_data(adata: "ad.AnnData", issues: list[str]) -> None:
|
|
500
|
+
"""Internal helper for velocity data validation."""
|
|
501
|
+
if "spliced" not in adata.layers:
|
|
502
|
+
issues.append("Missing 'spliced' layer required for RNA velocity")
|
|
503
|
+
if "unspliced" not in adata.layers:
|
|
504
|
+
issues.append("Missing 'unspliced' layer required for RNA velocity")
|
|
505
|
+
|
|
506
|
+
if "spliced" in adata.layers and "unspliced" in adata.layers:
|
|
507
|
+
for layer_name in ["spliced", "unspliced"]:
|
|
508
|
+
layer_data = adata.layers[layer_name]
|
|
509
|
+
|
|
510
|
+
if hasattr(layer_data, "nnz"): # Sparse matrix
|
|
511
|
+
if layer_data.nnz == 0:
|
|
512
|
+
issues.append(f"'{layer_name}' layer is empty (all zeros)")
|
|
513
|
+
else: # Dense matrix
|
|
514
|
+
if np.all(layer_data == 0):
|
|
515
|
+
issues.append(f"'{layer_name}' layer is empty (all zeros)")
|
|
516
|
+
|
|
517
|
+
if hasattr(layer_data, "data"): # Sparse matrix
|
|
518
|
+
if np.any(np.isnan(layer_data.data)):
|
|
519
|
+
issues.append(f"'{layer_name}' layer contains NaN values")
|
|
520
|
+
else: # Dense matrix
|
|
521
|
+
if np.any(np.isnan(layer_data)):
|
|
522
|
+
issues.append(f"'{layer_name}' layer contains NaN values")
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
# =============================================================================
|
|
526
|
+
# Metadata Storage: Scientific Provenance Tracking
|
|
527
|
+
# =============================================================================
|
|
528
|
+
def store_analysis_metadata(
|
|
529
|
+
adata: "ad.AnnData",
|
|
530
|
+
analysis_name: str,
|
|
531
|
+
method: str,
|
|
532
|
+
parameters: dict[str, Any],
|
|
533
|
+
results_keys: dict[str, list[str]],
|
|
534
|
+
statistics: Optional[dict[str, Any]] = None,
|
|
535
|
+
species: Optional[str] = None,
|
|
536
|
+
database: Optional[str] = None,
|
|
537
|
+
reference_info: Optional[dict[str, Any]] = None,
|
|
538
|
+
) -> None:
|
|
539
|
+
"""Store analysis metadata in adata.uns for scientific provenance tracking.
|
|
540
|
+
|
|
541
|
+
This function stores ONLY scientifically important metadata:
|
|
542
|
+
- Method name (required for reproducibility)
|
|
543
|
+
- Parameters (required for reproducibility)
|
|
544
|
+
- Results locations (required for data access)
|
|
545
|
+
- Statistics (required for quality assessment)
|
|
546
|
+
- Species/Database (required for biological interpretation)
|
|
547
|
+
- Reference info (required for reference-based methods)
|
|
548
|
+
|
|
549
|
+
Args:
|
|
550
|
+
adata: AnnData object to store metadata in
|
|
551
|
+
analysis_name: Name of the analysis (e.g., "annotation_tangram")
|
|
552
|
+
method: Method name (e.g., "tangram", "liana", "cellrank")
|
|
553
|
+
parameters: Dictionary of analysis parameters
|
|
554
|
+
results_keys: Dictionary mapping storage location to list of keys
|
|
555
|
+
Example: {"obs": ["cell_type_tangram"], "obsm": ["tangram_ct_pred"]}
|
|
556
|
+
statistics: Optional dictionary of quality/summary statistics
|
|
557
|
+
species: Optional species identifier (critical for communication/enrichment)
|
|
558
|
+
database: Optional database/resource name (critical for communication/enrichment)
|
|
559
|
+
reference_info: Optional reference dataset information
|
|
560
|
+
"""
|
|
561
|
+
# Build metadata dictionary - only scientifically important information
|
|
562
|
+
metadata = {
|
|
563
|
+
"method": method,
|
|
564
|
+
"parameters": parameters,
|
|
565
|
+
"results_keys": results_keys,
|
|
566
|
+
}
|
|
567
|
+
|
|
568
|
+
# Add optional scientific metadata
|
|
569
|
+
if statistics is not None:
|
|
570
|
+
metadata["statistics"] = statistics
|
|
571
|
+
|
|
572
|
+
if species is not None:
|
|
573
|
+
metadata["species"] = species
|
|
574
|
+
|
|
575
|
+
if database is not None:
|
|
576
|
+
metadata["database"] = database
|
|
577
|
+
|
|
578
|
+
if reference_info is not None:
|
|
579
|
+
metadata["reference_info"] = reference_info
|
|
580
|
+
|
|
581
|
+
# Store in adata.uns with unique key
|
|
582
|
+
metadata_key = f"{analysis_name}_metadata"
|
|
583
|
+
adata.uns[metadata_key] = metadata
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
def get_analysis_parameter(
|
|
587
|
+
adata: "ad.AnnData",
|
|
588
|
+
analysis_name: str,
|
|
589
|
+
parameter_name: str,
|
|
590
|
+
default: Any = None,
|
|
591
|
+
) -> Any:
|
|
592
|
+
"""Get a parameter from stored analysis metadata.
|
|
593
|
+
|
|
594
|
+
Retrieves parameters stored by store_analysis_metadata(). Use this to
|
|
595
|
+
access analysis parameters (like cluster_key) without re-inferring them.
|
|
596
|
+
|
|
597
|
+
Args:
|
|
598
|
+
adata: AnnData object
|
|
599
|
+
analysis_name: Name of the analysis (e.g., "spatial_stats_neighborhood")
|
|
600
|
+
parameter_name: Name of the parameter (e.g., "cluster_key")
|
|
601
|
+
default: Default value if parameter not found
|
|
602
|
+
|
|
603
|
+
Returns:
|
|
604
|
+
Parameter value or default
|
|
605
|
+
|
|
606
|
+
Example:
|
|
607
|
+
# Get cluster_key used in neighborhood analysis
|
|
608
|
+
cluster_key = get_analysis_parameter(
|
|
609
|
+
adata, "spatial_stats_neighborhood", "cluster_key"
|
|
610
|
+
)
|
|
611
|
+
"""
|
|
612
|
+
metadata_key = f"{analysis_name}_metadata"
|
|
613
|
+
if metadata_key not in adata.uns:
|
|
614
|
+
return default
|
|
615
|
+
|
|
616
|
+
metadata = adata.uns[metadata_key]
|
|
617
|
+
if "parameters" not in metadata:
|
|
618
|
+
return default
|
|
619
|
+
|
|
620
|
+
return metadata["parameters"].get(parameter_name, default)
|
|
621
|
+
|
|
622
|
+
|
|
623
|
+
# =============================================================================
|
|
624
|
+
# Gene Selection Utilities
|
|
625
|
+
# =============================================================================
|
|
626
|
+
def get_highly_variable_genes(
|
|
627
|
+
adata: "ad.AnnData",
|
|
628
|
+
max_genes: int = 500,
|
|
629
|
+
fallback_to_variance: bool = True,
|
|
630
|
+
) -> list[str]:
|
|
631
|
+
"""
|
|
632
|
+
Get highly variable genes from AnnData.
|
|
633
|
+
|
|
634
|
+
Priority order:
|
|
635
|
+
1. Use precomputed HVG from adata.var['highly_variable']
|
|
636
|
+
2. If fallback enabled, compute variance and return top variable genes
|
|
637
|
+
|
|
638
|
+
Args:
|
|
639
|
+
adata: AnnData object
|
|
640
|
+
max_genes: Maximum number of genes to return
|
|
641
|
+
fallback_to_variance: If True, compute variance when HVG not available
|
|
642
|
+
|
|
643
|
+
Returns:
|
|
644
|
+
List of gene names (may be shorter than max_genes if fewer available)
|
|
645
|
+
"""
|
|
646
|
+
# Try precomputed HVG first
|
|
647
|
+
if "highly_variable" in adata.var.columns:
|
|
648
|
+
hvg_genes = adata.var_names[adata.var["highly_variable"]].tolist()
|
|
649
|
+
return hvg_genes[:max_genes]
|
|
650
|
+
|
|
651
|
+
# Fallback to variance calculation
|
|
652
|
+
if fallback_to_variance:
|
|
653
|
+
from scipy import sparse
|
|
654
|
+
|
|
655
|
+
if sparse.issparse(adata.X):
|
|
656
|
+
# Compute variance on sparse matrix without converting to dense
|
|
657
|
+
# Var(X) = E[X^2] - E[X]^2 (memory efficient, ~5x faster)
|
|
658
|
+
mean = np.array(adata.X.mean(axis=0)).flatten()
|
|
659
|
+
mean_sq = np.array(adata.X.power(2).mean(axis=0)).flatten()
|
|
660
|
+
var_scores = mean_sq - mean**2
|
|
661
|
+
else:
|
|
662
|
+
var_scores = np.array(adata.X.var(axis=0)).flatten()
|
|
663
|
+
|
|
664
|
+
top_indices = np.argsort(var_scores)[-max_genes:]
|
|
665
|
+
return adata.var_names[top_indices].tolist()
|
|
666
|
+
|
|
667
|
+
return []
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def select_genes_for_analysis(
|
|
671
|
+
adata: "ad.AnnData",
|
|
672
|
+
genes: Optional[list[str]] = None,
|
|
673
|
+
n_genes: int = 20,
|
|
674
|
+
require_hvg: bool = True,
|
|
675
|
+
analysis_name: str = "analysis",
|
|
676
|
+
) -> list[str]:
|
|
677
|
+
"""
|
|
678
|
+
Select genes for spatial/statistical analysis.
|
|
679
|
+
|
|
680
|
+
Unified gene selection logic for all analysis tools. Replaces duplicated
|
|
681
|
+
code across spatial_statistics.py and other tools.
|
|
682
|
+
|
|
683
|
+
Priority:
|
|
684
|
+
1. User-specified genes (filtered to existing genes)
|
|
685
|
+
2. Highly variable genes (HVG) from preprocessing
|
|
686
|
+
|
|
687
|
+
Args:
|
|
688
|
+
adata: AnnData object
|
|
689
|
+
genes: User-specified gene list. If provided, filters to genes in adata.
|
|
690
|
+
n_genes: Maximum number of genes to return when using HVG.
|
|
691
|
+
require_hvg: If True (default), raise error when HVG not found.
|
|
692
|
+
If False, return empty list when HVG not found.
|
|
693
|
+
analysis_name: Name of analysis for error messages (e.g., "Moran's I").
|
|
694
|
+
|
|
695
|
+
Returns:
|
|
696
|
+
List of gene names to analyze.
|
|
697
|
+
|
|
698
|
+
Raises:
|
|
699
|
+
DataError: If genes specified but none found, or HVG required but missing.
|
|
700
|
+
|
|
701
|
+
Examples:
|
|
702
|
+
# Use user-specified genes
|
|
703
|
+
genes = select_genes_for_analysis(adata, genes=["CD4", "CD8A"])
|
|
704
|
+
|
|
705
|
+
# Use top 50 HVGs
|
|
706
|
+
genes = select_genes_for_analysis(adata, n_genes=50)
|
|
707
|
+
|
|
708
|
+
# For analysis that can work without HVG
|
|
709
|
+
genes = select_genes_for_analysis(adata, require_hvg=False)
|
|
710
|
+
"""
|
|
711
|
+
# Case 1: User specified genes
|
|
712
|
+
if genes is not None:
|
|
713
|
+
valid_genes = [g for g in genes if g in adata.var_names]
|
|
714
|
+
if not valid_genes:
|
|
715
|
+
# Find closest matches for better error message
|
|
716
|
+
from difflib import get_close_matches
|
|
717
|
+
|
|
718
|
+
suggestions = []
|
|
719
|
+
for g in genes[:3]: # Check first 3 genes
|
|
720
|
+
matches = get_close_matches(
|
|
721
|
+
g, adata.var_names.tolist(), n=1, cutoff=0.6
|
|
722
|
+
)
|
|
723
|
+
if matches:
|
|
724
|
+
suggestions.append(f"'{g}' → '{matches[0]}'?")
|
|
725
|
+
|
|
726
|
+
suggestion_str = (
|
|
727
|
+
f" Did you mean: {', '.join(suggestions)}" if suggestions else ""
|
|
728
|
+
)
|
|
729
|
+
raise DataError(
|
|
730
|
+
f"None of the specified genes found in data: {genes[:5]}..."
|
|
731
|
+
f"{suggestion_str}"
|
|
732
|
+
)
|
|
733
|
+
return valid_genes
|
|
734
|
+
|
|
735
|
+
# Case 2: Use HVG
|
|
736
|
+
if "highly_variable" in adata.var.columns and adata.var["highly_variable"].any():
|
|
737
|
+
hvg_genes = adata.var_names[adata.var["highly_variable"]].tolist()
|
|
738
|
+
return hvg_genes[:n_genes]
|
|
739
|
+
|
|
740
|
+
# Case 3: HVG not available
|
|
741
|
+
if require_hvg:
|
|
742
|
+
raise DataError(
|
|
743
|
+
f"Highly variable genes (HVG) required for {analysis_name}.\n\n"
|
|
744
|
+
"Solutions:\n"
|
|
745
|
+
"1. Run preprocess_data() first to compute HVGs\n"
|
|
746
|
+
"2. Specify genes explicitly via 'genes' parameter"
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
return []
|
|
750
|
+
|
|
751
|
+
|
|
752
|
+
# =============================================================================
|
|
753
|
+
# Gene Name Utilities
|
|
754
|
+
# =============================================================================
|
|
755
|
+
def ensure_unique_var_names(
|
|
756
|
+
adata: "ad.AnnData",
|
|
757
|
+
label: str = "data",
|
|
758
|
+
) -> int:
|
|
759
|
+
"""
|
|
760
|
+
Ensure gene names are unique, fixing duplicates if needed.
|
|
761
|
+
|
|
762
|
+
Args:
|
|
763
|
+
adata: AnnData object (modified in-place)
|
|
764
|
+
label: Label for logging (not used in sync version, for API consistency)
|
|
765
|
+
|
|
766
|
+
Returns:
|
|
767
|
+
Number of duplicate gene names that were fixed (0 if already unique)
|
|
768
|
+
"""
|
|
769
|
+
if adata.var_names.is_unique:
|
|
770
|
+
return 0
|
|
771
|
+
|
|
772
|
+
n_duplicates = len(adata.var_names) - len(set(adata.var_names))
|
|
773
|
+
adata.var_names_make_unique()
|
|
774
|
+
return n_duplicates
|
|
775
|
+
|
|
776
|
+
|
|
777
|
+
async def ensure_unique_var_names_async(
|
|
778
|
+
adata: "ad.AnnData",
|
|
779
|
+
ctx: Any, # ToolContext, use Any to avoid circular import
|
|
780
|
+
label: str = "data",
|
|
781
|
+
) -> int:
|
|
782
|
+
"""
|
|
783
|
+
Ensure gene names are unique with user feedback via ctx.
|
|
784
|
+
|
|
785
|
+
Async variant of ensure_unique_var_names with context logging.
|
|
786
|
+
|
|
787
|
+
Args:
|
|
788
|
+
adata: AnnData object (modified in-place)
|
|
789
|
+
ctx: ToolContext for logging warnings to user
|
|
790
|
+
label: Descriptive label for the data (e.g., "reference data", "query data")
|
|
791
|
+
|
|
792
|
+
Returns:
|
|
793
|
+
Number of duplicate gene names that were fixed (0 if already unique)
|
|
794
|
+
"""
|
|
795
|
+
n_fixed = ensure_unique_var_names(adata, label)
|
|
796
|
+
if n_fixed > 0:
|
|
797
|
+
await ctx.warning(f"Found {n_fixed} duplicate gene names in {label}, fixed")
|
|
798
|
+
return n_fixed
|
|
799
|
+
|
|
800
|
+
|
|
801
|
+
# =============================================================================
|
|
802
|
+
# Raw Counts Data Access: Unified interface for accessing raw data
|
|
803
|
+
# =============================================================================
|
|
804
|
+
|
|
805
|
+
|
|
806
|
+
def check_is_integer_counts(X: Any, sample_size: int = 100) -> tuple[bool, bool, bool]:
|
|
807
|
+
"""Check if a matrix contains integer counts.
|
|
808
|
+
|
|
809
|
+
This is a lightweight utility for checking data format without
|
|
810
|
+
going through the full data source detection logic.
|
|
811
|
+
|
|
812
|
+
Args:
|
|
813
|
+
X: Data matrix (sparse or dense)
|
|
814
|
+
sample_size: Number of rows/cols to sample for efficiency
|
|
815
|
+
|
|
816
|
+
Returns:
|
|
817
|
+
Tuple of (is_integer, has_negatives, has_decimals)
|
|
818
|
+
"""
|
|
819
|
+
n_rows = min(sample_size, X.shape[0])
|
|
820
|
+
n_cols = min(sample_size, X.shape[1])
|
|
821
|
+
sample = X[:n_rows, :n_cols]
|
|
822
|
+
|
|
823
|
+
if sparse.issparse(sample):
|
|
824
|
+
sample = sample.toarray()
|
|
825
|
+
|
|
826
|
+
has_negatives = float(sample.min()) < 0
|
|
827
|
+
has_decimals = not np.allclose(sample, np.round(sample), atol=1e-6)
|
|
828
|
+
is_integer = not has_negatives and not has_decimals
|
|
829
|
+
|
|
830
|
+
return is_integer, has_negatives, has_decimals
|
|
831
|
+
|
|
832
|
+
|
|
833
|
+
def ensure_counts_layer(
|
|
834
|
+
adata: "ad.AnnData",
|
|
835
|
+
layer_name: str = "counts",
|
|
836
|
+
error_message: Optional[str] = None,
|
|
837
|
+
) -> bool:
|
|
838
|
+
"""Ensure a counts layer exists in AnnData, creating from raw if needed.
|
|
839
|
+
|
|
840
|
+
This is the single source of truth for counts layer preparation.
|
|
841
|
+
Used by scVI-tools methods (scANVI, Cell2location, etc.) that require
|
|
842
|
+
raw counts in a specific layer.
|
|
843
|
+
|
|
844
|
+
Args:
|
|
845
|
+
adata: AnnData object (modified in-place)
|
|
846
|
+
layer_name: Name of the layer to ensure (default: "counts")
|
|
847
|
+
error_message: Custom error message if counts cannot be created
|
|
848
|
+
|
|
849
|
+
Returns:
|
|
850
|
+
True if layer was created, False if already existed
|
|
851
|
+
|
|
852
|
+
Raises:
|
|
853
|
+
DataNotFoundError: If counts layer cannot be created
|
|
854
|
+
|
|
855
|
+
Note:
|
|
856
|
+
When adata has been subsetted to HVGs, this function correctly
|
|
857
|
+
subsets adata.raw to match the current var_names.
|
|
858
|
+
|
|
859
|
+
Examples:
|
|
860
|
+
# Ensure counts layer exists before scANVI setup
|
|
861
|
+
ensure_counts_layer(adata_ref)
|
|
862
|
+
scvi.model.SCANVI.setup_anndata(adata_ref, layer="counts", ...)
|
|
863
|
+
|
|
864
|
+
# With custom error message
|
|
865
|
+
ensure_counts_layer(adata, error_message="scANVI requires raw counts")
|
|
866
|
+
"""
|
|
867
|
+
from .exceptions import DataNotFoundError
|
|
868
|
+
|
|
869
|
+
if layer_name in adata.layers:
|
|
870
|
+
return False
|
|
871
|
+
|
|
872
|
+
if adata.raw is not None:
|
|
873
|
+
# Get raw counts, subsetting to current var_names
|
|
874
|
+
# Note: adata.raw may have full genes while adata has HVG subset
|
|
875
|
+
adata.layers[layer_name] = adata.raw[:, adata.var_names].X
|
|
876
|
+
return True
|
|
877
|
+
|
|
878
|
+
# Cannot create counts layer
|
|
879
|
+
default_error = (
|
|
880
|
+
f"Cannot create '{layer_name}' layer: adata.raw is None. "
|
|
881
|
+
"Load unpreprocessed data or ensure adata.raw is preserved during preprocessing."
|
|
882
|
+
)
|
|
883
|
+
raise DataNotFoundError(error_message or default_error)
|
|
884
|
+
|
|
885
|
+
|
|
886
|
+
class RawDataResult:
|
|
887
|
+
"""Result of raw data extraction."""
|
|
888
|
+
|
|
889
|
+
def __init__(
|
|
890
|
+
self,
|
|
891
|
+
X: Any, # sparse or dense matrix
|
|
892
|
+
var_names: pd.Index,
|
|
893
|
+
source: str,
|
|
894
|
+
is_integer_counts: bool,
|
|
895
|
+
has_negatives: bool = False,
|
|
896
|
+
has_decimals: bool = False,
|
|
897
|
+
):
|
|
898
|
+
self.X = X
|
|
899
|
+
self.var_names = var_names
|
|
900
|
+
self.source = source
|
|
901
|
+
self.is_integer_counts = is_integer_counts
|
|
902
|
+
self.has_negatives = has_negatives
|
|
903
|
+
self.has_decimals = has_decimals
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
def get_raw_data_source(
|
|
907
|
+
adata: "ad.AnnData",
|
|
908
|
+
prefer_complete_genes: bool = True,
|
|
909
|
+
require_integer_counts: bool = False,
|
|
910
|
+
sample_size: int = 100,
|
|
911
|
+
) -> RawDataResult:
|
|
912
|
+
"""
|
|
913
|
+
Get raw count data from AnnData using a unified priority order.
|
|
914
|
+
|
|
915
|
+
This is THE single source of truth for accessing raw counts data.
|
|
916
|
+
All tools should use this function instead of implementing their own logic.
|
|
917
|
+
|
|
918
|
+
Priority order (when prefer_complete_genes=True):
|
|
919
|
+
1. adata.raw - Complete gene set, preserved before HVG filtering
|
|
920
|
+
2. adata.layers["counts"] - Raw counts layer
|
|
921
|
+
3. adata.X - Current expression matrix
|
|
922
|
+
|
|
923
|
+
Priority order (when prefer_complete_genes=False):
|
|
924
|
+
1. adata.layers["counts"] - Raw counts layer
|
|
925
|
+
2. adata.X - Current expression matrix
|
|
926
|
+
(adata.raw is skipped as it may have different dimensions)
|
|
927
|
+
|
|
928
|
+
Args:
|
|
929
|
+
adata: AnnData object
|
|
930
|
+
prefer_complete_genes: If True, prefer adata.raw for complete gene coverage.
|
|
931
|
+
Set to False when you need data aligned with current adata dimensions.
|
|
932
|
+
require_integer_counts: If True, validate that data contains integer counts.
|
|
933
|
+
Raises DataError if only normalized data is found.
|
|
934
|
+
sample_size: Number of cells/genes to sample for validation.
|
|
935
|
+
|
|
936
|
+
Returns:
|
|
937
|
+
RawDataResult with data matrix, var_names, source name, and validation info.
|
|
938
|
+
|
|
939
|
+
Raises:
|
|
940
|
+
DataError: If require_integer_counts=True and no integer counts found.
|
|
941
|
+
|
|
942
|
+
Example:
|
|
943
|
+
result = get_raw_data_source(adata, prefer_complete_genes=True)
|
|
944
|
+
print(f"Using {result.source}: {len(result.var_names)} genes")
|
|
945
|
+
if result.is_integer_counts:
|
|
946
|
+
# Safe to use for deconvolution/velocity
|
|
947
|
+
pass
|
|
948
|
+
"""
|
|
949
|
+
sources_tried = []
|
|
950
|
+
|
|
951
|
+
# Source 1: adata.raw (complete gene set)
|
|
952
|
+
if prefer_complete_genes and adata.raw is not None:
|
|
953
|
+
try:
|
|
954
|
+
raw_adata = adata.raw.to_adata()
|
|
955
|
+
is_int, has_neg, has_dec = check_is_integer_counts(raw_adata.X, sample_size)
|
|
956
|
+
|
|
957
|
+
if is_int or not require_integer_counts:
|
|
958
|
+
return RawDataResult(
|
|
959
|
+
X=raw_adata.X,
|
|
960
|
+
var_names=raw_adata.var_names,
|
|
961
|
+
source="raw",
|
|
962
|
+
is_integer_counts=is_int,
|
|
963
|
+
has_negatives=has_neg,
|
|
964
|
+
has_decimals=has_dec,
|
|
965
|
+
)
|
|
966
|
+
sources_tried.append("raw (normalized, skipped)")
|
|
967
|
+
except Exception:
|
|
968
|
+
sources_tried.append("raw (error, skipped)")
|
|
969
|
+
|
|
970
|
+
# Source 2: layers["counts"]
|
|
971
|
+
if "counts" in adata.layers:
|
|
972
|
+
X_counts = adata.layers["counts"]
|
|
973
|
+
is_int, has_neg, has_dec = check_is_integer_counts(X_counts, sample_size)
|
|
974
|
+
|
|
975
|
+
if is_int or not require_integer_counts:
|
|
976
|
+
return RawDataResult(
|
|
977
|
+
X=X_counts,
|
|
978
|
+
var_names=adata.var_names,
|
|
979
|
+
source="counts_layer",
|
|
980
|
+
is_integer_counts=is_int,
|
|
981
|
+
has_negatives=has_neg,
|
|
982
|
+
has_decimals=has_dec,
|
|
983
|
+
)
|
|
984
|
+
sources_tried.append("counts_layer (normalized, skipped)")
|
|
985
|
+
|
|
986
|
+
# Source 3: current X
|
|
987
|
+
is_int, has_neg, has_dec = check_is_integer_counts(adata.X, sample_size)
|
|
988
|
+
|
|
989
|
+
if is_int or not require_integer_counts:
|
|
990
|
+
return RawDataResult(
|
|
991
|
+
X=adata.X,
|
|
992
|
+
var_names=adata.var_names,
|
|
993
|
+
source="current",
|
|
994
|
+
is_integer_counts=is_int,
|
|
995
|
+
has_negatives=has_neg,
|
|
996
|
+
has_decimals=has_dec,
|
|
997
|
+
)
|
|
998
|
+
|
|
999
|
+
# If we reach here, require_integer_counts=True but no valid source found
|
|
1000
|
+
# (line 1012 would have returned if require_integer_counts=False)
|
|
1001
|
+
raise DataError(
|
|
1002
|
+
f"No raw integer counts found. Sources tried: {sources_tried + ['current (normalized)']}. "
|
|
1003
|
+
f"Data appears to be normalized (has_negatives={has_neg}, has_decimals={has_dec}). "
|
|
1004
|
+
"Deconvolution and velocity methods require raw integer counts. "
|
|
1005
|
+
"Solutions: (1) Load unpreprocessed data, (2) Ensure adata.layers['counts'] "
|
|
1006
|
+
"contains raw counts, or (3) Re-run preprocessing with adata.raw preservation."
|
|
1007
|
+
)
|
|
1008
|
+
|
|
1009
|
+
|
|
1010
|
+
# =============================================================================
|
|
1011
|
+
# Expression Data Extraction: Unified sparse/dense handling
|
|
1012
|
+
# =============================================================================
|
|
1013
|
+
def to_dense(X: Any, copy: bool = False) -> np.ndarray:
|
|
1014
|
+
"""
|
|
1015
|
+
Convert sparse matrix to dense numpy array.
|
|
1016
|
+
|
|
1017
|
+
Handles both scipy sparse matrices and already-dense arrays uniformly.
|
|
1018
|
+
This is THE single function for sparse-to-dense conversion across ChatSpatial.
|
|
1019
|
+
|
|
1020
|
+
Args:
|
|
1021
|
+
X: Expression matrix (sparse or dense)
|
|
1022
|
+
copy: If True, always return a copy (safe for modification).
|
|
1023
|
+
If False (default), may return view for dense input (read-only use).
|
|
1024
|
+
|
|
1025
|
+
Returns:
|
|
1026
|
+
Dense numpy array
|
|
1027
|
+
|
|
1028
|
+
Note:
|
|
1029
|
+
- Sparse input: Always returns a new array (toarray() creates copy)
|
|
1030
|
+
- Dense input with copy=False: May return view (no memory overhead)
|
|
1031
|
+
- Dense input with copy=True: Always returns independent copy
|
|
1032
|
+
|
|
1033
|
+
Examples:
|
|
1034
|
+
# Read-only use (default, memory efficient)
|
|
1035
|
+
dense_X = to_dense(adata.X)
|
|
1036
|
+
|
|
1037
|
+
# When you need to modify the result
|
|
1038
|
+
dense_X = to_dense(adata.X, copy=True)
|
|
1039
|
+
dense_X[0, 0] = 999 # Safe, won't affect original
|
|
1040
|
+
"""
|
|
1041
|
+
if sparse.issparse(X):
|
|
1042
|
+
return X.toarray()
|
|
1043
|
+
# For dense: np.array with copy=False may still copy if needed (e.g., non-contiguous)
|
|
1044
|
+
# np.array with copy=True always copies
|
|
1045
|
+
return np.array(X, copy=copy)
|
|
1046
|
+
|
|
1047
|
+
|
|
1048
|
+
def get_gene_expression(
|
|
1049
|
+
adata: "ad.AnnData",
|
|
1050
|
+
gene: str,
|
|
1051
|
+
layer: Optional[str] = None,
|
|
1052
|
+
) -> np.ndarray:
|
|
1053
|
+
"""
|
|
1054
|
+
Get expression values of a single gene as 1D array.
|
|
1055
|
+
|
|
1056
|
+
This is THE single function for extracting single-gene expression.
|
|
1057
|
+
Replaces 12+ duplicated code patterns across the codebase.
|
|
1058
|
+
|
|
1059
|
+
Args:
|
|
1060
|
+
adata: AnnData object
|
|
1061
|
+
gene: Gene name (must exist in adata.var_names)
|
|
1062
|
+
layer: Optional layer name. If None, uses adata.X
|
|
1063
|
+
|
|
1064
|
+
Returns:
|
|
1065
|
+
1D numpy array of expression values (length = n_obs)
|
|
1066
|
+
|
|
1067
|
+
Raises:
|
|
1068
|
+
DataError: If gene not found in adata
|
|
1069
|
+
|
|
1070
|
+
Examples:
|
|
1071
|
+
# Basic usage
|
|
1072
|
+
cd4_expr = get_gene_expression(adata, "CD4")
|
|
1073
|
+
|
|
1074
|
+
# From specific layer
|
|
1075
|
+
counts = get_gene_expression(adata, "CD4", layer="counts")
|
|
1076
|
+
|
|
1077
|
+
# Use in visualization
|
|
1078
|
+
adata.obs["_temp_expr"] = get_gene_expression(adata, gene)
|
|
1079
|
+
"""
|
|
1080
|
+
if gene not in adata.var_names:
|
|
1081
|
+
raise DataError(
|
|
1082
|
+
f"Gene '{gene}' not found in data. "
|
|
1083
|
+
f"Available genes (first 5): {adata.var_names[:5].tolist()}"
|
|
1084
|
+
)
|
|
1085
|
+
|
|
1086
|
+
if layer is not None:
|
|
1087
|
+
if layer not in adata.layers:
|
|
1088
|
+
raise DataError(
|
|
1089
|
+
f"Layer '{layer}' not found. Available: {list(adata.layers.keys())}"
|
|
1090
|
+
)
|
|
1091
|
+
gene_idx = adata.var_names.get_loc(gene)
|
|
1092
|
+
X = adata.layers[layer][:, gene_idx]
|
|
1093
|
+
else:
|
|
1094
|
+
X = adata[:, gene].X
|
|
1095
|
+
|
|
1096
|
+
return to_dense(X).flatten()
|
|
1097
|
+
|
|
1098
|
+
|
|
1099
|
+
def get_genes_expression(
|
|
1100
|
+
adata: "ad.AnnData",
|
|
1101
|
+
genes: list[str],
|
|
1102
|
+
layer: Optional[str] = None,
|
|
1103
|
+
) -> np.ndarray:
|
|
1104
|
+
"""
|
|
1105
|
+
Get expression values of multiple genes as 2D array.
|
|
1106
|
+
|
|
1107
|
+
Args:
|
|
1108
|
+
adata: AnnData object
|
|
1109
|
+
genes: List of gene names (must exist in adata.var_names)
|
|
1110
|
+
layer: Optional layer name. If None, uses adata.X
|
|
1111
|
+
|
|
1112
|
+
Returns:
|
|
1113
|
+
2D numpy array of shape (n_obs, n_genes)
|
|
1114
|
+
|
|
1115
|
+
Raises:
|
|
1116
|
+
DataError: If any gene not found in adata
|
|
1117
|
+
|
|
1118
|
+
Examples:
|
|
1119
|
+
# Get expression matrix for heatmap
|
|
1120
|
+
expr_matrix = get_genes_expression(adata, ["CD4", "CD8A", "CD3D"])
|
|
1121
|
+
|
|
1122
|
+
# From counts layer
|
|
1123
|
+
counts = get_genes_expression(adata, marker_genes, layer="counts")
|
|
1124
|
+
"""
|
|
1125
|
+
# Validate genes
|
|
1126
|
+
missing = [g for g in genes if g not in adata.var_names]
|
|
1127
|
+
if missing:
|
|
1128
|
+
raise DataError(
|
|
1129
|
+
f"Genes not found: {missing[:5]}{'...' if len(missing) > 5 else ''}. "
|
|
1130
|
+
f"Available genes (first 5): {adata.var_names[:5].tolist()}"
|
|
1131
|
+
)
|
|
1132
|
+
|
|
1133
|
+
if layer is not None:
|
|
1134
|
+
if layer not in adata.layers:
|
|
1135
|
+
raise DataError(
|
|
1136
|
+
f"Layer '{layer}' not found. Available: {list(adata.layers.keys())}"
|
|
1137
|
+
)
|
|
1138
|
+
gene_indices = [adata.var_names.get_loc(g) for g in genes]
|
|
1139
|
+
X = adata.layers[layer][:, gene_indices]
|
|
1140
|
+
else:
|
|
1141
|
+
X = adata[:, genes].X
|
|
1142
|
+
|
|
1143
|
+
return to_dense(X)
|
|
1144
|
+
|
|
1145
|
+
|
|
1146
|
+
# =============================================================================
|
|
1147
|
+
# Metadata Profiling: Extract structure information for LLM understanding
|
|
1148
|
+
# =============================================================================
|
|
1149
|
+
def get_column_profile(
|
|
1150
|
+
adata: "ad.AnnData", layer: Literal["obs", "var"] = "obs"
|
|
1151
|
+
) -> list[dict[str, Any]]:
|
|
1152
|
+
"""
|
|
1153
|
+
Get metadata column profile for obs or var.
|
|
1154
|
+
|
|
1155
|
+
Returns detailed information about each column to help LLM understand the data.
|
|
1156
|
+
|
|
1157
|
+
Args:
|
|
1158
|
+
adata: AnnData object
|
|
1159
|
+
layer: Which layer to profile ("obs" or "var")
|
|
1160
|
+
|
|
1161
|
+
Returns:
|
|
1162
|
+
List of column information dictionaries with keys:
|
|
1163
|
+
- name: Column name
|
|
1164
|
+
- dtype: "numerical" or "categorical"
|
|
1165
|
+
- n_unique: Number of unique values
|
|
1166
|
+
- range: (min, max) for numerical columns, None for categorical
|
|
1167
|
+
- sample_values: Sample values for categorical columns, None for numerical
|
|
1168
|
+
"""
|
|
1169
|
+
df = adata.obs if layer == "obs" else adata.var
|
|
1170
|
+
profiles = []
|
|
1171
|
+
|
|
1172
|
+
for col in df.columns:
|
|
1173
|
+
col_data = df[col]
|
|
1174
|
+
|
|
1175
|
+
# Determine if numeric
|
|
1176
|
+
is_numeric = pd.api.types.is_numeric_dtype(col_data)
|
|
1177
|
+
|
|
1178
|
+
if is_numeric:
|
|
1179
|
+
# Numerical column
|
|
1180
|
+
profiles.append(
|
|
1181
|
+
{
|
|
1182
|
+
"name": col,
|
|
1183
|
+
"dtype": "numerical",
|
|
1184
|
+
"n_unique": int(col_data.nunique()),
|
|
1185
|
+
"range": (float(col_data.min()), float(col_data.max())),
|
|
1186
|
+
"sample_values": None,
|
|
1187
|
+
}
|
|
1188
|
+
)
|
|
1189
|
+
else:
|
|
1190
|
+
# Categorical column
|
|
1191
|
+
unique_vals = col_data.unique()
|
|
1192
|
+
n_unique = len(unique_vals)
|
|
1193
|
+
|
|
1194
|
+
# Take first 5 sample values (or 3 if too many unique values)
|
|
1195
|
+
if n_unique <= 100:
|
|
1196
|
+
sample_vals = unique_vals[:5].tolist()
|
|
1197
|
+
else:
|
|
1198
|
+
sample_vals = unique_vals[:3].tolist()
|
|
1199
|
+
|
|
1200
|
+
profiles.append(
|
|
1201
|
+
{
|
|
1202
|
+
"name": col,
|
|
1203
|
+
"dtype": "categorical",
|
|
1204
|
+
"n_unique": n_unique,
|
|
1205
|
+
"sample_values": [str(v) for v in sample_vals],
|
|
1206
|
+
"range": None,
|
|
1207
|
+
}
|
|
1208
|
+
)
|
|
1209
|
+
|
|
1210
|
+
return profiles
|
|
1211
|
+
|
|
1212
|
+
|
|
1213
|
+
def get_gene_profile(
|
|
1214
|
+
adata: "ad.AnnData",
|
|
1215
|
+
) -> tuple[Optional[list[str]], list[str]]:
|
|
1216
|
+
"""
|
|
1217
|
+
Get gene expression profile including HVGs and top expressed genes.
|
|
1218
|
+
|
|
1219
|
+
Args:
|
|
1220
|
+
adata: AnnData object
|
|
1221
|
+
|
|
1222
|
+
Returns:
|
|
1223
|
+
Tuple of (top_highly_variable_genes, top_expressed_genes)
|
|
1224
|
+
- top_highly_variable_genes: List of HVG names or None if not computed
|
|
1225
|
+
- top_expressed_genes: List of top 10 expressed gene names
|
|
1226
|
+
"""
|
|
1227
|
+
# Highly variable genes (no fallback - only return if precomputed)
|
|
1228
|
+
hvg_list = get_highly_variable_genes(
|
|
1229
|
+
adata, max_genes=10, fallback_to_variance=False
|
|
1230
|
+
)
|
|
1231
|
+
top_hvg = hvg_list if hvg_list else None
|
|
1232
|
+
|
|
1233
|
+
# Top expressed genes
|
|
1234
|
+
try:
|
|
1235
|
+
mean_expr = np.array(adata.X.mean(axis=0)).flatten()
|
|
1236
|
+
top_idx = np.argsort(mean_expr)[-10:][::-1] # Descending order
|
|
1237
|
+
top_expr = adata.var_names[top_idx].tolist()
|
|
1238
|
+
except Exception:
|
|
1239
|
+
top_expr = adata.var_names[:10].tolist() # Fallback
|
|
1240
|
+
|
|
1241
|
+
return top_hvg, top_expr
|
|
1242
|
+
|
|
1243
|
+
|
|
1244
|
+
def get_adata_profile(adata: "ad.AnnData") -> dict[str, Any]:
|
|
1245
|
+
"""
|
|
1246
|
+
Get comprehensive metadata profile for LLM understanding.
|
|
1247
|
+
|
|
1248
|
+
This is the main function for extracting dataset information that helps
|
|
1249
|
+
LLM make informed analysis decisions.
|
|
1250
|
+
|
|
1251
|
+
Args:
|
|
1252
|
+
adata: AnnData object
|
|
1253
|
+
|
|
1254
|
+
Returns:
|
|
1255
|
+
Dictionary containing:
|
|
1256
|
+
- obs_columns: Profile of observation metadata columns
|
|
1257
|
+
- var_columns: Profile of variable metadata columns
|
|
1258
|
+
- obsm_keys: List of keys in obsm (embeddings, coordinates)
|
|
1259
|
+
- uns_keys: List of keys in uns (unstructured annotations)
|
|
1260
|
+
- top_highly_variable_genes: Top HVGs if computed
|
|
1261
|
+
- top_expressed_genes: Top expressed genes
|
|
1262
|
+
"""
|
|
1263
|
+
# Get column profiles
|
|
1264
|
+
obs_profile = get_column_profile(adata, layer="obs")
|
|
1265
|
+
var_profile = get_column_profile(adata, layer="var")
|
|
1266
|
+
|
|
1267
|
+
# Get gene profiles
|
|
1268
|
+
top_hvg, top_expr = get_gene_profile(adata)
|
|
1269
|
+
|
|
1270
|
+
# Get multi-dimensional data keys
|
|
1271
|
+
obsm_keys = list(adata.obsm.keys()) if hasattr(adata, "obsm") else []
|
|
1272
|
+
uns_keys = list(adata.uns.keys()) if hasattr(adata, "uns") else []
|
|
1273
|
+
|
|
1274
|
+
return {
|
|
1275
|
+
"obs_columns": obs_profile,
|
|
1276
|
+
"var_columns": var_profile,
|
|
1277
|
+
"obsm_keys": obsm_keys,
|
|
1278
|
+
"uns_keys": uns_keys,
|
|
1279
|
+
"top_highly_variable_genes": top_hvg,
|
|
1280
|
+
"top_expressed_genes": top_expr,
|
|
1281
|
+
}
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
# =============================================================================
|
|
1285
|
+
# Gene Overlap: Find and validate common genes between datasets
|
|
1286
|
+
# =============================================================================
|
|
1287
|
+
def find_common_genes(*gene_collections: Any) -> list[str]:
|
|
1288
|
+
"""
|
|
1289
|
+
Find common genes across multiple gene collections.
|
|
1290
|
+
|
|
1291
|
+
This is THE single function for computing gene intersections across ChatSpatial.
|
|
1292
|
+
Supports any number of gene collections (2 or more).
|
|
1293
|
+
|
|
1294
|
+
Args:
|
|
1295
|
+
*gene_collections: Two or more gene collections. Each can be:
|
|
1296
|
+
- List[str]: Gene name list
|
|
1297
|
+
- pd.Index: AnnData var_names
|
|
1298
|
+
- Any Iterable[str]: Will be converted to set
|
|
1299
|
+
|
|
1300
|
+
Returns:
|
|
1301
|
+
List of common gene names (order not guaranteed)
|
|
1302
|
+
|
|
1303
|
+
Raises:
|
|
1304
|
+
ValueError: If fewer than 2 collections provided
|
|
1305
|
+
|
|
1306
|
+
Examples:
|
|
1307
|
+
# Between two AnnData objects
|
|
1308
|
+
common = find_common_genes(adata1.var_names, adata2.var_names)
|
|
1309
|
+
|
|
1310
|
+
# Multiple datasets (e.g., spatial registration)
|
|
1311
|
+
common = find_common_genes(
|
|
1312
|
+
adata1.var_names, adata2.var_names, adata3.var_names
|
|
1313
|
+
)
|
|
1314
|
+
|
|
1315
|
+
# With explicit lists
|
|
1316
|
+
common = find_common_genes(["GeneA", "GeneB"], ["GeneB", "GeneC"])
|
|
1317
|
+
"""
|
|
1318
|
+
if len(gene_collections) < 2:
|
|
1319
|
+
raise ValueError("find_common_genes requires at least 2 gene collections")
|
|
1320
|
+
|
|
1321
|
+
# Convert first collection to set
|
|
1322
|
+
result = set(gene_collections[0])
|
|
1323
|
+
|
|
1324
|
+
# Intersect with remaining collections
|
|
1325
|
+
for genes in gene_collections[1:]:
|
|
1326
|
+
result &= set(genes)
|
|
1327
|
+
|
|
1328
|
+
return list(result)
|
|
1329
|
+
|
|
1330
|
+
|
|
1331
|
+
def validate_gene_overlap(
|
|
1332
|
+
common_genes: list[str],
|
|
1333
|
+
source_n_genes: int,
|
|
1334
|
+
target_n_genes: int,
|
|
1335
|
+
min_genes: int = 100,
|
|
1336
|
+
source_name: str = "source",
|
|
1337
|
+
target_name: str = "target",
|
|
1338
|
+
) -> None:
|
|
1339
|
+
"""
|
|
1340
|
+
Validate that gene overlap meets minimum requirements.
|
|
1341
|
+
|
|
1342
|
+
This is THE single validation function for gene overlap across ChatSpatial.
|
|
1343
|
+
Moved from deconvolution.py._validate_common_genes for reuse.
|
|
1344
|
+
|
|
1345
|
+
Args:
|
|
1346
|
+
common_genes: List of common gene names
|
|
1347
|
+
source_n_genes: Number of genes in source data
|
|
1348
|
+
target_n_genes: Number of genes in target data
|
|
1349
|
+
min_genes: Minimum required common genes (default: 100)
|
|
1350
|
+
source_name: Name of source data for error messages
|
|
1351
|
+
target_name: Name of target data for error messages
|
|
1352
|
+
|
|
1353
|
+
Raises:
|
|
1354
|
+
DataError: If insufficient common genes
|
|
1355
|
+
|
|
1356
|
+
Examples:
|
|
1357
|
+
# Basic validation
|
|
1358
|
+
common = find_common_genes(spatial.var_names, reference.var_names)
|
|
1359
|
+
validate_gene_overlap(common, spatial.n_vars, reference.n_vars)
|
|
1360
|
+
|
|
1361
|
+
# With custom threshold and names
|
|
1362
|
+
validate_gene_overlap(
|
|
1363
|
+
common, spatial.n_vars, reference.n_vars,
|
|
1364
|
+
min_genes=50, source_name="spatial", target_name="reference"
|
|
1365
|
+
)
|
|
1366
|
+
"""
|
|
1367
|
+
if len(common_genes) < min_genes:
|
|
1368
|
+
raise DataError(
|
|
1369
|
+
f"Insufficient gene overlap: {len(common_genes)} < {min_genes} required. "
|
|
1370
|
+
f"{source_name}: {source_n_genes} genes, {target_name}: {target_n_genes} genes. "
|
|
1371
|
+
f"Check species/gene naming convention match."
|
|
1372
|
+
)
|