chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1903 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Cell type annotation tools for spatial transcriptomics data.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
|
|
7
|
+
import hashlib
|
|
8
|
+
import json
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import TYPE_CHECKING, Any, NamedTuple, Optional
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import pandas as pd
|
|
14
|
+
import scanpy as sc
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from ..spatial_mcp_adapter import ToolContext
|
|
18
|
+
|
|
19
|
+
from ..models.analysis import AnnotationResult
|
|
20
|
+
from ..models.data import AnnotationParameters
|
|
21
|
+
from ..utils.adata_utils import (
|
|
22
|
+
ensure_categorical,
|
|
23
|
+
ensure_counts_layer,
|
|
24
|
+
ensure_unique_var_names_async,
|
|
25
|
+
find_common_genes,
|
|
26
|
+
get_cell_type_key,
|
|
27
|
+
get_cluster_key,
|
|
28
|
+
get_spatial_key,
|
|
29
|
+
to_dense,
|
|
30
|
+
validate_obs_column,
|
|
31
|
+
)
|
|
32
|
+
from ..utils.dependency_manager import (
|
|
33
|
+
is_available,
|
|
34
|
+
require,
|
|
35
|
+
validate_r_environment,
|
|
36
|
+
validate_scvi_tools,
|
|
37
|
+
)
|
|
38
|
+
from ..utils.device_utils import cuda_available
|
|
39
|
+
from ..utils.exceptions import (
|
|
40
|
+
DataError,
|
|
41
|
+
DataNotFoundError,
|
|
42
|
+
ParameterError,
|
|
43
|
+
ProcessingError,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class AnnotationMethodOutput(NamedTuple):
|
|
48
|
+
"""Unified output from all annotation methods.
|
|
49
|
+
|
|
50
|
+
This provides a consistent return type across all annotation methods,
|
|
51
|
+
improving code clarity and preventing positional argument confusion.
|
|
52
|
+
|
|
53
|
+
Attributes:
|
|
54
|
+
cell_types: List of unique cell type names identified (deduplicated)
|
|
55
|
+
counts: Mapping of cell type names to number of cells assigned
|
|
56
|
+
confidence: Mapping of cell type names to confidence scores.
|
|
57
|
+
Empty dict indicates no confidence data available.
|
|
58
|
+
mapping_score: Optional method-specific quality score (e.g., Tangram mapping score)
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
cell_types: list[str]
|
|
62
|
+
counts: dict[str, int]
|
|
63
|
+
confidence: dict[str, float]
|
|
64
|
+
mapping_score: Optional[float] = None
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
# Supported annotation methods
|
|
68
|
+
# Confidence behavior by method:
|
|
69
|
+
# - singler/tangram/sctype: Real confidence scores (correlation/probability/scoring)
|
|
70
|
+
# - scanvi/cellassign: Partial confidence (when soft prediction available)
|
|
71
|
+
# - mllmcelltype: No numeric confidence (LLM-based)
|
|
72
|
+
SUPPORTED_METHODS = {
|
|
73
|
+
"tangram",
|
|
74
|
+
"scanvi",
|
|
75
|
+
"cellassign",
|
|
76
|
+
"mllmcelltype",
|
|
77
|
+
"sctype",
|
|
78
|
+
"singler",
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
async def _annotate_with_singler(
|
|
83
|
+
adata,
|
|
84
|
+
params: AnnotationParameters,
|
|
85
|
+
ctx: "ToolContext",
|
|
86
|
+
output_key: str,
|
|
87
|
+
confidence_key: str,
|
|
88
|
+
reference_adata: Optional[Any] = None,
|
|
89
|
+
) -> AnnotationMethodOutput:
|
|
90
|
+
"""Annotate cell types using SingleR reference-based method"""
|
|
91
|
+
# Validate and import dependencies
|
|
92
|
+
require("singler", ctx, feature="SingleR annotation")
|
|
93
|
+
require("singlecellexperiment", ctx, feature="SingleR annotation")
|
|
94
|
+
import singler
|
|
95
|
+
|
|
96
|
+
# Optional: check for celldex
|
|
97
|
+
celldex = None
|
|
98
|
+
if is_available("celldex"):
|
|
99
|
+
import celldex
|
|
100
|
+
|
|
101
|
+
# Get expression matrix - prefer normalized data
|
|
102
|
+
# IMPORTANT: Ensure test_mat dimensions match adata.var_names (used in test_features)
|
|
103
|
+
if "X_normalized" in adata.layers:
|
|
104
|
+
test_mat = adata.layers["X_normalized"]
|
|
105
|
+
elif adata.X is not None:
|
|
106
|
+
test_mat = adata.X
|
|
107
|
+
else:
|
|
108
|
+
# Fallback: use raw data, but subset to current var_names to ensure dimension match
|
|
109
|
+
# Note: adata.raw may have full genes while adata has HVG subset
|
|
110
|
+
if adata.raw is not None:
|
|
111
|
+
test_mat = adata.raw[:, adata.var_names].X
|
|
112
|
+
else:
|
|
113
|
+
test_mat = adata.X
|
|
114
|
+
|
|
115
|
+
# MEMORY OPTIMIZATION: SingleR (singler-py) natively supports sparse matrices
|
|
116
|
+
# No toarray() needed - both np.log1p() and .T() work with sparse matrices
|
|
117
|
+
# Verified: sparse and dense inputs produce identical results
|
|
118
|
+
# Memory savings: ~1.3 GB for typical 10K cells × 20K genes dataset
|
|
119
|
+
|
|
120
|
+
# Ensure log-normalization (SingleR expects log-normalized data)
|
|
121
|
+
if "log1p" not in adata.uns:
|
|
122
|
+
await ctx.warning(
|
|
123
|
+
"Data may not be log-normalized. Applying log1p for SingleR..."
|
|
124
|
+
)
|
|
125
|
+
test_mat = np.log1p(test_mat)
|
|
126
|
+
|
|
127
|
+
# Transpose for SingleR (genes x cells)
|
|
128
|
+
test_mat = test_mat.T
|
|
129
|
+
|
|
130
|
+
# Ensure gene names are strings
|
|
131
|
+
test_features = [str(x) for x in adata.var_names]
|
|
132
|
+
|
|
133
|
+
# Prepare reference
|
|
134
|
+
reference_name = getattr(params, "singler_reference", None)
|
|
135
|
+
reference_data_id = getattr(params, "reference_data_id", None)
|
|
136
|
+
|
|
137
|
+
ref_data = None
|
|
138
|
+
ref_labels = None
|
|
139
|
+
ref_features_to_use = None # Only set when using custom reference (not celldex)
|
|
140
|
+
|
|
141
|
+
# Priority: reference_name > reference_data_id > default
|
|
142
|
+
if reference_name and celldex:
|
|
143
|
+
ref = celldex.fetch_reference(reference_name, "2024-02-26", realize_assays=True)
|
|
144
|
+
# Get labels
|
|
145
|
+
for label_col in ["label.main", "label.fine", "cell_type"]:
|
|
146
|
+
try:
|
|
147
|
+
ref_labels = ref.get_column_data().column(label_col)
|
|
148
|
+
break
|
|
149
|
+
except Exception:
|
|
150
|
+
continue # Try next label column
|
|
151
|
+
if ref_labels is None:
|
|
152
|
+
raise DataNotFoundError(
|
|
153
|
+
f"Could not find labels in reference {reference_name}"
|
|
154
|
+
)
|
|
155
|
+
ref_data = ref
|
|
156
|
+
|
|
157
|
+
elif reference_data_id and reference_adata is not None:
|
|
158
|
+
# Use provided reference data (passed from main function via ctx.get_adata())
|
|
159
|
+
# Handle duplicate gene names
|
|
160
|
+
await ensure_unique_var_names_async(reference_adata, ctx, "reference data")
|
|
161
|
+
if await ensure_unique_var_names_async(adata, ctx, "query data") > 0:
|
|
162
|
+
# Update test_features after fixing
|
|
163
|
+
test_features = [str(x) for x in adata.var_names]
|
|
164
|
+
|
|
165
|
+
# Get reference expression matrix
|
|
166
|
+
if "X_normalized" in reference_adata.layers:
|
|
167
|
+
ref_mat = reference_adata.layers["X_normalized"]
|
|
168
|
+
else:
|
|
169
|
+
ref_mat = reference_adata.X
|
|
170
|
+
|
|
171
|
+
# MEMORY OPTIMIZATION: SingleR (singler-py) natively supports sparse matrices
|
|
172
|
+
# No toarray() needed - both np.log1p() and .T() work with sparse matrices
|
|
173
|
+
# Verified: sparse and dense inputs produce identical results
|
|
174
|
+
# Memory savings: ~1.3 GB for typical 10K cells × 20K genes reference dataset
|
|
175
|
+
|
|
176
|
+
# Ensure log-normalization for reference
|
|
177
|
+
if "log1p" not in reference_adata.uns:
|
|
178
|
+
await ctx.warning(
|
|
179
|
+
"Reference data may not be log-normalized. Applying log1p..."
|
|
180
|
+
)
|
|
181
|
+
ref_mat = np.log1p(ref_mat)
|
|
182
|
+
|
|
183
|
+
# Transpose for SingleR (genes x cells)
|
|
184
|
+
ref_mat = ref_mat.T
|
|
185
|
+
ref_features = [str(x) for x in reference_adata.var_names]
|
|
186
|
+
|
|
187
|
+
# Check gene overlap
|
|
188
|
+
common_genes = find_common_genes(test_features, ref_features)
|
|
189
|
+
|
|
190
|
+
if len(common_genes) < min(50, len(test_features) * 0.1):
|
|
191
|
+
raise DataError(
|
|
192
|
+
f"Insufficient gene overlap for SingleR: only {len(common_genes)} common genes "
|
|
193
|
+
f"(test: {len(test_features)}, reference: {len(ref_features)})"
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# Get labels from reference - check various common column names
|
|
197
|
+
# cell_type_key is now required (no default value)
|
|
198
|
+
cell_type_key = params.cell_type_key
|
|
199
|
+
|
|
200
|
+
validate_obs_column(reference_adata, cell_type_key, "Cell type")
|
|
201
|
+
|
|
202
|
+
ref_labels = list(reference_adata.obs[cell_type_key])
|
|
203
|
+
|
|
204
|
+
# For SingleR, pass the actual expression matrix directly (not SCE)
|
|
205
|
+
# This has been shown to work better in testing
|
|
206
|
+
ref_data = ref_mat
|
|
207
|
+
ref_features_to_use = ref_features # Keep reference features for gene matching
|
|
208
|
+
|
|
209
|
+
elif celldex:
|
|
210
|
+
# Use default reference
|
|
211
|
+
ref = celldex.fetch_reference(
|
|
212
|
+
"blueprint_encode", "2024-02-26", realize_assays=True
|
|
213
|
+
)
|
|
214
|
+
ref_labels = ref.get_column_data().column("label.main")
|
|
215
|
+
ref_data = ref
|
|
216
|
+
else:
|
|
217
|
+
raise DataNotFoundError(
|
|
218
|
+
"No reference data. Provide reference_data_id or singler_reference."
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Run SingleR annotation
|
|
222
|
+
use_integrated = getattr(params, "singler_integrated", False)
|
|
223
|
+
num_threads = getattr(params, "num_threads", 4)
|
|
224
|
+
|
|
225
|
+
if use_integrated and isinstance(ref_data, list):
|
|
226
|
+
single_results, integrated = singler.annotate_integrated(
|
|
227
|
+
test_mat,
|
|
228
|
+
ref_data=ref_data,
|
|
229
|
+
ref_labels=ref_labels,
|
|
230
|
+
test_features=test_features,
|
|
231
|
+
num_threads=num_threads,
|
|
232
|
+
)
|
|
233
|
+
best_labels = integrated.column("best_label")
|
|
234
|
+
scores = integrated.column("scores")
|
|
235
|
+
else:
|
|
236
|
+
# Build kwargs for annotate_single
|
|
237
|
+
annotate_kwargs = {
|
|
238
|
+
"test_data": test_mat,
|
|
239
|
+
"test_features": test_features,
|
|
240
|
+
"ref_data": ref_data,
|
|
241
|
+
"ref_labels": ref_labels,
|
|
242
|
+
"num_threads": num_threads,
|
|
243
|
+
}
|
|
244
|
+
|
|
245
|
+
# Add ref_features if we're using custom reference data (not celldex)
|
|
246
|
+
if ref_features_to_use is not None:
|
|
247
|
+
annotate_kwargs["ref_features"] = ref_features_to_use
|
|
248
|
+
|
|
249
|
+
results = singler.annotate_single(**annotate_kwargs)
|
|
250
|
+
best_labels = results.column("best")
|
|
251
|
+
scores = results.column("scores")
|
|
252
|
+
|
|
253
|
+
# Try to get delta scores for confidence (higher delta = higher confidence)
|
|
254
|
+
try:
|
|
255
|
+
delta_scores = results.column("delta")
|
|
256
|
+
if delta_scores:
|
|
257
|
+
low_delta = sum(1 for d in delta_scores if d and d < 0.05)
|
|
258
|
+
if low_delta > len(delta_scores) * 0.3:
|
|
259
|
+
await ctx.warning(
|
|
260
|
+
f"{low_delta}/{len(delta_scores)} cells have low confidence scores (delta < 0.05)"
|
|
261
|
+
)
|
|
262
|
+
except Exception:
|
|
263
|
+
delta_scores = None
|
|
264
|
+
|
|
265
|
+
# Process results
|
|
266
|
+
cell_types = list(best_labels)
|
|
267
|
+
unique_types = list(set(cell_types))
|
|
268
|
+
counts = pd.Series(cell_types).value_counts().to_dict()
|
|
269
|
+
|
|
270
|
+
# Calculate confidence scores - prefer delta scores if available
|
|
271
|
+
# IMPORTANT: Different scores have different mathematical semantics
|
|
272
|
+
# - delta scores: gap between best and second-best match, range [0, +∞)
|
|
273
|
+
# - correlation scores: Pearson correlation, range [-1, 1]
|
|
274
|
+
# We apply scientifically appropriate transformations to [0, 1]
|
|
275
|
+
confidence_scores = {}
|
|
276
|
+
|
|
277
|
+
# First try to use delta scores (more meaningful confidence measure)
|
|
278
|
+
# delta_scores is always defined by the try-except block above (line 429 or 437)
|
|
279
|
+
if delta_scores is not None:
|
|
280
|
+
try:
|
|
281
|
+
for cell_type in unique_types:
|
|
282
|
+
type_indices = [i for i, ct in enumerate(cell_types) if ct == cell_type]
|
|
283
|
+
if type_indices:
|
|
284
|
+
type_deltas = [
|
|
285
|
+
delta_scores[i] for i in type_indices if i < len(delta_scores)
|
|
286
|
+
]
|
|
287
|
+
if type_deltas:
|
|
288
|
+
avg_delta = np.mean([d for d in type_deltas if d is not None])
|
|
289
|
+
# Transform delta to [0, 1] using saturating function
|
|
290
|
+
# delta=0 → 0 (no discrimination = zero confidence)
|
|
291
|
+
# delta→∞ → 1 (perfect discrimination = full confidence)
|
|
292
|
+
confidence = 1.0 - np.exp(-avg_delta)
|
|
293
|
+
confidence_scores[cell_type] = round(float(confidence), 3)
|
|
294
|
+
except Exception:
|
|
295
|
+
# Delta score extraction failed, will fall back to regular scores
|
|
296
|
+
pass
|
|
297
|
+
|
|
298
|
+
# Fall back to correlation scores if delta not available
|
|
299
|
+
if not confidence_scores and scores is not None:
|
|
300
|
+
try:
|
|
301
|
+
scores_df = pd.DataFrame(scores.to_dict())
|
|
302
|
+
except AttributeError:
|
|
303
|
+
scores_df = pd.DataFrame(
|
|
304
|
+
scores.to_numpy() if hasattr(scores, "to_numpy") else scores
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
for cell_type in unique_types:
|
|
308
|
+
mask = [ct == cell_type for ct in cell_types]
|
|
309
|
+
if cell_type in scores_df.columns and any(mask):
|
|
310
|
+
type_scores = scores_df.loc[mask, cell_type]
|
|
311
|
+
avg_score = type_scores.mean()
|
|
312
|
+
# Use max(0, r) instead of (r+1)/2 for correlation
|
|
313
|
+
# r<0 (negative correlation) → 0 (opposite pattern = not a match)
|
|
314
|
+
# r=0 → 0 (no correlation = zero confidence)
|
|
315
|
+
# r=1 → 1 (perfect correlation = full confidence)
|
|
316
|
+
confidence = max(0.0, float(avg_score))
|
|
317
|
+
confidence_scores[cell_type] = round(confidence, 3)
|
|
318
|
+
# else: cell type won't have confidence score (no action needed)
|
|
319
|
+
|
|
320
|
+
# Add to AnnData (keys provided by caller for single-point control)
|
|
321
|
+
adata.obs[output_key] = cell_types
|
|
322
|
+
ensure_categorical(adata, output_key)
|
|
323
|
+
|
|
324
|
+
# Only add confidence column if we have real confidence values
|
|
325
|
+
if confidence_scores:
|
|
326
|
+
# Use 0.0 for cells without confidence (more honest than arbitrary 0.5)
|
|
327
|
+
confidence_array = [confidence_scores.get(ct, 0.0) for ct in cell_types]
|
|
328
|
+
adata.obs[confidence_key] = confidence_array
|
|
329
|
+
|
|
330
|
+
return AnnotationMethodOutput(
|
|
331
|
+
cell_types=unique_types,
|
|
332
|
+
counts=counts,
|
|
333
|
+
confidence=confidence_scores,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
async def _annotate_with_tangram(
|
|
338
|
+
adata,
|
|
339
|
+
params: AnnotationParameters,
|
|
340
|
+
ctx: "ToolContext",
|
|
341
|
+
output_key: str,
|
|
342
|
+
confidence_key: str,
|
|
343
|
+
reference_adata: Optional[Any] = None,
|
|
344
|
+
) -> AnnotationMethodOutput:
|
|
345
|
+
"""Annotate cell types using Tangram method"""
|
|
346
|
+
# Validate dependencies with comprehensive error reporting
|
|
347
|
+
require("tangram", ctx, feature="Tangram annotation")
|
|
348
|
+
import tangram as tg
|
|
349
|
+
|
|
350
|
+
# Check if reference data is provided
|
|
351
|
+
if reference_adata is None:
|
|
352
|
+
raise ParameterError("Tangram requires reference_data_id parameter.")
|
|
353
|
+
|
|
354
|
+
# Use reference single-cell data (passed from main function via ctx.get_adata())
|
|
355
|
+
adata_sc_original = reference_adata
|
|
356
|
+
|
|
357
|
+
# ===== CRITICAL FIX: Use raw data for Tangram to preserve gene name case =====
|
|
358
|
+
# Issue: Preprocessed data may have lowercase gene names, while reference has uppercase
|
|
359
|
+
# This causes 0 overlapping genes and Tangram mapping failure (all NaN results)
|
|
360
|
+
# Solution: Use adata.raw which preserves original gene names
|
|
361
|
+
if adata.raw is not None:
|
|
362
|
+
adata_sp = adata.raw.to_adata()
|
|
363
|
+
# Preserve spatial coordinates from preprocessed data
|
|
364
|
+
spatial_key = get_spatial_key(adata)
|
|
365
|
+
if spatial_key:
|
|
366
|
+
adata_sp.obsm[spatial_key] = adata.obsm[spatial_key].copy()
|
|
367
|
+
else:
|
|
368
|
+
adata_sp = adata
|
|
369
|
+
await ctx.warning(
|
|
370
|
+
"No raw data available - using preprocessed data (may have gene name mismatches)"
|
|
371
|
+
)
|
|
372
|
+
# =============================================================================
|
|
373
|
+
|
|
374
|
+
# Handle duplicate gene names
|
|
375
|
+
await ensure_unique_var_names_async(adata_sc_original, ctx, "reference data")
|
|
376
|
+
await ensure_unique_var_names_async(adata_sp, ctx, "spatial data")
|
|
377
|
+
|
|
378
|
+
# Determine training genes
|
|
379
|
+
training_genes = params.training_genes
|
|
380
|
+
|
|
381
|
+
if training_genes is None:
|
|
382
|
+
# Use marker genes if available
|
|
383
|
+
if params.marker_genes:
|
|
384
|
+
# Flatten marker genes dictionary
|
|
385
|
+
training_genes = []
|
|
386
|
+
for genes in params.marker_genes.values():
|
|
387
|
+
training_genes.extend(genes)
|
|
388
|
+
training_genes = list(set(training_genes)) # Remove duplicates
|
|
389
|
+
else:
|
|
390
|
+
# Use highly variable genes
|
|
391
|
+
if "highly_variable" not in adata_sc_original.var:
|
|
392
|
+
raise DataNotFoundError(
|
|
393
|
+
"HVGs not found in reference data. Run preprocessing first."
|
|
394
|
+
)
|
|
395
|
+
training_genes = list(
|
|
396
|
+
adata_sc_original.var_names[adata_sc_original.var.highly_variable]
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
# COW FIX: Create copy of reference data to avoid modifying original
|
|
400
|
+
# Tangram's pp_adatas adds metadata (uns, obs) but doesn't subset genes
|
|
401
|
+
adata_sc = adata_sc_original.copy()
|
|
402
|
+
|
|
403
|
+
# Preprocess data for Tangram
|
|
404
|
+
tg.pp_adatas(adata_sc, adata_sp, genes=training_genes)
|
|
405
|
+
|
|
406
|
+
# Set mapping mode
|
|
407
|
+
mode = params.tangram_mode
|
|
408
|
+
cluster_label = params.cluster_label
|
|
409
|
+
|
|
410
|
+
if mode == "clusters" and cluster_label is None:
|
|
411
|
+
await ctx.warning(
|
|
412
|
+
"Cluster label not provided for 'clusters' mode. Using default cell type annotation if available."
|
|
413
|
+
)
|
|
414
|
+
# Try to find a cell type or cluster annotation in the reference data
|
|
415
|
+
cluster_label = get_cell_type_key(adata_sc) or get_cluster_key(adata_sc)
|
|
416
|
+
|
|
417
|
+
if cluster_label is None:
|
|
418
|
+
raise ParameterError(
|
|
419
|
+
"No cluster label found. Provide cluster_label parameter."
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
# Check GPU availability for device selection
|
|
423
|
+
device = params.tangram_device
|
|
424
|
+
if device != "cpu" and not cuda_available():
|
|
425
|
+
await ctx.warning("GPU requested but not available - falling back to CPU")
|
|
426
|
+
device = "cpu"
|
|
427
|
+
|
|
428
|
+
# Run Tangram mapping with enhanced parameters
|
|
429
|
+
mapping_args = {
|
|
430
|
+
"mode": mode,
|
|
431
|
+
"num_epochs": params.num_epochs,
|
|
432
|
+
"device": device,
|
|
433
|
+
"density_prior": params.tangram_density_prior, # Add density prior
|
|
434
|
+
"learning_rate": params.tangram_learning_rate, # Add learning rate
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
# Add optional regularization parameters
|
|
438
|
+
if params.tangram_lambda_r is not None:
|
|
439
|
+
mapping_args["lambda_r"] = params.tangram_lambda_r
|
|
440
|
+
|
|
441
|
+
if params.tangram_lambda_neighborhood is not None:
|
|
442
|
+
mapping_args["lambda_neighborhood"] = params.tangram_lambda_neighborhood
|
|
443
|
+
|
|
444
|
+
if mode == "clusters":
|
|
445
|
+
mapping_args["cluster_label"] = cluster_label
|
|
446
|
+
|
|
447
|
+
ad_map = tg.map_cells_to_space(adata_sc, adata_sp, **mapping_args)
|
|
448
|
+
|
|
449
|
+
# Get mapping score from training history
|
|
450
|
+
tangram_mapping_score = 0.0 # Default score
|
|
451
|
+
try:
|
|
452
|
+
if "training_history" in ad_map.uns:
|
|
453
|
+
history = ad_map.uns["training_history"]
|
|
454
|
+
|
|
455
|
+
# Extract score from main_loss (which is actually a similarity score, higher is better)
|
|
456
|
+
if (
|
|
457
|
+
isinstance(history, dict)
|
|
458
|
+
and "main_loss" in history
|
|
459
|
+
and len(history["main_loss"]) > 0
|
|
460
|
+
):
|
|
461
|
+
import re
|
|
462
|
+
|
|
463
|
+
last_value = history["main_loss"][-1]
|
|
464
|
+
|
|
465
|
+
# Extract value from tensor string if needed
|
|
466
|
+
if isinstance(last_value, str):
|
|
467
|
+
# Handle tensor string format: 'tensor(0.9050, grad_fn=...)'
|
|
468
|
+
match = re.search(r"tensor\(([-\d.]+)", last_value)
|
|
469
|
+
if match:
|
|
470
|
+
tangram_mapping_score = float(match.group(1))
|
|
471
|
+
else:
|
|
472
|
+
# Try direct conversion
|
|
473
|
+
try:
|
|
474
|
+
tangram_mapping_score = float(last_value)
|
|
475
|
+
except Exception:
|
|
476
|
+
tangram_mapping_score = 0.0
|
|
477
|
+
else:
|
|
478
|
+
tangram_mapping_score = float(last_value)
|
|
479
|
+
|
|
480
|
+
else:
|
|
481
|
+
error_msg = (
|
|
482
|
+
f"Tangram history format not recognized: {type(history).__name__}. "
|
|
483
|
+
f"Upgrade tangram-sc: pip install --upgrade tangram-sc"
|
|
484
|
+
)
|
|
485
|
+
raise ProcessingError(error_msg)
|
|
486
|
+
except Exception as score_error:
|
|
487
|
+
raise ProcessingError(
|
|
488
|
+
f"Tangram mapping completed but score extraction failed: {score_error}"
|
|
489
|
+
) from score_error
|
|
490
|
+
|
|
491
|
+
# Compute validation metrics if requested
|
|
492
|
+
if params.tangram_compute_validation:
|
|
493
|
+
try:
|
|
494
|
+
scores = tg.compare_spatial_geneexp(ad_map, adata_sp)
|
|
495
|
+
adata_sp.uns["tangram_validation_scores"] = scores
|
|
496
|
+
except Exception as val_error:
|
|
497
|
+
await ctx.warning(f"Could not compute validation metrics: {val_error}")
|
|
498
|
+
|
|
499
|
+
# Project genes if requested
|
|
500
|
+
if params.tangram_project_genes:
|
|
501
|
+
try:
|
|
502
|
+
ad_ge = tg.project_genes(ad_map, adata_sc)
|
|
503
|
+
adata_sp.obsm["tangram_gene_predictions"] = ad_ge.X
|
|
504
|
+
except Exception as gene_error:
|
|
505
|
+
await ctx.warning(f"Could not project genes: {gene_error}")
|
|
506
|
+
|
|
507
|
+
# Project cell annotations to space using proper API function
|
|
508
|
+
try:
|
|
509
|
+
# Determine annotation column
|
|
510
|
+
annotation_col = None
|
|
511
|
+
if mode == "clusters" and cluster_label:
|
|
512
|
+
annotation_col = cluster_label
|
|
513
|
+
else:
|
|
514
|
+
# cell_type_key is now required (no auto-detect)
|
|
515
|
+
if params.cell_type_key not in adata_sc.obs:
|
|
516
|
+
# Improved error message showing available columns
|
|
517
|
+
available_cols = list(adata_sc.obs.columns)
|
|
518
|
+
categorical_cols = [
|
|
519
|
+
col
|
|
520
|
+
for col in available_cols
|
|
521
|
+
if adata_sc.obs[col].dtype.name in ["object", "category"]
|
|
522
|
+
]
|
|
523
|
+
|
|
524
|
+
raise ParameterError(
|
|
525
|
+
f"Cell type column '{params.cell_type_key}' not found. "
|
|
526
|
+
f"Available: {categorical_cols[:5]}"
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
annotation_col = params.cell_type_key
|
|
530
|
+
|
|
531
|
+
# annotation_col is guaranteed to be set (either from cluster_label or cell_type_key)
|
|
532
|
+
tg.project_cell_annotations(ad_map, adata_sp, annotation=annotation_col)
|
|
533
|
+
except Exception as proj_error:
|
|
534
|
+
await ctx.warning(f"Could not project cell annotations: {proj_error}")
|
|
535
|
+
# Continue without projection
|
|
536
|
+
|
|
537
|
+
# Get cell type predictions (keys provided by caller for single-point control)
|
|
538
|
+
cell_types = []
|
|
539
|
+
counts = {}
|
|
540
|
+
confidence_scores = {}
|
|
541
|
+
|
|
542
|
+
if "tangram_ct_pred" in adata_sp.obsm:
|
|
543
|
+
cell_type_df = adata_sp.obsm["tangram_ct_pred"]
|
|
544
|
+
|
|
545
|
+
# Get cell types and counts
|
|
546
|
+
cell_types = list(cell_type_df.columns)
|
|
547
|
+
|
|
548
|
+
# ===== CRITICAL FIX: Row normalization for proper probability calculation =====
|
|
549
|
+
# tangram_ct_pred contains unnormalized density/abundance values, NOT probabilities
|
|
550
|
+
# Row sums can be != 1.0 and values can exceed 1.0
|
|
551
|
+
# We normalize to convert densities → probability distributions
|
|
552
|
+
cell_type_prob = cell_type_df.div(cell_type_df.sum(axis=1), axis=0)
|
|
553
|
+
|
|
554
|
+
# Validation: Ensure normalized values are valid probabilities
|
|
555
|
+
if not (cell_type_prob.values >= 0).all():
|
|
556
|
+
await ctx.warning(
|
|
557
|
+
"Some normalized probabilities are negative - data quality issue"
|
|
558
|
+
)
|
|
559
|
+
if not (cell_type_prob.values <= 1.0).all():
|
|
560
|
+
await ctx.warning(
|
|
561
|
+
"Some normalized probabilities exceed 1.0 - normalization failed"
|
|
562
|
+
)
|
|
563
|
+
if not np.allclose(cell_type_prob.sum(axis=1), 1.0):
|
|
564
|
+
await ctx.warning(
|
|
565
|
+
"Row sums don't equal 1.0 after normalization - numerical issue"
|
|
566
|
+
)
|
|
567
|
+
|
|
568
|
+
# Assign cell type based on highest probability (argmax is same before/after normalization)
|
|
569
|
+
adata_sp.obs[output_key] = cell_type_prob.idxmax(axis=1)
|
|
570
|
+
ensure_categorical(adata_sp, output_key)
|
|
571
|
+
|
|
572
|
+
# Get counts
|
|
573
|
+
counts = adata_sp.obs[output_key].value_counts().to_dict()
|
|
574
|
+
|
|
575
|
+
# Calculate confidence scores from NORMALIZED probabilities
|
|
576
|
+
confidence_scores = {}
|
|
577
|
+
for cell_type in cell_types:
|
|
578
|
+
cells_of_type = adata_sp.obs[output_key] == cell_type
|
|
579
|
+
if np.sum(cells_of_type) > 0:
|
|
580
|
+
# Use mean PROBABILITY as confidence (now guaranteed to be in [0, 1])
|
|
581
|
+
mean_prob = cell_type_prob.loc[cells_of_type, cell_type].mean()
|
|
582
|
+
confidence_scores[cell_type] = round(float(mean_prob), 3)
|
|
583
|
+
|
|
584
|
+
else:
|
|
585
|
+
await ctx.warning("No cell type predictions found in Tangram results")
|
|
586
|
+
|
|
587
|
+
# Validate results before returning
|
|
588
|
+
if not cell_types:
|
|
589
|
+
raise ProcessingError(
|
|
590
|
+
"Tangram mapping failed - no cell type predictions generated"
|
|
591
|
+
)
|
|
592
|
+
|
|
593
|
+
if tangram_mapping_score <= 0:
|
|
594
|
+
await ctx.warning(
|
|
595
|
+
f"Tangram mapping score is suspiciously low: {tangram_mapping_score}"
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
# ===== Copy results from adata_sp back to original adata =====
|
|
599
|
+
# Since adata_sp was created from adata.raw (different object), we need to
|
|
600
|
+
# transfer the Tangram results back to the original adata for downstream use
|
|
601
|
+
if adata_sp is not adata:
|
|
602
|
+
# Copy cell type assignments
|
|
603
|
+
if output_key in adata_sp.obs:
|
|
604
|
+
adata.obs[output_key] = adata_sp.obs[output_key]
|
|
605
|
+
|
|
606
|
+
# Copy tangram_ct_pred from obsm
|
|
607
|
+
if "tangram_ct_pred" in adata_sp.obsm:
|
|
608
|
+
adata.obsm["tangram_ct_pred"] = adata_sp.obsm["tangram_ct_pred"]
|
|
609
|
+
|
|
610
|
+
# Copy tangram_gene_predictions if they exist
|
|
611
|
+
if "tangram_gene_predictions" in adata_sp.obsm:
|
|
612
|
+
adata.obsm["tangram_gene_predictions"] = adata_sp.obsm[
|
|
613
|
+
"tangram_gene_predictions"
|
|
614
|
+
]
|
|
615
|
+
|
|
616
|
+
return AnnotationMethodOutput(
|
|
617
|
+
cell_types=cell_types,
|
|
618
|
+
counts=counts,
|
|
619
|
+
confidence=confidence_scores,
|
|
620
|
+
mapping_score=tangram_mapping_score,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
async def _annotate_with_scanvi(
|
|
625
|
+
adata,
|
|
626
|
+
params: AnnotationParameters,
|
|
627
|
+
ctx: "ToolContext",
|
|
628
|
+
output_key: str,
|
|
629
|
+
confidence_key: str,
|
|
630
|
+
reference_adata: Optional[Any] = None,
|
|
631
|
+
) -> AnnotationMethodOutput:
|
|
632
|
+
"""Annotate cell types using scANVI (semi-supervised variational inference).
|
|
633
|
+
|
|
634
|
+
scANVI (single-cell ANnotation using Variational Inference) is a deep learning
|
|
635
|
+
method for transferring cell type labels from reference to query data using
|
|
636
|
+
semi-supervised learning with variational autoencoders.
|
|
637
|
+
|
|
638
|
+
Official Implementation: scvi-tools (https://scvi-tools.org)
|
|
639
|
+
Reference: Xu et al. (2021) "Probabilistic harmonization and annotation of
|
|
640
|
+
single-cell transcriptomics data with deep generative models"
|
|
641
|
+
|
|
642
|
+
Method Overview:
|
|
643
|
+
1. Trains on reference data with known cell type labels
|
|
644
|
+
2. Learns shared latent representation between reference and query
|
|
645
|
+
3. Transfers labels to query data via probabilistic predictions
|
|
646
|
+
4. Supports batch correction and semi-supervised training
|
|
647
|
+
|
|
648
|
+
Requirements:
|
|
649
|
+
- reference_data_id: Must point to preprocessed single-cell reference data
|
|
650
|
+
- cell_type_key: Column in reference data containing cell type labels
|
|
651
|
+
- Both datasets must have 'counts' layer (raw counts, not normalized)
|
|
652
|
+
- Sufficient gene overlap between reference and query data
|
|
653
|
+
|
|
654
|
+
Parameters (via AnnotationParameters):
|
|
655
|
+
Core Architecture:
|
|
656
|
+
- scanvi_n_latent (default: 10): Latent space dimensions
|
|
657
|
+
- scanvi_n_hidden (default: 128): Hidden layer units
|
|
658
|
+
- scanvi_n_layers (default: 1): Number of layers
|
|
659
|
+
- scanvi_dropout_rate (default: 0.1): Dropout for regularization
|
|
660
|
+
|
|
661
|
+
Training Strategy:
|
|
662
|
+
- scanvi_use_scvi_pretrain (default: True): Use SCVI pretraining
|
|
663
|
+
- scanvi_scvi_epochs (default: 200): SCVI pretraining epochs
|
|
664
|
+
- num_epochs (default: 100): SCANVI training epochs
|
|
665
|
+
- scanvi_query_epochs (default: 100): Query data training epochs
|
|
666
|
+
|
|
667
|
+
Advanced:
|
|
668
|
+
- scanvi_unlabeled_category (default: "Unknown"): Label for unlabeled cells
|
|
669
|
+
- scanvi_n_samples_per_label (default: 100): Samples per label
|
|
670
|
+
- batch_key: For batch correction (optional)
|
|
671
|
+
|
|
672
|
+
Official Recommendations (scvi-tools):
|
|
673
|
+
For large integration tasks:
|
|
674
|
+
- scanvi_n_layers: 2
|
|
675
|
+
- scanvi_n_latent: 30
|
|
676
|
+
- scanvi_scvi_epochs: 300 (SCVI pretraining)
|
|
677
|
+
- num_epochs: 100 (SCANVI training)
|
|
678
|
+
- scanvi_query_epochs: 100
|
|
679
|
+
- Gene selection: 1000-10000 HVGs recommended
|
|
680
|
+
|
|
681
|
+
Empirical Adjustments (not official):
|
|
682
|
+
For small datasets (<1000 genes or <1000 cells):
|
|
683
|
+
- scanvi_n_latent: 3-5 (may prevent NaN/gradient explosion)
|
|
684
|
+
- scanvi_dropout_rate: 0.2-0.3 (may improve regularization)
|
|
685
|
+
- scanvi_use_scvi_pretrain: False (may reduce complexity)
|
|
686
|
+
- num_epochs: 50 (may prevent overfitting)
|
|
687
|
+
- scanvi_query_epochs: 50
|
|
688
|
+
|
|
689
|
+
Common Issues:
|
|
690
|
+
- NaN errors during training: Try reducing n_latent or increasing dropout_rate
|
|
691
|
+
- Low confidence scores: Try increasing training epochs or check gene overlap
|
|
692
|
+
- Memory issues: Reduce batch size or use GPU
|
|
693
|
+
|
|
694
|
+
Returns:
|
|
695
|
+
Tuple of (cell_types, counts, confidence_scores, None):
|
|
696
|
+
- cell_types: List of predicted cell type categories
|
|
697
|
+
- counts: Dict mapping cell types to number of cells
|
|
698
|
+
- confidence_scores: Dict mapping cell types to mean prediction probability
|
|
699
|
+
- None: (compatibility placeholder)
|
|
700
|
+
|
|
701
|
+
Example:
|
|
702
|
+
params = AnnotationParameters(
|
|
703
|
+
method="scanvi",
|
|
704
|
+
reference_data_id="reference_sc",
|
|
705
|
+
cell_type_key="cell_types",
|
|
706
|
+
scanvi_n_latent=5, # For small dataset
|
|
707
|
+
scanvi_dropout_rate=0.2, # Better regularization
|
|
708
|
+
scanvi_use_scvi_pretrain=False, # Simpler training
|
|
709
|
+
num_epochs=50, # Prevent overfitting
|
|
710
|
+
)
|
|
711
|
+
"""
|
|
712
|
+
|
|
713
|
+
# Validate dependencies with comprehensive error reporting
|
|
714
|
+
scvi = validate_scvi_tools(ctx, components=["SCANVI"])
|
|
715
|
+
|
|
716
|
+
# Check if reference data is provided
|
|
717
|
+
if reference_adata is None:
|
|
718
|
+
raise ParameterError("scANVI requires reference_data_id parameter.")
|
|
719
|
+
|
|
720
|
+
# Use reference single-cell data (passed from main function via ctx.get_adata())
|
|
721
|
+
adata_ref_original = reference_adata
|
|
722
|
+
|
|
723
|
+
# Handle duplicate gene names
|
|
724
|
+
await ensure_unique_var_names_async(adata_ref_original, ctx, "reference data")
|
|
725
|
+
await ensure_unique_var_names_async(adata, ctx, "query data")
|
|
726
|
+
|
|
727
|
+
# Gene alignment
|
|
728
|
+
common_genes = find_common_genes(adata_ref_original.var_names, adata.var_names)
|
|
729
|
+
|
|
730
|
+
if len(common_genes) < min(100, adata_ref_original.n_vars * 0.5):
|
|
731
|
+
raise DataError(
|
|
732
|
+
f"Insufficient gene overlap: Only {len(common_genes)} common genes found. "
|
|
733
|
+
f"Reference has {adata_ref_original.n_vars}, query has {adata.n_vars} genes."
|
|
734
|
+
)
|
|
735
|
+
|
|
736
|
+
# COW FIX: Operate on temporary copies for gene subsetting
|
|
737
|
+
# This prevents loss of HVG information in the original adata
|
|
738
|
+
if len(common_genes) < adata_ref_original.n_vars:
|
|
739
|
+
await ctx.warning(
|
|
740
|
+
f"Subsetting to {len(common_genes)} common genes for ScanVI training "
|
|
741
|
+
f"(reference: {adata_ref_original.n_vars}, query: {adata.n_vars})"
|
|
742
|
+
)
|
|
743
|
+
# Create subsets for ScanVI (not modifying originals)
|
|
744
|
+
adata_ref = adata_ref_original[:, common_genes].copy()
|
|
745
|
+
adata_subset = adata[:, common_genes].copy()
|
|
746
|
+
else:
|
|
747
|
+
# No subsetting needed
|
|
748
|
+
adata_ref = adata_ref_original.copy()
|
|
749
|
+
adata_subset = adata.copy()
|
|
750
|
+
|
|
751
|
+
# Data validation
|
|
752
|
+
if "log1p" not in adata_ref.uns:
|
|
753
|
+
await ctx.warning("Reference data may not be log-normalized")
|
|
754
|
+
if "highly_variable" not in adata_ref.var:
|
|
755
|
+
await ctx.warning("No highly variable genes detected in reference")
|
|
756
|
+
|
|
757
|
+
# Get parameters
|
|
758
|
+
cell_type_key = getattr(params, "cell_type_key", "cell_type")
|
|
759
|
+
batch_key = getattr(params, "batch_key", None)
|
|
760
|
+
|
|
761
|
+
# Optional SCVI Pretraining
|
|
762
|
+
if params.scanvi_use_scvi_pretrain:
|
|
763
|
+
# Setup for SCVI with labels (required for SCANVI conversion)
|
|
764
|
+
# First ensure the reference has the cell type labels
|
|
765
|
+
validate_obs_column(
|
|
766
|
+
adata_ref, cell_type_key, "Cell type column (reference data)"
|
|
767
|
+
)
|
|
768
|
+
|
|
769
|
+
# SCVI needs to know about labels for later SCANVI conversion
|
|
770
|
+
scvi.model.SCVI.setup_anndata(
|
|
771
|
+
adata_ref,
|
|
772
|
+
labels_key=cell_type_key, # Important: include labels_key
|
|
773
|
+
batch_key=batch_key,
|
|
774
|
+
layer=params.layer,
|
|
775
|
+
)
|
|
776
|
+
|
|
777
|
+
# Train SCVI
|
|
778
|
+
scvi_model = scvi.model.SCVI(
|
|
779
|
+
adata_ref,
|
|
780
|
+
n_latent=params.scanvi_n_latent,
|
|
781
|
+
n_hidden=params.scanvi_n_hidden,
|
|
782
|
+
n_layers=params.scanvi_n_layers,
|
|
783
|
+
dropout_rate=params.scanvi_dropout_rate,
|
|
784
|
+
)
|
|
785
|
+
|
|
786
|
+
scvi_model.train(
|
|
787
|
+
max_epochs=params.scanvi_scvi_epochs,
|
|
788
|
+
early_stopping=True,
|
|
789
|
+
check_val_every_n_epoch=params.scanvi_check_val_every_n_epoch,
|
|
790
|
+
)
|
|
791
|
+
|
|
792
|
+
# Convert to SCANVI (no need for setup_anndata, it uses SCVI's setup)
|
|
793
|
+
model = scvi.model.SCANVI.from_scvi_model(
|
|
794
|
+
scvi_model, params.scanvi_unlabeled_category
|
|
795
|
+
)
|
|
796
|
+
|
|
797
|
+
# Train SCANVI (fewer epochs needed after pretraining)
|
|
798
|
+
# Use configurable epochs (default: 20, official recommendation after pretraining)
|
|
799
|
+
model.train(
|
|
800
|
+
max_epochs=params.scanvi_scanvi_epochs,
|
|
801
|
+
n_samples_per_label=params.scanvi_n_samples_per_label,
|
|
802
|
+
early_stopping=True,
|
|
803
|
+
)
|
|
804
|
+
|
|
805
|
+
else:
|
|
806
|
+
# Direct SCANVI training (existing approach)
|
|
807
|
+
# Ensure counts layer exists (create from adata.raw if needed)
|
|
808
|
+
ensure_counts_layer(
|
|
809
|
+
adata_ref,
|
|
810
|
+
error_message="scANVI requires raw counts in layers['counts'].",
|
|
811
|
+
)
|
|
812
|
+
|
|
813
|
+
# Setup AnnData for scANVI
|
|
814
|
+
scvi.model.SCANVI.setup_anndata(
|
|
815
|
+
adata_ref,
|
|
816
|
+
labels_key=cell_type_key,
|
|
817
|
+
unlabeled_category=params.scanvi_unlabeled_category,
|
|
818
|
+
batch_key=batch_key,
|
|
819
|
+
layer="counts",
|
|
820
|
+
)
|
|
821
|
+
|
|
822
|
+
# Create scANVI model
|
|
823
|
+
model = scvi.model.SCANVI(
|
|
824
|
+
adata_ref,
|
|
825
|
+
n_hidden=params.scanvi_n_hidden,
|
|
826
|
+
n_latent=params.scanvi_n_latent,
|
|
827
|
+
n_layers=params.scanvi_n_layers,
|
|
828
|
+
dropout_rate=params.scanvi_dropout_rate,
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
model.train(
|
|
832
|
+
max_epochs=params.num_epochs,
|
|
833
|
+
n_samples_per_label=params.scanvi_n_samples_per_label,
|
|
834
|
+
early_stopping=True,
|
|
835
|
+
check_val_every_n_epoch=params.scanvi_check_val_every_n_epoch,
|
|
836
|
+
)
|
|
837
|
+
|
|
838
|
+
# Query data preparation
|
|
839
|
+
adata_subset.obs[cell_type_key] = params.scanvi_unlabeled_category
|
|
840
|
+
|
|
841
|
+
# Setup query data (batch handling)
|
|
842
|
+
if batch_key and batch_key not in adata_subset.obs:
|
|
843
|
+
adata_subset.obs[batch_key] = "query_batch"
|
|
844
|
+
|
|
845
|
+
# Ensure counts layer exists for query data (create from adata.raw if needed)
|
|
846
|
+
ensure_counts_layer(
|
|
847
|
+
adata_subset,
|
|
848
|
+
error_message="scANVI requires raw counts in layers['counts'].",
|
|
849
|
+
)
|
|
850
|
+
|
|
851
|
+
scvi.model.SCANVI.setup_anndata(
|
|
852
|
+
adata_subset,
|
|
853
|
+
labels_key=cell_type_key,
|
|
854
|
+
unlabeled_category=params.scanvi_unlabeled_category,
|
|
855
|
+
batch_key=batch_key,
|
|
856
|
+
layer="counts",
|
|
857
|
+
)
|
|
858
|
+
|
|
859
|
+
# Transfer model to spatial data with proper parameters
|
|
860
|
+
spatial_model = scvi.model.SCANVI.load_query_data(adata_subset, model)
|
|
861
|
+
|
|
862
|
+
# ===== Improved Query Training (NEW) =====
|
|
863
|
+
spatial_model.train(
|
|
864
|
+
max_epochs=params.scanvi_query_epochs, # Default: 100 (was 50)
|
|
865
|
+
early_stopping=True,
|
|
866
|
+
plan_kwargs=dict(weight_decay=0.0), # Critical: preserve reference space
|
|
867
|
+
check_val_every_n_epoch=params.scanvi_check_val_every_n_epoch,
|
|
868
|
+
)
|
|
869
|
+
|
|
870
|
+
# COW FIX: Get predictions from adata_subset, then add to original adata
|
|
871
|
+
predictions = spatial_model.predict()
|
|
872
|
+
adata_subset.obs[cell_type_key] = predictions
|
|
873
|
+
ensure_categorical(adata_subset, cell_type_key)
|
|
874
|
+
|
|
875
|
+
# Extract results from adata_subset
|
|
876
|
+
cell_types = list(adata_subset.obs[cell_type_key].cat.categories)
|
|
877
|
+
counts = adata_subset.obs[cell_type_key].value_counts().to_dict()
|
|
878
|
+
|
|
879
|
+
# Get prediction probabilities as confidence scores
|
|
880
|
+
try:
|
|
881
|
+
probs = spatial_model.predict(soft=True)
|
|
882
|
+
confidence_scores = {}
|
|
883
|
+
for i, cell_type in enumerate(cell_types):
|
|
884
|
+
cells_of_type = adata_subset.obs[cell_type_key] == cell_type
|
|
885
|
+
if np.sum(cells_of_type) > 0 and isinstance(probs, pd.DataFrame):
|
|
886
|
+
if cell_type in probs.columns:
|
|
887
|
+
mean_prob = probs.loc[cells_of_type, cell_type].mean()
|
|
888
|
+
confidence_scores[cell_type] = round(float(mean_prob), 2)
|
|
889
|
+
# else: No probability column for this cell type - skip confidence
|
|
890
|
+
elif (
|
|
891
|
+
np.sum(cells_of_type) > 0
|
|
892
|
+
and hasattr(probs, "shape")
|
|
893
|
+
and probs.shape[1] > i
|
|
894
|
+
):
|
|
895
|
+
mean_prob = probs[cells_of_type, i].mean()
|
|
896
|
+
confidence_scores[cell_type] = round(float(mean_prob), 2)
|
|
897
|
+
# else: No cells of this type or no probability data - skip confidence
|
|
898
|
+
except Exception as e:
|
|
899
|
+
await ctx.warning(f"Could not get confidence scores: {e}")
|
|
900
|
+
# Could not extract probabilities - return empty confidence dict
|
|
901
|
+
confidence_scores = (
|
|
902
|
+
{}
|
|
903
|
+
) # Empty dict clearly indicates no confidence data available
|
|
904
|
+
|
|
905
|
+
# COW FIX: Add prediction results to original adata.obs using output_key
|
|
906
|
+
adata.obs[output_key] = adata_subset.obs[cell_type_key].values
|
|
907
|
+
ensure_categorical(adata, output_key)
|
|
908
|
+
|
|
909
|
+
# Store confidence if available
|
|
910
|
+
if confidence_scores:
|
|
911
|
+
confidence_array = [
|
|
912
|
+
confidence_scores.get(ct, 0.0) for ct in adata.obs[output_key]
|
|
913
|
+
]
|
|
914
|
+
adata.obs[confidence_key] = confidence_array
|
|
915
|
+
|
|
916
|
+
return AnnotationMethodOutput(
|
|
917
|
+
cell_types=cell_types,
|
|
918
|
+
counts=counts,
|
|
919
|
+
confidence=confidence_scores,
|
|
920
|
+
)
|
|
921
|
+
|
|
922
|
+
|
|
923
|
+
async def _annotate_with_mllmcelltype(
|
|
924
|
+
adata,
|
|
925
|
+
params: AnnotationParameters,
|
|
926
|
+
ctx: "ToolContext",
|
|
927
|
+
output_key: str,
|
|
928
|
+
confidence_key: str,
|
|
929
|
+
) -> AnnotationMethodOutput:
|
|
930
|
+
"""Annotate cell types using mLLMCellType (LLM-based) method.
|
|
931
|
+
|
|
932
|
+
Supports both single-model and multi-model consensus annotation.
|
|
933
|
+
|
|
934
|
+
Single Model Mode (default):
|
|
935
|
+
- Uses one LLM for annotation
|
|
936
|
+
- Fast and cost-effective
|
|
937
|
+
- Providers: openai, anthropic, gemini, deepseek, qwen, zhipu, stepfun, minimax, grok, openrouter
|
|
938
|
+
- Default models: openai="gpt-5", anthropic="claude-sonnet-4-20250514", gemini="gemini-2.5-pro-preview-03-25"
|
|
939
|
+
- Latest recommended: "gpt-5", "claude-sonnet-4-5-20250929", "claude-opus-4-1-20250805", "gemini-2.5-pro"
|
|
940
|
+
|
|
941
|
+
Multi-Model Consensus Mode (set mllm_use_consensus=True):
|
|
942
|
+
- Uses multiple LLMs for collaborative annotation
|
|
943
|
+
- Higher accuracy through consensus
|
|
944
|
+
- Provides uncertainty metrics (consensus proportion, entropy)
|
|
945
|
+
- Structured deliberation for controversial clusters
|
|
946
|
+
|
|
947
|
+
Parameters (via AnnotationParameters):
|
|
948
|
+
- cluster_label: Required. Cluster column in adata.obs
|
|
949
|
+
- mllm_species: "human" or "mouse"
|
|
950
|
+
- mllm_tissue: Tissue context (optional but recommended)
|
|
951
|
+
- mllm_provider: LLM provider (single model mode)
|
|
952
|
+
- mllm_model: Model name (None = use default for provider)
|
|
953
|
+
- mllm_use_consensus: Enable multi-model consensus
|
|
954
|
+
- mllm_models: List of models for consensus (e.g., ["gpt-5", "claude-sonnet-4-5-20250929"])
|
|
955
|
+
- mllm_additional_context: Additional context for better annotation
|
|
956
|
+
- mllm_base_urls: Custom API endpoints (useful for proxies)
|
|
957
|
+
"""
|
|
958
|
+
|
|
959
|
+
# Validate dependencies with comprehensive error reporting
|
|
960
|
+
require("mllmcelltype", ctx, feature="mLLMCellType annotation")
|
|
961
|
+
import mllmcelltype
|
|
962
|
+
|
|
963
|
+
# Validate clustering has been performed
|
|
964
|
+
# cluster_label is now required for mLLMCellType (no default value)
|
|
965
|
+
if not params.cluster_label:
|
|
966
|
+
available_cols = list(adata.obs.columns)
|
|
967
|
+
categorical_cols = [
|
|
968
|
+
col
|
|
969
|
+
for col in available_cols
|
|
970
|
+
if adata.obs[col].dtype.name in ["object", "category"]
|
|
971
|
+
]
|
|
972
|
+
|
|
973
|
+
raise ParameterError(
|
|
974
|
+
f"cluster_label parameter is required for mLLMCellType method.\n\n"
|
|
975
|
+
f"Available categorical columns (likely clusters):\n {', '.join(categorical_cols[:15])}\n"
|
|
976
|
+
f"{f' ... and {len(categorical_cols)-15} more' if len(categorical_cols) > 15 else ''}\n\n"
|
|
977
|
+
f"Common cluster column names: leiden, louvain, seurat_clusters, phenograph\n\n"
|
|
978
|
+
f"Example: params = {{'cluster_label': 'leiden', ...}}"
|
|
979
|
+
)
|
|
980
|
+
|
|
981
|
+
cluster_key = params.cluster_label
|
|
982
|
+
validate_obs_column(adata, cluster_key, "Cluster")
|
|
983
|
+
|
|
984
|
+
# Find differentially expressed genes for each cluster
|
|
985
|
+
|
|
986
|
+
sc.tl.rank_genes_groups(adata, cluster_key, method="wilcoxon")
|
|
987
|
+
|
|
988
|
+
# Extract top marker genes for each cluster
|
|
989
|
+
marker_genes_dict = {}
|
|
990
|
+
n_genes = params.mllm_n_marker_genes
|
|
991
|
+
|
|
992
|
+
for cluster in adata.obs[cluster_key].unique():
|
|
993
|
+
# Get top genes for this cluster
|
|
994
|
+
gene_names = adata.uns["rank_genes_groups"]["names"][str(cluster)][:n_genes]
|
|
995
|
+
marker_genes_dict[f"Cluster_{cluster}"] = list(gene_names)
|
|
996
|
+
|
|
997
|
+
# Prepare parameters for mllmcelltype
|
|
998
|
+
species = params.mllm_species
|
|
999
|
+
tissue = params.mllm_tissue
|
|
1000
|
+
additional_context = params.mllm_additional_context
|
|
1001
|
+
use_cache = params.mllm_use_cache
|
|
1002
|
+
base_urls = params.mllm_base_urls
|
|
1003
|
+
verbose = params.mllm_verbose
|
|
1004
|
+
force_rerun = params.mllm_force_rerun
|
|
1005
|
+
clusters_to_analyze = params.mllm_clusters_to_analyze
|
|
1006
|
+
|
|
1007
|
+
# Check if using multi-model consensus or single model
|
|
1008
|
+
use_consensus = params.mllm_use_consensus
|
|
1009
|
+
|
|
1010
|
+
try:
|
|
1011
|
+
if use_consensus:
|
|
1012
|
+
# Use interactive_consensus_annotation with multiple models
|
|
1013
|
+
models = params.mllm_models
|
|
1014
|
+
if not models:
|
|
1015
|
+
raise ParameterError(
|
|
1016
|
+
"mllm_models parameter is required when mllm_use_consensus=True. "
|
|
1017
|
+
"Provide a list of model names, e.g., ['gpt-5', 'claude-sonnet-4-5-20250929', 'gemini-2.5-pro']"
|
|
1018
|
+
)
|
|
1019
|
+
|
|
1020
|
+
api_keys = params.mllm_api_keys
|
|
1021
|
+
consensus_threshold = params.mllm_consensus_threshold
|
|
1022
|
+
entropy_threshold = params.mllm_entropy_threshold
|
|
1023
|
+
max_discussion_rounds = params.mllm_max_discussion_rounds
|
|
1024
|
+
consensus_model = params.mllm_consensus_model
|
|
1025
|
+
|
|
1026
|
+
# Call interactive_consensus_annotation
|
|
1027
|
+
consensus_results = mllmcelltype.interactive_consensus_annotation(
|
|
1028
|
+
marker_genes=marker_genes_dict,
|
|
1029
|
+
species=species,
|
|
1030
|
+
models=models,
|
|
1031
|
+
api_keys=api_keys,
|
|
1032
|
+
tissue=tissue,
|
|
1033
|
+
additional_context=additional_context,
|
|
1034
|
+
consensus_threshold=consensus_threshold,
|
|
1035
|
+
entropy_threshold=entropy_threshold,
|
|
1036
|
+
max_discussion_rounds=max_discussion_rounds,
|
|
1037
|
+
use_cache=use_cache,
|
|
1038
|
+
verbose=verbose,
|
|
1039
|
+
consensus_model=consensus_model,
|
|
1040
|
+
base_urls=base_urls,
|
|
1041
|
+
clusters_to_analyze=clusters_to_analyze,
|
|
1042
|
+
force_rerun=force_rerun,
|
|
1043
|
+
)
|
|
1044
|
+
|
|
1045
|
+
# Extract consensus annotations
|
|
1046
|
+
annotations = consensus_results.get("consensus", {})
|
|
1047
|
+
|
|
1048
|
+
else:
|
|
1049
|
+
# Use single model annotation
|
|
1050
|
+
provider = params.mllm_provider
|
|
1051
|
+
model = params.mllm_model
|
|
1052
|
+
api_key = params.mllm_api_key
|
|
1053
|
+
|
|
1054
|
+
# Call annotate_clusters (single model)
|
|
1055
|
+
annotations = mllmcelltype.annotate_clusters(
|
|
1056
|
+
marker_genes=marker_genes_dict,
|
|
1057
|
+
species=species,
|
|
1058
|
+
provider=provider,
|
|
1059
|
+
model=model,
|
|
1060
|
+
api_key=api_key,
|
|
1061
|
+
tissue=tissue,
|
|
1062
|
+
additional_context=additional_context,
|
|
1063
|
+
use_cache=use_cache,
|
|
1064
|
+
base_urls=base_urls,
|
|
1065
|
+
)
|
|
1066
|
+
except Exception as e:
|
|
1067
|
+
raise ProcessingError(f"mLLMCellType annotation failed: {e}") from e
|
|
1068
|
+
|
|
1069
|
+
# Map cluster annotations back to cells
|
|
1070
|
+
cluster_to_celltype = {}
|
|
1071
|
+
for cluster_name, cell_type in annotations.items():
|
|
1072
|
+
# Extract cluster number from "Cluster_X" format
|
|
1073
|
+
cluster_id = cluster_name.replace("Cluster_", "")
|
|
1074
|
+
cluster_to_celltype[cluster_id] = cell_type
|
|
1075
|
+
|
|
1076
|
+
# Apply cell type annotations to cells (key provided by caller)
|
|
1077
|
+
adata.obs[output_key] = adata.obs[cluster_key].astype(str).map(cluster_to_celltype)
|
|
1078
|
+
|
|
1079
|
+
# Handle any unmapped clusters
|
|
1080
|
+
unmapped = adata.obs[output_key].isna()
|
|
1081
|
+
if unmapped.any():
|
|
1082
|
+
await ctx.warning(f"Found {unmapped.sum()} cells in unmapped clusters")
|
|
1083
|
+
adata.obs.loc[unmapped, output_key] = "Unknown"
|
|
1084
|
+
|
|
1085
|
+
ensure_categorical(adata, output_key)
|
|
1086
|
+
|
|
1087
|
+
# Get cell types and counts
|
|
1088
|
+
cell_types = list(adata.obs[output_key].unique())
|
|
1089
|
+
counts = adata.obs[output_key].value_counts().to_dict()
|
|
1090
|
+
|
|
1091
|
+
# LLM-based annotations don't provide numeric confidence scores
|
|
1092
|
+
# We intentionally leave this empty rather than assigning misleading values
|
|
1093
|
+
return AnnotationMethodOutput(
|
|
1094
|
+
cell_types=cell_types,
|
|
1095
|
+
counts=counts,
|
|
1096
|
+
confidence={},
|
|
1097
|
+
)
|
|
1098
|
+
|
|
1099
|
+
|
|
1100
|
+
async def _annotate_with_cellassign(
|
|
1101
|
+
adata,
|
|
1102
|
+
params: AnnotationParameters,
|
|
1103
|
+
ctx: "ToolContext",
|
|
1104
|
+
output_key: str,
|
|
1105
|
+
confidence_key: str,
|
|
1106
|
+
) -> AnnotationMethodOutput:
|
|
1107
|
+
"""Annotate cell types using CellAssign method"""
|
|
1108
|
+
|
|
1109
|
+
# Validate dependencies with comprehensive error reporting
|
|
1110
|
+
validate_scvi_tools(ctx, components=["CellAssign"])
|
|
1111
|
+
from scvi.external import CellAssign
|
|
1112
|
+
|
|
1113
|
+
# Check if marker genes are provided
|
|
1114
|
+
if params.marker_genes is None:
|
|
1115
|
+
raise ParameterError(
|
|
1116
|
+
"CellAssign requires marker genes to be provided. "
|
|
1117
|
+
"Please specify marker_genes parameter with a dictionary of cell types and their marker genes."
|
|
1118
|
+
)
|
|
1119
|
+
|
|
1120
|
+
marker_genes = params.marker_genes
|
|
1121
|
+
|
|
1122
|
+
# CRITICAL FIX: Use adata.raw for marker gene validation if available
|
|
1123
|
+
# Preprocessing filters genes to HVGs, but marker genes may not be in HVGs
|
|
1124
|
+
# adata.raw contains all original genes and should be checked first
|
|
1125
|
+
if adata.raw is not None:
|
|
1126
|
+
all_genes = set(adata.raw.var_names)
|
|
1127
|
+
gene_source = "adata.raw"
|
|
1128
|
+
else:
|
|
1129
|
+
all_genes = set(adata.var_names)
|
|
1130
|
+
gene_source = "adata.var_names"
|
|
1131
|
+
await ctx.warning(
|
|
1132
|
+
f"Using filtered gene set for marker gene validation "
|
|
1133
|
+
f"({len(all_genes)} genes). Some marker genes may be missing. "
|
|
1134
|
+
f"Consider using unpreprocessed data for CellAssign."
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
# Validate marker genes exist in dataset
|
|
1138
|
+
valid_marker_genes = {}
|
|
1139
|
+
total_markers = sum(len(g) for g in marker_genes.values())
|
|
1140
|
+
markers_found = 0
|
|
1141
|
+
markers_missing = 0
|
|
1142
|
+
|
|
1143
|
+
for cell_type, genes in marker_genes.items():
|
|
1144
|
+
existing_genes = [gene for gene in genes if gene in all_genes]
|
|
1145
|
+
missing_genes = [gene for gene in genes if gene not in all_genes]
|
|
1146
|
+
|
|
1147
|
+
if existing_genes:
|
|
1148
|
+
valid_marker_genes[cell_type] = existing_genes
|
|
1149
|
+
markers_found += len(existing_genes)
|
|
1150
|
+
if missing_genes and len(missing_genes) > len(existing_genes):
|
|
1151
|
+
await ctx.warning(
|
|
1152
|
+
f"Missing most markers for {cell_type}: {len(missing_genes)}/{len(genes)}"
|
|
1153
|
+
)
|
|
1154
|
+
else:
|
|
1155
|
+
markers_missing += len(genes)
|
|
1156
|
+
await ctx.warning(
|
|
1157
|
+
f"No marker genes found for {cell_type} - all {len(genes)} markers missing!"
|
|
1158
|
+
)
|
|
1159
|
+
|
|
1160
|
+
if not valid_marker_genes:
|
|
1161
|
+
raise DataError(
|
|
1162
|
+
f"No valid marker genes found for any cell type. "
|
|
1163
|
+
f"Checked {total_markers} markers against {len(all_genes)} genes in {gene_source}. "
|
|
1164
|
+
f"If data was preprocessed, marker genes may have been filtered out. "
|
|
1165
|
+
f"Consider using unpreprocessed data or ensure marker genes are highly variable."
|
|
1166
|
+
)
|
|
1167
|
+
valid_cell_types = list(valid_marker_genes)
|
|
1168
|
+
|
|
1169
|
+
# Create marker gene matrix as DataFrame (required by CellAssign API)
|
|
1170
|
+
all_marker_genes = []
|
|
1171
|
+
for genes in valid_marker_genes.values():
|
|
1172
|
+
all_marker_genes.extend(genes)
|
|
1173
|
+
available_marker_genes = list(set(all_marker_genes)) # Remove duplicates
|
|
1174
|
+
|
|
1175
|
+
# Note: available_marker_genes cannot be empty here because valid_marker_genes
|
|
1176
|
+
# is already validated at line 1120 to have at least one cell type with genes
|
|
1177
|
+
|
|
1178
|
+
# Create DataFrame with genes as index, cell types as columns
|
|
1179
|
+
marker_gene_matrix = pd.DataFrame(
|
|
1180
|
+
np.zeros((len(available_marker_genes), len(valid_cell_types))),
|
|
1181
|
+
index=available_marker_genes,
|
|
1182
|
+
columns=valid_cell_types,
|
|
1183
|
+
)
|
|
1184
|
+
|
|
1185
|
+
# Fill marker matrix
|
|
1186
|
+
for cell_type in valid_cell_types:
|
|
1187
|
+
for gene in valid_marker_genes[cell_type]:
|
|
1188
|
+
if gene in available_marker_genes:
|
|
1189
|
+
marker_gene_matrix.loc[gene, cell_type] = 1
|
|
1190
|
+
|
|
1191
|
+
# Compute size factors BEFORE subsetting (official CellAssign requirement)
|
|
1192
|
+
if "size_factors" not in adata.obs:
|
|
1193
|
+
# Calculate size factors from FULL dataset
|
|
1194
|
+
if hasattr(adata.X, "sum"):
|
|
1195
|
+
size_factors = adata.X.sum(axis=1)
|
|
1196
|
+
if hasattr(size_factors, "A1"): # sparse matrix
|
|
1197
|
+
size_factors = size_factors.A1
|
|
1198
|
+
else:
|
|
1199
|
+
size_factors = np.sum(adata.X, axis=1)
|
|
1200
|
+
|
|
1201
|
+
# Normalize and ensure positive
|
|
1202
|
+
size_factors = np.maximum(size_factors, 1e-6)
|
|
1203
|
+
mean_sf = np.mean(size_factors)
|
|
1204
|
+
size_factors_normalized = size_factors / mean_sf
|
|
1205
|
+
|
|
1206
|
+
adata.obs["size_factors"] = pd.Series(
|
|
1207
|
+
size_factors_normalized, index=adata.obs.index
|
|
1208
|
+
)
|
|
1209
|
+
|
|
1210
|
+
# Subset data to marker genes (size factors already computed)
|
|
1211
|
+
# Use adata.raw if available (contains all genes including markers)
|
|
1212
|
+
if adata.raw is not None:
|
|
1213
|
+
import anndata as ad_module
|
|
1214
|
+
|
|
1215
|
+
adata_subset = ad_module.AnnData(
|
|
1216
|
+
X=adata.raw[:, available_marker_genes].X,
|
|
1217
|
+
obs=adata.obs.copy(),
|
|
1218
|
+
var=adata.raw.var.loc[available_marker_genes].copy(),
|
|
1219
|
+
)
|
|
1220
|
+
else:
|
|
1221
|
+
adata_subset = adata[:, available_marker_genes].copy()
|
|
1222
|
+
|
|
1223
|
+
# Check for invalid values in the data
|
|
1224
|
+
X_array = to_dense(adata_subset.X)
|
|
1225
|
+
|
|
1226
|
+
# Replace any NaN or Inf values with zeros
|
|
1227
|
+
if np.any(np.isnan(X_array)) or np.any(np.isinf(X_array)):
|
|
1228
|
+
await ctx.warning("Found NaN or Inf values in data, replacing with zeros")
|
|
1229
|
+
X_array = np.nan_to_num(X_array, nan=0.0, posinf=0.0, neginf=0.0)
|
|
1230
|
+
adata_subset.X = X_array
|
|
1231
|
+
|
|
1232
|
+
# Additional data cleaning for CellAssign compatibility
|
|
1233
|
+
# Check for genes with zero variance (which cause numerical issues in CellAssign)
|
|
1234
|
+
gene_vars = np.var(X_array, axis=0)
|
|
1235
|
+
zero_var_genes = gene_vars == 0
|
|
1236
|
+
if np.any(zero_var_genes):
|
|
1237
|
+
adata_subset.var_names[zero_var_genes].tolist()
|
|
1238
|
+
await ctx.warning(
|
|
1239
|
+
f"Found {np.sum(zero_var_genes)} genes with zero variance. "
|
|
1240
|
+
f"CellAssign may have numerical issues with these genes."
|
|
1241
|
+
)
|
|
1242
|
+
# Don't raise error, just warn - CellAssign might handle it
|
|
1243
|
+
|
|
1244
|
+
# Ensure data is non-negative (CellAssign expects count-like data)
|
|
1245
|
+
if np.any(X_array < 0):
|
|
1246
|
+
await ctx.warning("Found negative values in data, clipping to zero")
|
|
1247
|
+
X_array = np.maximum(X_array, 0)
|
|
1248
|
+
adata_subset.X = X_array
|
|
1249
|
+
|
|
1250
|
+
# Verify size factors were transferred to subset
|
|
1251
|
+
if "size_factors" not in adata_subset.obs:
|
|
1252
|
+
raise ProcessingError(
|
|
1253
|
+
"Size factors not found in adata.obs. This should not happen - "
|
|
1254
|
+
"they should have been computed before subsetting. Please report this bug."
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
# Setup CellAssign on subset data only
|
|
1258
|
+
CellAssign.setup_anndata(adata_subset, size_factor_key="size_factors")
|
|
1259
|
+
|
|
1260
|
+
# Train CellAssign model
|
|
1261
|
+
model = CellAssign(adata_subset, marker_gene_matrix)
|
|
1262
|
+
|
|
1263
|
+
model.train(
|
|
1264
|
+
max_epochs=params.cellassign_max_iter, lr=params.cellassign_learning_rate
|
|
1265
|
+
)
|
|
1266
|
+
|
|
1267
|
+
# Get predictions
|
|
1268
|
+
predictions = model.predict()
|
|
1269
|
+
|
|
1270
|
+
# Handle different prediction formats (key provided by caller)
|
|
1271
|
+
if isinstance(predictions, pd.DataFrame):
|
|
1272
|
+
# CellAssign returns DataFrame with probabilities
|
|
1273
|
+
predicted_indices = predictions.values.argmax(axis=1)
|
|
1274
|
+
adata.obs[output_key] = [valid_cell_types[i] for i in predicted_indices]
|
|
1275
|
+
|
|
1276
|
+
# Get confidence scores from probabilities DataFrame
|
|
1277
|
+
confidence_scores = {}
|
|
1278
|
+
for i, cell_type in enumerate(valid_cell_types):
|
|
1279
|
+
cells_of_type = adata.obs[output_key] == cell_type
|
|
1280
|
+
if np.sum(cells_of_type) > 0:
|
|
1281
|
+
# Use iloc with boolean indexing properly
|
|
1282
|
+
cell_indices = np.where(cells_of_type)[0]
|
|
1283
|
+
mean_prob = predictions.iloc[cell_indices, i].mean()
|
|
1284
|
+
confidence_scores[cell_type] = round(float(mean_prob), 2)
|
|
1285
|
+
# else: No cells of this type - skip confidence
|
|
1286
|
+
else:
|
|
1287
|
+
# Other models return indices directly
|
|
1288
|
+
adata.obs[output_key] = [valid_cell_types[i] for i in predictions]
|
|
1289
|
+
# CellAssign returned indices, not probabilities - no confidence available
|
|
1290
|
+
confidence_scores = {} # Empty dict indicates no confidence data
|
|
1291
|
+
|
|
1292
|
+
ensure_categorical(adata, output_key)
|
|
1293
|
+
|
|
1294
|
+
# Store confidence if available
|
|
1295
|
+
if confidence_scores:
|
|
1296
|
+
confidence_array = [
|
|
1297
|
+
confidence_scores.get(ct, 0.0) for ct in adata.obs[output_key]
|
|
1298
|
+
]
|
|
1299
|
+
adata.obs[confidence_key] = confidence_array
|
|
1300
|
+
|
|
1301
|
+
# Get cell types and counts
|
|
1302
|
+
counts = adata.obs[output_key].value_counts().to_dict()
|
|
1303
|
+
|
|
1304
|
+
return AnnotationMethodOutput(
|
|
1305
|
+
cell_types=valid_cell_types,
|
|
1306
|
+
counts=counts,
|
|
1307
|
+
confidence=confidence_scores,
|
|
1308
|
+
)
|
|
1309
|
+
|
|
1310
|
+
|
|
1311
|
+
async def annotate_cell_types(
|
|
1312
|
+
data_id: str,
|
|
1313
|
+
ctx: ToolContext,
|
|
1314
|
+
params: AnnotationParameters, # No default - must be provided by caller (LLM)
|
|
1315
|
+
) -> AnnotationResult:
|
|
1316
|
+
"""Annotate cell types in spatial transcriptomics data
|
|
1317
|
+
|
|
1318
|
+
Args:
|
|
1319
|
+
data_id: Dataset ID
|
|
1320
|
+
ctx: Tool context for data access and logging
|
|
1321
|
+
params: Annotation parameters
|
|
1322
|
+
|
|
1323
|
+
Returns:
|
|
1324
|
+
Annotation result
|
|
1325
|
+
"""
|
|
1326
|
+
# Retrieve the AnnData object via ToolContext
|
|
1327
|
+
adata = await ctx.get_adata(data_id)
|
|
1328
|
+
|
|
1329
|
+
# Validate method first - clean and simple
|
|
1330
|
+
if params.method not in SUPPORTED_METHODS:
|
|
1331
|
+
raise ParameterError(
|
|
1332
|
+
f"Unsupported method: {params.method}. Supported: {sorted(SUPPORTED_METHODS)}"
|
|
1333
|
+
)
|
|
1334
|
+
|
|
1335
|
+
# Get reference data if needed for methods that require it
|
|
1336
|
+
reference_adata = None
|
|
1337
|
+
if params.method in ["tangram", "scanvi", "singler"] and params.reference_data_id:
|
|
1338
|
+
reference_adata = await ctx.get_adata(params.reference_data_id)
|
|
1339
|
+
|
|
1340
|
+
# Generate output keys in ONE place (single-point control)
|
|
1341
|
+
output_key = f"cell_type_{params.method}"
|
|
1342
|
+
confidence_key = f"confidence_{params.method}"
|
|
1343
|
+
|
|
1344
|
+
# Route to appropriate annotation method
|
|
1345
|
+
try:
|
|
1346
|
+
if params.method == "tangram":
|
|
1347
|
+
result = await _annotate_with_tangram(
|
|
1348
|
+
adata, params, ctx, output_key, confidence_key, reference_adata
|
|
1349
|
+
)
|
|
1350
|
+
elif params.method == "scanvi":
|
|
1351
|
+
result = await _annotate_with_scanvi(
|
|
1352
|
+
adata, params, ctx, output_key, confidence_key, reference_adata
|
|
1353
|
+
)
|
|
1354
|
+
elif params.method == "cellassign":
|
|
1355
|
+
result = await _annotate_with_cellassign(
|
|
1356
|
+
adata, params, ctx, output_key, confidence_key
|
|
1357
|
+
)
|
|
1358
|
+
elif params.method == "mllmcelltype":
|
|
1359
|
+
result = await _annotate_with_mllmcelltype(
|
|
1360
|
+
adata, params, ctx, output_key, confidence_key
|
|
1361
|
+
)
|
|
1362
|
+
elif params.method == "singler":
|
|
1363
|
+
result = await _annotate_with_singler(
|
|
1364
|
+
adata, params, ctx, output_key, confidence_key, reference_adata
|
|
1365
|
+
)
|
|
1366
|
+
else: # sctype
|
|
1367
|
+
result = await _annotate_with_sctype(
|
|
1368
|
+
adata, params, ctx, output_key, confidence_key
|
|
1369
|
+
)
|
|
1370
|
+
|
|
1371
|
+
except Exception as e:
|
|
1372
|
+
raise ProcessingError(f"Annotation failed: {e}") from e
|
|
1373
|
+
|
|
1374
|
+
# Extract values from unified result type
|
|
1375
|
+
cell_types = result.cell_types
|
|
1376
|
+
counts = result.counts
|
|
1377
|
+
confidence_scores = result.confidence
|
|
1378
|
+
tangram_mapping_score = result.mapping_score
|
|
1379
|
+
|
|
1380
|
+
# Determine if confidence_key should be reported (only if we have confidence data)
|
|
1381
|
+
confidence_key_for_result = confidence_key if confidence_scores else None
|
|
1382
|
+
|
|
1383
|
+
# Store scientific metadata for reproducibility
|
|
1384
|
+
from ..utils.adata_utils import store_analysis_metadata
|
|
1385
|
+
|
|
1386
|
+
# Extract results keys
|
|
1387
|
+
results_keys_dict = {"obs": [output_key], "obsm": [], "uns": []}
|
|
1388
|
+
if confidence_key_for_result:
|
|
1389
|
+
results_keys_dict["obs"].append(confidence_key)
|
|
1390
|
+
|
|
1391
|
+
# Add method-specific result keys
|
|
1392
|
+
if params.method == "tangram":
|
|
1393
|
+
results_keys_dict["obsm"].extend(
|
|
1394
|
+
["tangram_ct_pred", "tangram_gene_predictions"]
|
|
1395
|
+
)
|
|
1396
|
+
|
|
1397
|
+
# Prepare parameters dict (only scientifically important ones)
|
|
1398
|
+
parameters_dict = {}
|
|
1399
|
+
if params.method == "tangram":
|
|
1400
|
+
parameters_dict = {
|
|
1401
|
+
"device": params.tangram_device,
|
|
1402
|
+
"n_epochs": params.num_epochs, # Fixed: use num_epochs instead of tangram_num_epochs
|
|
1403
|
+
"learning_rate": params.tangram_learning_rate,
|
|
1404
|
+
}
|
|
1405
|
+
elif params.method == "scanvi":
|
|
1406
|
+
parameters_dict = {
|
|
1407
|
+
"n_latent": params.scanvi_n_latent,
|
|
1408
|
+
"n_hidden": params.scanvi_n_hidden,
|
|
1409
|
+
"dropout_rate": params.scanvi_dropout_rate,
|
|
1410
|
+
"use_scvi_pretrain": params.scanvi_use_scvi_pretrain,
|
|
1411
|
+
}
|
|
1412
|
+
elif params.method == "mllmcelltype":
|
|
1413
|
+
parameters_dict = {
|
|
1414
|
+
"n_marker_genes": params.mllm_n_marker_genes,
|
|
1415
|
+
"species": params.mllm_species,
|
|
1416
|
+
"provider": params.mllm_provider,
|
|
1417
|
+
"model": params.mllm_model,
|
|
1418
|
+
"use_consensus": params.mllm_use_consensus,
|
|
1419
|
+
}
|
|
1420
|
+
elif params.method == "sctype":
|
|
1421
|
+
parameters_dict = {
|
|
1422
|
+
"tissue": params.sctype_tissue,
|
|
1423
|
+
"scaled": params.sctype_scaled,
|
|
1424
|
+
}
|
|
1425
|
+
elif params.method == "singler":
|
|
1426
|
+
parameters_dict = {
|
|
1427
|
+
"fine_tune": params.singler_fine_tune,
|
|
1428
|
+
}
|
|
1429
|
+
|
|
1430
|
+
# Prepare statistics dict
|
|
1431
|
+
statistics_dict = {"n_cell_types": len(cell_types)}
|
|
1432
|
+
if tangram_mapping_score is not None:
|
|
1433
|
+
statistics_dict["mapping_score"] = tangram_mapping_score
|
|
1434
|
+
|
|
1435
|
+
# Prepare reference info if applicable
|
|
1436
|
+
reference_info_dict = None
|
|
1437
|
+
if params.method in ["tangram", "scanvi", "singler"] and params.reference_data_id:
|
|
1438
|
+
reference_info_dict = {"reference_data_id": params.reference_data_id}
|
|
1439
|
+
|
|
1440
|
+
# Store metadata
|
|
1441
|
+
store_analysis_metadata(
|
|
1442
|
+
adata,
|
|
1443
|
+
analysis_name=f"annotation_{params.method}",
|
|
1444
|
+
method=params.method,
|
|
1445
|
+
parameters=parameters_dict,
|
|
1446
|
+
results_keys=results_keys_dict,
|
|
1447
|
+
statistics=statistics_dict,
|
|
1448
|
+
reference_info=reference_info_dict,
|
|
1449
|
+
)
|
|
1450
|
+
|
|
1451
|
+
# Return result
|
|
1452
|
+
return AnnotationResult(
|
|
1453
|
+
data_id=data_id,
|
|
1454
|
+
method=params.method,
|
|
1455
|
+
output_key=output_key,
|
|
1456
|
+
confidence_key=confidence_key_for_result,
|
|
1457
|
+
cell_types=cell_types,
|
|
1458
|
+
counts=counts,
|
|
1459
|
+
confidence_scores=confidence_scores,
|
|
1460
|
+
tangram_mapping_score=tangram_mapping_score,
|
|
1461
|
+
)
|
|
1462
|
+
|
|
1463
|
+
|
|
1464
|
+
# ============================================================================
|
|
1465
|
+
# SC-TYPE IMPLEMENTATION
|
|
1466
|
+
# ============================================================================
|
|
1467
|
+
|
|
1468
|
+
# Cache for sc-type results (memory only, no pickle)
|
|
1469
|
+
_SCTYPE_CACHE: dict[str, Any] = {}
|
|
1470
|
+
_SCTYPE_CACHE_DIR = Path.home() / ".chatspatial" / "sctype_cache"
|
|
1471
|
+
|
|
1472
|
+
# R code constants for sc-type (extracted for clarity)
|
|
1473
|
+
_R_INSTALL_PACKAGES = """
|
|
1474
|
+
required_packages <- c("dplyr", "openxlsx", "HGNChelper")
|
|
1475
|
+
for (pkg in required_packages) {
|
|
1476
|
+
if (!require(pkg, character.only = TRUE, quietly = TRUE)) {
|
|
1477
|
+
install.packages(pkg, repos = "https://cran.r-project.org/", quiet = TRUE)
|
|
1478
|
+
if (!require(pkg, character.only = TRUE, quietly = TRUE)) {
|
|
1479
|
+
stop(paste("Failed to install R package:", pkg))
|
|
1480
|
+
}
|
|
1481
|
+
}
|
|
1482
|
+
}
|
|
1483
|
+
"""
|
|
1484
|
+
|
|
1485
|
+
_R_LOAD_SCTYPE = """
|
|
1486
|
+
source("https://raw.githubusercontent.com/IanevskiAleksandr/sc-type/master/R/gene_sets_prepare.R")
|
|
1487
|
+
source("https://raw.githubusercontent.com/IanevskiAleksandr/sc-type/master/R/sctype_score_.R")
|
|
1488
|
+
"""
|
|
1489
|
+
|
|
1490
|
+
_R_SCTYPE_SCORING = """
|
|
1491
|
+
# Set row/column names and convert to dense
|
|
1492
|
+
rownames(scdata) <- gene_names
|
|
1493
|
+
colnames(scdata) <- cell_names
|
|
1494
|
+
if (inherits(scdata, 'sparseMatrix')) scdata <- as.matrix(scdata)
|
|
1495
|
+
|
|
1496
|
+
# Extract gene sets
|
|
1497
|
+
gs_positive <- gs_list$gs_positive
|
|
1498
|
+
gs_negative <- gs_list$gs_negative
|
|
1499
|
+
|
|
1500
|
+
if (length(gs_positive) == 0) stop("No valid positive gene sets found")
|
|
1501
|
+
|
|
1502
|
+
# Filter gene sets to genes present in data
|
|
1503
|
+
available_genes <- rownames(scdata)
|
|
1504
|
+
filtered_gs_positive <- list()
|
|
1505
|
+
filtered_gs_negative <- list()
|
|
1506
|
+
|
|
1507
|
+
for (celltype in names(gs_positive)) {
|
|
1508
|
+
pos_genes <- gs_positive[[celltype]]
|
|
1509
|
+
neg_genes <- if (celltype %in% names(gs_negative)) gs_negative[[celltype]] else c()
|
|
1510
|
+
pos_overlap <- intersect(toupper(pos_genes), toupper(available_genes))
|
|
1511
|
+
if (length(pos_overlap) > 0) {
|
|
1512
|
+
filtered_gs_positive[[celltype]] <- pos_overlap
|
|
1513
|
+
filtered_gs_negative[[celltype]] <- intersect(toupper(neg_genes), toupper(available_genes))
|
|
1514
|
+
}
|
|
1515
|
+
}
|
|
1516
|
+
|
|
1517
|
+
if (length(filtered_gs_positive) == 0) {
|
|
1518
|
+
stop("No valid cell type gene sets found after filtering.")
|
|
1519
|
+
}
|
|
1520
|
+
|
|
1521
|
+
# Run sc-type scoring
|
|
1522
|
+
es_max <- sctype_score(
|
|
1523
|
+
scRNAseqData = as.matrix(scdata),
|
|
1524
|
+
scaled = TRUE,
|
|
1525
|
+
gs = filtered_gs_positive,
|
|
1526
|
+
gs2 = filtered_gs_negative
|
|
1527
|
+
)
|
|
1528
|
+
|
|
1529
|
+
if (is.null(es_max) || nrow(es_max) == 0) {
|
|
1530
|
+
stop("SC-Type scoring failed to produce results.")
|
|
1531
|
+
}
|
|
1532
|
+
"""
|
|
1533
|
+
|
|
1534
|
+
# Valid tissue types from sc-type database
|
|
1535
|
+
SCTYPE_VALID_TISSUES = {
|
|
1536
|
+
"Adrenal",
|
|
1537
|
+
"Brain",
|
|
1538
|
+
"Eye",
|
|
1539
|
+
"Heart",
|
|
1540
|
+
"Hippocampus",
|
|
1541
|
+
"Immune system",
|
|
1542
|
+
"Intestine",
|
|
1543
|
+
"Kidney",
|
|
1544
|
+
"Liver",
|
|
1545
|
+
"Lung",
|
|
1546
|
+
"Muscle",
|
|
1547
|
+
"Pancreas",
|
|
1548
|
+
"Placenta",
|
|
1549
|
+
"Spleen",
|
|
1550
|
+
"Stomach",
|
|
1551
|
+
"Thymus",
|
|
1552
|
+
}
|
|
1553
|
+
|
|
1554
|
+
|
|
1555
|
+
def _get_sctype_cache_key(adata, params: AnnotationParameters) -> str:
|
|
1556
|
+
"""Generate cache key for sc-type results"""
|
|
1557
|
+
# Create a hash based on data and parameters
|
|
1558
|
+
data_hash = hashlib.md5()
|
|
1559
|
+
|
|
1560
|
+
# Hash expression data (sample first 1000 cells and 500 genes for efficiency)
|
|
1561
|
+
sample_slice = adata.X[: min(1000, adata.n_obs), : min(500, adata.n_vars)]
|
|
1562
|
+
sample_data = to_dense(sample_slice)
|
|
1563
|
+
data_hash.update(sample_data.tobytes())
|
|
1564
|
+
|
|
1565
|
+
# Hash relevant parameters
|
|
1566
|
+
params_dict = {
|
|
1567
|
+
"tissue": params.sctype_tissue,
|
|
1568
|
+
"db": params.sctype_db_,
|
|
1569
|
+
"scaled": params.sctype_scaled,
|
|
1570
|
+
"custom_markers": params.sctype_custom_markers,
|
|
1571
|
+
}
|
|
1572
|
+
data_hash.update(str(params_dict).encode())
|
|
1573
|
+
|
|
1574
|
+
return data_hash.hexdigest()
|
|
1575
|
+
|
|
1576
|
+
|
|
1577
|
+
def _load_sctype_functions(ctx: "ToolContext") -> None:
|
|
1578
|
+
"""Load sc-type R functions and auto-install R packages if needed."""
|
|
1579
|
+
robjects, _, _, _, _, default_converter, openrlib, _ = validate_r_environment(ctx)
|
|
1580
|
+
from rpy2.robjects import conversion
|
|
1581
|
+
|
|
1582
|
+
with openrlib.rlock:
|
|
1583
|
+
with conversion.localconverter(default_converter):
|
|
1584
|
+
robjects.r(_R_INSTALL_PACKAGES)
|
|
1585
|
+
robjects.r(_R_LOAD_SCTYPE)
|
|
1586
|
+
|
|
1587
|
+
|
|
1588
|
+
def _prepare_sctype_genesets(params: AnnotationParameters, ctx: "ToolContext"):
|
|
1589
|
+
"""Prepare gene sets for sc-type."""
|
|
1590
|
+
if params.sctype_custom_markers:
|
|
1591
|
+
return _convert_custom_markers_to_gs(params.sctype_custom_markers, ctx)
|
|
1592
|
+
|
|
1593
|
+
# Use sc-type database
|
|
1594
|
+
tissue = params.sctype_tissue
|
|
1595
|
+
if not tissue:
|
|
1596
|
+
raise ParameterError("sctype_tissue is required when not using custom markers")
|
|
1597
|
+
|
|
1598
|
+
robjects, _, _, _, _, default_converter, openrlib, _ = validate_r_environment(ctx)
|
|
1599
|
+
from rpy2.robjects import conversion
|
|
1600
|
+
|
|
1601
|
+
db_path = (
|
|
1602
|
+
params.sctype_db_
|
|
1603
|
+
or "https://raw.githubusercontent.com/IanevskiAleksandr/sc-type/master/ScTypeDB_full.xlsx"
|
|
1604
|
+
)
|
|
1605
|
+
|
|
1606
|
+
with openrlib.rlock:
|
|
1607
|
+
with conversion.localconverter(default_converter):
|
|
1608
|
+
robjects.r.assign("db_path", db_path)
|
|
1609
|
+
robjects.r.assign("tissue_type", tissue)
|
|
1610
|
+
robjects.r("gs_list <- gene_sets_prepare(db_path, tissue_type)")
|
|
1611
|
+
return robjects.r["gs_list"]
|
|
1612
|
+
|
|
1613
|
+
|
|
1614
|
+
def _convert_custom_markers_to_gs(
|
|
1615
|
+
custom_markers: dict[str, dict[str, list[str]]], ctx: "ToolContext"
|
|
1616
|
+
):
|
|
1617
|
+
"""Convert custom markers to sc-type gene set format"""
|
|
1618
|
+
if not custom_markers:
|
|
1619
|
+
raise DataError("Custom markers dictionary is empty")
|
|
1620
|
+
|
|
1621
|
+
gs_positive = {}
|
|
1622
|
+
gs_negative = {}
|
|
1623
|
+
|
|
1624
|
+
valid_celltypes = 0
|
|
1625
|
+
|
|
1626
|
+
for cell_type, markers in custom_markers.items():
|
|
1627
|
+
if not isinstance(markers, dict):
|
|
1628
|
+
continue
|
|
1629
|
+
|
|
1630
|
+
positive_genes = []
|
|
1631
|
+
negative_genes = []
|
|
1632
|
+
|
|
1633
|
+
if "positive" in markers and isinstance(markers["positive"], list):
|
|
1634
|
+
positive_genes = [
|
|
1635
|
+
str(g).strip().upper()
|
|
1636
|
+
for g in markers["positive"]
|
|
1637
|
+
if g and str(g).strip()
|
|
1638
|
+
]
|
|
1639
|
+
|
|
1640
|
+
if "negative" in markers and isinstance(markers["negative"], list):
|
|
1641
|
+
negative_genes = [
|
|
1642
|
+
str(g).strip().upper()
|
|
1643
|
+
for g in markers["negative"]
|
|
1644
|
+
if g and str(g).strip()
|
|
1645
|
+
]
|
|
1646
|
+
|
|
1647
|
+
# Only include cell types that have at least some positive markers
|
|
1648
|
+
if positive_genes:
|
|
1649
|
+
gs_positive[cell_type] = positive_genes
|
|
1650
|
+
gs_negative[cell_type] = negative_genes # Can be empty list
|
|
1651
|
+
valid_celltypes += 1
|
|
1652
|
+
|
|
1653
|
+
if valid_celltypes == 0:
|
|
1654
|
+
raise DataError(
|
|
1655
|
+
"No valid cell types found in custom markers - all cell types need at least one positive marker"
|
|
1656
|
+
)
|
|
1657
|
+
|
|
1658
|
+
# Get robjects and converters from validation
|
|
1659
|
+
robjects, pandas2ri, _, _, localconverter, default_converter, openrlib, _ = (
|
|
1660
|
+
validate_r_environment(ctx)
|
|
1661
|
+
)
|
|
1662
|
+
|
|
1663
|
+
# Wrap R calls in conversion context (FIX for contextvars issue)
|
|
1664
|
+
with openrlib.rlock:
|
|
1665
|
+
with localconverter(robjects.default_converter + pandas2ri.converter):
|
|
1666
|
+
# Convert Python dictionaries to R named lists, handle empty lists properly
|
|
1667
|
+
r_gs_positive = robjects.r["list"](
|
|
1668
|
+
**{
|
|
1669
|
+
k: robjects.StrVector(v) if v else robjects.StrVector([])
|
|
1670
|
+
for k, v in gs_positive.items()
|
|
1671
|
+
}
|
|
1672
|
+
)
|
|
1673
|
+
r_gs_negative = robjects.r["list"](
|
|
1674
|
+
**{
|
|
1675
|
+
k: robjects.StrVector(v) if v else robjects.StrVector([])
|
|
1676
|
+
for k, v in gs_negative.items()
|
|
1677
|
+
}
|
|
1678
|
+
)
|
|
1679
|
+
|
|
1680
|
+
# Create the final gs_list structure
|
|
1681
|
+
gs_list = robjects.r["list"](
|
|
1682
|
+
gs_positive=r_gs_positive, gs_negative=r_gs_negative
|
|
1683
|
+
)
|
|
1684
|
+
|
|
1685
|
+
return gs_list
|
|
1686
|
+
|
|
1687
|
+
|
|
1688
|
+
def _run_sctype_scoring(
|
|
1689
|
+
adata, gs_list, params: AnnotationParameters, ctx: "ToolContext"
|
|
1690
|
+
) -> pd.DataFrame:
|
|
1691
|
+
"""Run sc-type scoring algorithm."""
|
|
1692
|
+
robjects, pandas2ri, numpy2ri, _, _, default_converter, openrlib, anndata2ri = (
|
|
1693
|
+
validate_r_environment(ctx)
|
|
1694
|
+
)
|
|
1695
|
+
from rpy2.robjects import conversion
|
|
1696
|
+
|
|
1697
|
+
# Prepare expression data
|
|
1698
|
+
expr_data = (
|
|
1699
|
+
adata.layers["scaled"]
|
|
1700
|
+
if params.sctype_scaled and "scaled" in adata.layers
|
|
1701
|
+
else adata.X
|
|
1702
|
+
)
|
|
1703
|
+
|
|
1704
|
+
with openrlib.rlock:
|
|
1705
|
+
with conversion.localconverter(
|
|
1706
|
+
default_converter
|
|
1707
|
+
+ anndata2ri.converter
|
|
1708
|
+
+ pandas2ri.converter
|
|
1709
|
+
+ numpy2ri.converter
|
|
1710
|
+
):
|
|
1711
|
+
# Transfer data to R (genes × cells for scType)
|
|
1712
|
+
robjects.r.assign("scdata", expr_data.T)
|
|
1713
|
+
robjects.r.assign("gene_names", list(adata.var_names))
|
|
1714
|
+
robjects.r.assign("cell_names", list(adata.obs_names))
|
|
1715
|
+
robjects.r.assign("gs_list", gs_list)
|
|
1716
|
+
|
|
1717
|
+
# Run scoring using pre-defined R code
|
|
1718
|
+
robjects.r(_R_SCTYPE_SCORING)
|
|
1719
|
+
|
|
1720
|
+
# Get results
|
|
1721
|
+
row_names = list(robjects.r("rownames(es_max)"))
|
|
1722
|
+
col_names = list(robjects.r("colnames(es_max)"))
|
|
1723
|
+
scores_matrix = robjects.r["es_max"]
|
|
1724
|
+
|
|
1725
|
+
# Convert to DataFrame
|
|
1726
|
+
if isinstance(scores_matrix, pd.DataFrame):
|
|
1727
|
+
scores_df = scores_matrix
|
|
1728
|
+
scores_df.index = row_names if row_names else scores_df.index
|
|
1729
|
+
scores_df.columns = col_names if col_names else scores_df.columns
|
|
1730
|
+
else:
|
|
1731
|
+
scores_df = pd.DataFrame(scores_matrix, index=row_names, columns=col_names)
|
|
1732
|
+
|
|
1733
|
+
return scores_df
|
|
1734
|
+
|
|
1735
|
+
|
|
1736
|
+
def _softmax(scores_array: np.ndarray) -> np.ndarray:
|
|
1737
|
+
"""Compute softmax probabilities from raw scores (numerically stable)."""
|
|
1738
|
+
shifted = scores_array - np.max(scores_array)
|
|
1739
|
+
exp_scores = np.exp(shifted)
|
|
1740
|
+
return exp_scores / np.sum(exp_scores)
|
|
1741
|
+
|
|
1742
|
+
|
|
1743
|
+
def _assign_sctype_celltypes(
|
|
1744
|
+
scores_df: pd.DataFrame, ctx: "ToolContext"
|
|
1745
|
+
) -> tuple[list[str], list[float]]:
|
|
1746
|
+
"""Assign cell types based on sc-type scores using softmax confidence."""
|
|
1747
|
+
if scores_df is None or scores_df.empty:
|
|
1748
|
+
raise DataError("Scores DataFrame is empty or None")
|
|
1749
|
+
|
|
1750
|
+
cell_types = []
|
|
1751
|
+
confidence_scores = []
|
|
1752
|
+
|
|
1753
|
+
for col_name in scores_df.columns:
|
|
1754
|
+
cell_scores = scores_df[col_name]
|
|
1755
|
+
max_idx = cell_scores.idxmax()
|
|
1756
|
+
max_score = cell_scores.loc[max_idx]
|
|
1757
|
+
|
|
1758
|
+
if max_score > 0:
|
|
1759
|
+
cell_types.append(str(max_idx))
|
|
1760
|
+
# Softmax gives statistically meaningful confidence
|
|
1761
|
+
softmax_probs = _softmax(cell_scores.values)
|
|
1762
|
+
confidence_scores.append(
|
|
1763
|
+
float(softmax_probs[cell_scores.index.get_loc(max_idx)])
|
|
1764
|
+
)
|
|
1765
|
+
else:
|
|
1766
|
+
cell_types.append("Unknown")
|
|
1767
|
+
confidence_scores.append(0.0)
|
|
1768
|
+
|
|
1769
|
+
return cell_types, confidence_scores
|
|
1770
|
+
|
|
1771
|
+
|
|
1772
|
+
def _calculate_sctype_stats(cell_types: list[str]) -> dict[str, int]:
|
|
1773
|
+
"""Calculate cell type counts."""
|
|
1774
|
+
from collections import Counter
|
|
1775
|
+
|
|
1776
|
+
return dict(Counter(cell_types))
|
|
1777
|
+
|
|
1778
|
+
|
|
1779
|
+
async def _cache_sctype_results(
|
|
1780
|
+
cache_key: str, results: tuple, ctx: "ToolContext"
|
|
1781
|
+
) -> None:
|
|
1782
|
+
"""Cache sc-type results to disk as JSON (secure, no pickle)."""
|
|
1783
|
+
try:
|
|
1784
|
+
_SCTYPE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
|
1785
|
+
cache_file = _SCTYPE_CACHE_DIR / f"{cache_key}.json"
|
|
1786
|
+
|
|
1787
|
+
# Convert tuple to serializable dict
|
|
1788
|
+
cell_types, counts, confidence_by_celltype, mapping_score = results
|
|
1789
|
+
cache_data = {
|
|
1790
|
+
"cell_types": cell_types,
|
|
1791
|
+
"counts": counts,
|
|
1792
|
+
"confidence_by_celltype": confidence_by_celltype,
|
|
1793
|
+
"mapping_score": mapping_score,
|
|
1794
|
+
}
|
|
1795
|
+
|
|
1796
|
+
with open(cache_file, "w", encoding="utf-8") as f:
|
|
1797
|
+
json.dump(cache_data, f)
|
|
1798
|
+
|
|
1799
|
+
_SCTYPE_CACHE[cache_key] = results
|
|
1800
|
+
except Exception as e:
|
|
1801
|
+
await ctx.warning(f"Failed to cache results: {e}")
|
|
1802
|
+
|
|
1803
|
+
|
|
1804
|
+
def _load_cached_sctype_results(cache_key: str, ctx: "ToolContext") -> Optional[tuple]:
|
|
1805
|
+
"""Load cached sc-type results from memory or JSON file."""
|
|
1806
|
+
# Check memory cache first
|
|
1807
|
+
if cache_key in _SCTYPE_CACHE:
|
|
1808
|
+
return _SCTYPE_CACHE[cache_key]
|
|
1809
|
+
|
|
1810
|
+
# Check disk cache (JSON)
|
|
1811
|
+
cache_file = _SCTYPE_CACHE_DIR / f"{cache_key}.json"
|
|
1812
|
+
if cache_file.exists():
|
|
1813
|
+
try:
|
|
1814
|
+
with open(cache_file, "r", encoding="utf-8") as f:
|
|
1815
|
+
cache_data = json.load(f)
|
|
1816
|
+
|
|
1817
|
+
results = (
|
|
1818
|
+
cache_data["cell_types"],
|
|
1819
|
+
cache_data["counts"],
|
|
1820
|
+
cache_data["confidence_by_celltype"],
|
|
1821
|
+
cache_data.get("mapping_score"),
|
|
1822
|
+
)
|
|
1823
|
+
_SCTYPE_CACHE[cache_key] = results
|
|
1824
|
+
return results
|
|
1825
|
+
except Exception:
|
|
1826
|
+
# Cache corrupted or incompatible, will recompute
|
|
1827
|
+
pass
|
|
1828
|
+
|
|
1829
|
+
return None
|
|
1830
|
+
|
|
1831
|
+
|
|
1832
|
+
async def _annotate_with_sctype(
|
|
1833
|
+
adata: sc.AnnData,
|
|
1834
|
+
params: AnnotationParameters,
|
|
1835
|
+
ctx: "ToolContext",
|
|
1836
|
+
output_key: str,
|
|
1837
|
+
confidence_key: str,
|
|
1838
|
+
) -> AnnotationMethodOutput:
|
|
1839
|
+
"""Annotate cell types using sc-type method."""
|
|
1840
|
+
# Validate R environment
|
|
1841
|
+
validate_r_environment(ctx)
|
|
1842
|
+
|
|
1843
|
+
# Validate parameters
|
|
1844
|
+
if not params.sctype_tissue and not params.sctype_custom_markers:
|
|
1845
|
+
raise ParameterError(
|
|
1846
|
+
"Either sctype_tissue or sctype_custom_markers must be specified"
|
|
1847
|
+
)
|
|
1848
|
+
|
|
1849
|
+
if params.sctype_tissue and params.sctype_tissue not in SCTYPE_VALID_TISSUES:
|
|
1850
|
+
raise ParameterError(
|
|
1851
|
+
f"Tissue '{params.sctype_tissue}' not supported. "
|
|
1852
|
+
f"Valid: {', '.join(sorted(SCTYPE_VALID_TISSUES))}"
|
|
1853
|
+
)
|
|
1854
|
+
|
|
1855
|
+
# Check cache
|
|
1856
|
+
cache_key = None
|
|
1857
|
+
if params.sctype_use_cache:
|
|
1858
|
+
cache_key = _get_sctype_cache_key(adata, params)
|
|
1859
|
+
cached = _load_cached_sctype_results(cache_key, ctx)
|
|
1860
|
+
if cached:
|
|
1861
|
+
# Convert cached tuple to AnnotationMethodOutput
|
|
1862
|
+
cell_types, counts, confidence, _ = cached
|
|
1863
|
+
# Still need to store in adata.obs when using cache
|
|
1864
|
+
adata.obs[output_key] = pd.Categorical(cell_types)
|
|
1865
|
+
return AnnotationMethodOutput(
|
|
1866
|
+
cell_types=cell_types,
|
|
1867
|
+
counts=counts,
|
|
1868
|
+
confidence=confidence,
|
|
1869
|
+
)
|
|
1870
|
+
|
|
1871
|
+
# Run sc-type pipeline
|
|
1872
|
+
_load_sctype_functions(ctx)
|
|
1873
|
+
gs_list = _prepare_sctype_genesets(params, ctx)
|
|
1874
|
+
scores_df = _run_sctype_scoring(adata, gs_list, params, ctx)
|
|
1875
|
+
per_cell_types, per_cell_confidence = _assign_sctype_celltypes(scores_df, ctx)
|
|
1876
|
+
|
|
1877
|
+
# Calculate statistics
|
|
1878
|
+
counts = _calculate_sctype_stats(per_cell_types)
|
|
1879
|
+
|
|
1880
|
+
# Average confidence per cell type (for return value)
|
|
1881
|
+
confidence_by_celltype = {}
|
|
1882
|
+
for ct in set(per_cell_types):
|
|
1883
|
+
ct_confs = [
|
|
1884
|
+
c for i, c in enumerate(per_cell_confidence) if per_cell_types[i] == ct
|
|
1885
|
+
]
|
|
1886
|
+
confidence_by_celltype[ct] = sum(ct_confs) / len(ct_confs) if ct_confs else 0.0
|
|
1887
|
+
|
|
1888
|
+
# Store in adata.obs (keys provided by caller)
|
|
1889
|
+
adata.obs[output_key] = pd.Categorical(per_cell_types)
|
|
1890
|
+
adata.obs[confidence_key] = per_cell_confidence
|
|
1891
|
+
|
|
1892
|
+
unique_cell_types = list(set(per_cell_types))
|
|
1893
|
+
|
|
1894
|
+
# Cache results (as tuple for compatibility)
|
|
1895
|
+
if params.sctype_use_cache and cache_key:
|
|
1896
|
+
cache_tuple = (unique_cell_types, counts, confidence_by_celltype, None)
|
|
1897
|
+
await _cache_sctype_results(cache_key, cache_tuple, ctx)
|
|
1898
|
+
|
|
1899
|
+
return AnnotationMethodOutput(
|
|
1900
|
+
cell_types=unique_cell_types,
|
|
1901
|
+
counts=counts,
|
|
1902
|
+
confidence=confidence_by_celltype,
|
|
1903
|
+
)
|