chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1903 @@
1
+ """
2
+ Cell type annotation tools for spatial transcriptomics data.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import hashlib
8
+ import json
9
+ from pathlib import Path
10
+ from typing import TYPE_CHECKING, Any, NamedTuple, Optional
11
+
12
+ import numpy as np
13
+ import pandas as pd
14
+ import scanpy as sc
15
+
16
+ if TYPE_CHECKING:
17
+ from ..spatial_mcp_adapter import ToolContext
18
+
19
+ from ..models.analysis import AnnotationResult
20
+ from ..models.data import AnnotationParameters
21
+ from ..utils.adata_utils import (
22
+ ensure_categorical,
23
+ ensure_counts_layer,
24
+ ensure_unique_var_names_async,
25
+ find_common_genes,
26
+ get_cell_type_key,
27
+ get_cluster_key,
28
+ get_spatial_key,
29
+ to_dense,
30
+ validate_obs_column,
31
+ )
32
+ from ..utils.dependency_manager import (
33
+ is_available,
34
+ require,
35
+ validate_r_environment,
36
+ validate_scvi_tools,
37
+ )
38
+ from ..utils.device_utils import cuda_available
39
+ from ..utils.exceptions import (
40
+ DataError,
41
+ DataNotFoundError,
42
+ ParameterError,
43
+ ProcessingError,
44
+ )
45
+
46
+
47
+ class AnnotationMethodOutput(NamedTuple):
48
+ """Unified output from all annotation methods.
49
+
50
+ This provides a consistent return type across all annotation methods,
51
+ improving code clarity and preventing positional argument confusion.
52
+
53
+ Attributes:
54
+ cell_types: List of unique cell type names identified (deduplicated)
55
+ counts: Mapping of cell type names to number of cells assigned
56
+ confidence: Mapping of cell type names to confidence scores.
57
+ Empty dict indicates no confidence data available.
58
+ mapping_score: Optional method-specific quality score (e.g., Tangram mapping score)
59
+ """
60
+
61
+ cell_types: list[str]
62
+ counts: dict[str, int]
63
+ confidence: dict[str, float]
64
+ mapping_score: Optional[float] = None
65
+
66
+
67
+ # Supported annotation methods
68
+ # Confidence behavior by method:
69
+ # - singler/tangram/sctype: Real confidence scores (correlation/probability/scoring)
70
+ # - scanvi/cellassign: Partial confidence (when soft prediction available)
71
+ # - mllmcelltype: No numeric confidence (LLM-based)
72
+ SUPPORTED_METHODS = {
73
+ "tangram",
74
+ "scanvi",
75
+ "cellassign",
76
+ "mllmcelltype",
77
+ "sctype",
78
+ "singler",
79
+ }
80
+
81
+
82
+ async def _annotate_with_singler(
83
+ adata,
84
+ params: AnnotationParameters,
85
+ ctx: "ToolContext",
86
+ output_key: str,
87
+ confidence_key: str,
88
+ reference_adata: Optional[Any] = None,
89
+ ) -> AnnotationMethodOutput:
90
+ """Annotate cell types using SingleR reference-based method"""
91
+ # Validate and import dependencies
92
+ require("singler", ctx, feature="SingleR annotation")
93
+ require("singlecellexperiment", ctx, feature="SingleR annotation")
94
+ import singler
95
+
96
+ # Optional: check for celldex
97
+ celldex = None
98
+ if is_available("celldex"):
99
+ import celldex
100
+
101
+ # Get expression matrix - prefer normalized data
102
+ # IMPORTANT: Ensure test_mat dimensions match adata.var_names (used in test_features)
103
+ if "X_normalized" in adata.layers:
104
+ test_mat = adata.layers["X_normalized"]
105
+ elif adata.X is not None:
106
+ test_mat = adata.X
107
+ else:
108
+ # Fallback: use raw data, but subset to current var_names to ensure dimension match
109
+ # Note: adata.raw may have full genes while adata has HVG subset
110
+ if adata.raw is not None:
111
+ test_mat = adata.raw[:, adata.var_names].X
112
+ else:
113
+ test_mat = adata.X
114
+
115
+ # MEMORY OPTIMIZATION: SingleR (singler-py) natively supports sparse matrices
116
+ # No toarray() needed - both np.log1p() and .T() work with sparse matrices
117
+ # Verified: sparse and dense inputs produce identical results
118
+ # Memory savings: ~1.3 GB for typical 10K cells × 20K genes dataset
119
+
120
+ # Ensure log-normalization (SingleR expects log-normalized data)
121
+ if "log1p" not in adata.uns:
122
+ await ctx.warning(
123
+ "Data may not be log-normalized. Applying log1p for SingleR..."
124
+ )
125
+ test_mat = np.log1p(test_mat)
126
+
127
+ # Transpose for SingleR (genes x cells)
128
+ test_mat = test_mat.T
129
+
130
+ # Ensure gene names are strings
131
+ test_features = [str(x) for x in adata.var_names]
132
+
133
+ # Prepare reference
134
+ reference_name = getattr(params, "singler_reference", None)
135
+ reference_data_id = getattr(params, "reference_data_id", None)
136
+
137
+ ref_data = None
138
+ ref_labels = None
139
+ ref_features_to_use = None # Only set when using custom reference (not celldex)
140
+
141
+ # Priority: reference_name > reference_data_id > default
142
+ if reference_name and celldex:
143
+ ref = celldex.fetch_reference(reference_name, "2024-02-26", realize_assays=True)
144
+ # Get labels
145
+ for label_col in ["label.main", "label.fine", "cell_type"]:
146
+ try:
147
+ ref_labels = ref.get_column_data().column(label_col)
148
+ break
149
+ except Exception:
150
+ continue # Try next label column
151
+ if ref_labels is None:
152
+ raise DataNotFoundError(
153
+ f"Could not find labels in reference {reference_name}"
154
+ )
155
+ ref_data = ref
156
+
157
+ elif reference_data_id and reference_adata is not None:
158
+ # Use provided reference data (passed from main function via ctx.get_adata())
159
+ # Handle duplicate gene names
160
+ await ensure_unique_var_names_async(reference_adata, ctx, "reference data")
161
+ if await ensure_unique_var_names_async(adata, ctx, "query data") > 0:
162
+ # Update test_features after fixing
163
+ test_features = [str(x) for x in adata.var_names]
164
+
165
+ # Get reference expression matrix
166
+ if "X_normalized" in reference_adata.layers:
167
+ ref_mat = reference_adata.layers["X_normalized"]
168
+ else:
169
+ ref_mat = reference_adata.X
170
+
171
+ # MEMORY OPTIMIZATION: SingleR (singler-py) natively supports sparse matrices
172
+ # No toarray() needed - both np.log1p() and .T() work with sparse matrices
173
+ # Verified: sparse and dense inputs produce identical results
174
+ # Memory savings: ~1.3 GB for typical 10K cells × 20K genes reference dataset
175
+
176
+ # Ensure log-normalization for reference
177
+ if "log1p" not in reference_adata.uns:
178
+ await ctx.warning(
179
+ "Reference data may not be log-normalized. Applying log1p..."
180
+ )
181
+ ref_mat = np.log1p(ref_mat)
182
+
183
+ # Transpose for SingleR (genes x cells)
184
+ ref_mat = ref_mat.T
185
+ ref_features = [str(x) for x in reference_adata.var_names]
186
+
187
+ # Check gene overlap
188
+ common_genes = find_common_genes(test_features, ref_features)
189
+
190
+ if len(common_genes) < min(50, len(test_features) * 0.1):
191
+ raise DataError(
192
+ f"Insufficient gene overlap for SingleR: only {len(common_genes)} common genes "
193
+ f"(test: {len(test_features)}, reference: {len(ref_features)})"
194
+ )
195
+
196
+ # Get labels from reference - check various common column names
197
+ # cell_type_key is now required (no default value)
198
+ cell_type_key = params.cell_type_key
199
+
200
+ validate_obs_column(reference_adata, cell_type_key, "Cell type")
201
+
202
+ ref_labels = list(reference_adata.obs[cell_type_key])
203
+
204
+ # For SingleR, pass the actual expression matrix directly (not SCE)
205
+ # This has been shown to work better in testing
206
+ ref_data = ref_mat
207
+ ref_features_to_use = ref_features # Keep reference features for gene matching
208
+
209
+ elif celldex:
210
+ # Use default reference
211
+ ref = celldex.fetch_reference(
212
+ "blueprint_encode", "2024-02-26", realize_assays=True
213
+ )
214
+ ref_labels = ref.get_column_data().column("label.main")
215
+ ref_data = ref
216
+ else:
217
+ raise DataNotFoundError(
218
+ "No reference data. Provide reference_data_id or singler_reference."
219
+ )
220
+
221
+ # Run SingleR annotation
222
+ use_integrated = getattr(params, "singler_integrated", False)
223
+ num_threads = getattr(params, "num_threads", 4)
224
+
225
+ if use_integrated and isinstance(ref_data, list):
226
+ single_results, integrated = singler.annotate_integrated(
227
+ test_mat,
228
+ ref_data=ref_data,
229
+ ref_labels=ref_labels,
230
+ test_features=test_features,
231
+ num_threads=num_threads,
232
+ )
233
+ best_labels = integrated.column("best_label")
234
+ scores = integrated.column("scores")
235
+ else:
236
+ # Build kwargs for annotate_single
237
+ annotate_kwargs = {
238
+ "test_data": test_mat,
239
+ "test_features": test_features,
240
+ "ref_data": ref_data,
241
+ "ref_labels": ref_labels,
242
+ "num_threads": num_threads,
243
+ }
244
+
245
+ # Add ref_features if we're using custom reference data (not celldex)
246
+ if ref_features_to_use is not None:
247
+ annotate_kwargs["ref_features"] = ref_features_to_use
248
+
249
+ results = singler.annotate_single(**annotate_kwargs)
250
+ best_labels = results.column("best")
251
+ scores = results.column("scores")
252
+
253
+ # Try to get delta scores for confidence (higher delta = higher confidence)
254
+ try:
255
+ delta_scores = results.column("delta")
256
+ if delta_scores:
257
+ low_delta = sum(1 for d in delta_scores if d and d < 0.05)
258
+ if low_delta > len(delta_scores) * 0.3:
259
+ await ctx.warning(
260
+ f"{low_delta}/{len(delta_scores)} cells have low confidence scores (delta < 0.05)"
261
+ )
262
+ except Exception:
263
+ delta_scores = None
264
+
265
+ # Process results
266
+ cell_types = list(best_labels)
267
+ unique_types = list(set(cell_types))
268
+ counts = pd.Series(cell_types).value_counts().to_dict()
269
+
270
+ # Calculate confidence scores - prefer delta scores if available
271
+ # IMPORTANT: Different scores have different mathematical semantics
272
+ # - delta scores: gap between best and second-best match, range [0, +∞)
273
+ # - correlation scores: Pearson correlation, range [-1, 1]
274
+ # We apply scientifically appropriate transformations to [0, 1]
275
+ confidence_scores = {}
276
+
277
+ # First try to use delta scores (more meaningful confidence measure)
278
+ # delta_scores is always defined by the try-except block above (line 429 or 437)
279
+ if delta_scores is not None:
280
+ try:
281
+ for cell_type in unique_types:
282
+ type_indices = [i for i, ct in enumerate(cell_types) if ct == cell_type]
283
+ if type_indices:
284
+ type_deltas = [
285
+ delta_scores[i] for i in type_indices if i < len(delta_scores)
286
+ ]
287
+ if type_deltas:
288
+ avg_delta = np.mean([d for d in type_deltas if d is not None])
289
+ # Transform delta to [0, 1] using saturating function
290
+ # delta=0 → 0 (no discrimination = zero confidence)
291
+ # delta→∞ → 1 (perfect discrimination = full confidence)
292
+ confidence = 1.0 - np.exp(-avg_delta)
293
+ confidence_scores[cell_type] = round(float(confidence), 3)
294
+ except Exception:
295
+ # Delta score extraction failed, will fall back to regular scores
296
+ pass
297
+
298
+ # Fall back to correlation scores if delta not available
299
+ if not confidence_scores and scores is not None:
300
+ try:
301
+ scores_df = pd.DataFrame(scores.to_dict())
302
+ except AttributeError:
303
+ scores_df = pd.DataFrame(
304
+ scores.to_numpy() if hasattr(scores, "to_numpy") else scores
305
+ )
306
+
307
+ for cell_type in unique_types:
308
+ mask = [ct == cell_type for ct in cell_types]
309
+ if cell_type in scores_df.columns and any(mask):
310
+ type_scores = scores_df.loc[mask, cell_type]
311
+ avg_score = type_scores.mean()
312
+ # Use max(0, r) instead of (r+1)/2 for correlation
313
+ # r<0 (negative correlation) → 0 (opposite pattern = not a match)
314
+ # r=0 → 0 (no correlation = zero confidence)
315
+ # r=1 → 1 (perfect correlation = full confidence)
316
+ confidence = max(0.0, float(avg_score))
317
+ confidence_scores[cell_type] = round(confidence, 3)
318
+ # else: cell type won't have confidence score (no action needed)
319
+
320
+ # Add to AnnData (keys provided by caller for single-point control)
321
+ adata.obs[output_key] = cell_types
322
+ ensure_categorical(adata, output_key)
323
+
324
+ # Only add confidence column if we have real confidence values
325
+ if confidence_scores:
326
+ # Use 0.0 for cells without confidence (more honest than arbitrary 0.5)
327
+ confidence_array = [confidence_scores.get(ct, 0.0) for ct in cell_types]
328
+ adata.obs[confidence_key] = confidence_array
329
+
330
+ return AnnotationMethodOutput(
331
+ cell_types=unique_types,
332
+ counts=counts,
333
+ confidence=confidence_scores,
334
+ )
335
+
336
+
337
+ async def _annotate_with_tangram(
338
+ adata,
339
+ params: AnnotationParameters,
340
+ ctx: "ToolContext",
341
+ output_key: str,
342
+ confidence_key: str,
343
+ reference_adata: Optional[Any] = None,
344
+ ) -> AnnotationMethodOutput:
345
+ """Annotate cell types using Tangram method"""
346
+ # Validate dependencies with comprehensive error reporting
347
+ require("tangram", ctx, feature="Tangram annotation")
348
+ import tangram as tg
349
+
350
+ # Check if reference data is provided
351
+ if reference_adata is None:
352
+ raise ParameterError("Tangram requires reference_data_id parameter.")
353
+
354
+ # Use reference single-cell data (passed from main function via ctx.get_adata())
355
+ adata_sc_original = reference_adata
356
+
357
+ # ===== CRITICAL FIX: Use raw data for Tangram to preserve gene name case =====
358
+ # Issue: Preprocessed data may have lowercase gene names, while reference has uppercase
359
+ # This causes 0 overlapping genes and Tangram mapping failure (all NaN results)
360
+ # Solution: Use adata.raw which preserves original gene names
361
+ if adata.raw is not None:
362
+ adata_sp = adata.raw.to_adata()
363
+ # Preserve spatial coordinates from preprocessed data
364
+ spatial_key = get_spatial_key(adata)
365
+ if spatial_key:
366
+ adata_sp.obsm[spatial_key] = adata.obsm[spatial_key].copy()
367
+ else:
368
+ adata_sp = adata
369
+ await ctx.warning(
370
+ "No raw data available - using preprocessed data (may have gene name mismatches)"
371
+ )
372
+ # =============================================================================
373
+
374
+ # Handle duplicate gene names
375
+ await ensure_unique_var_names_async(adata_sc_original, ctx, "reference data")
376
+ await ensure_unique_var_names_async(adata_sp, ctx, "spatial data")
377
+
378
+ # Determine training genes
379
+ training_genes = params.training_genes
380
+
381
+ if training_genes is None:
382
+ # Use marker genes if available
383
+ if params.marker_genes:
384
+ # Flatten marker genes dictionary
385
+ training_genes = []
386
+ for genes in params.marker_genes.values():
387
+ training_genes.extend(genes)
388
+ training_genes = list(set(training_genes)) # Remove duplicates
389
+ else:
390
+ # Use highly variable genes
391
+ if "highly_variable" not in adata_sc_original.var:
392
+ raise DataNotFoundError(
393
+ "HVGs not found in reference data. Run preprocessing first."
394
+ )
395
+ training_genes = list(
396
+ adata_sc_original.var_names[adata_sc_original.var.highly_variable]
397
+ )
398
+
399
+ # COW FIX: Create copy of reference data to avoid modifying original
400
+ # Tangram's pp_adatas adds metadata (uns, obs) but doesn't subset genes
401
+ adata_sc = adata_sc_original.copy()
402
+
403
+ # Preprocess data for Tangram
404
+ tg.pp_adatas(adata_sc, adata_sp, genes=training_genes)
405
+
406
+ # Set mapping mode
407
+ mode = params.tangram_mode
408
+ cluster_label = params.cluster_label
409
+
410
+ if mode == "clusters" and cluster_label is None:
411
+ await ctx.warning(
412
+ "Cluster label not provided for 'clusters' mode. Using default cell type annotation if available."
413
+ )
414
+ # Try to find a cell type or cluster annotation in the reference data
415
+ cluster_label = get_cell_type_key(adata_sc) or get_cluster_key(adata_sc)
416
+
417
+ if cluster_label is None:
418
+ raise ParameterError(
419
+ "No cluster label found. Provide cluster_label parameter."
420
+ )
421
+
422
+ # Check GPU availability for device selection
423
+ device = params.tangram_device
424
+ if device != "cpu" and not cuda_available():
425
+ await ctx.warning("GPU requested but not available - falling back to CPU")
426
+ device = "cpu"
427
+
428
+ # Run Tangram mapping with enhanced parameters
429
+ mapping_args = {
430
+ "mode": mode,
431
+ "num_epochs": params.num_epochs,
432
+ "device": device,
433
+ "density_prior": params.tangram_density_prior, # Add density prior
434
+ "learning_rate": params.tangram_learning_rate, # Add learning rate
435
+ }
436
+
437
+ # Add optional regularization parameters
438
+ if params.tangram_lambda_r is not None:
439
+ mapping_args["lambda_r"] = params.tangram_lambda_r
440
+
441
+ if params.tangram_lambda_neighborhood is not None:
442
+ mapping_args["lambda_neighborhood"] = params.tangram_lambda_neighborhood
443
+
444
+ if mode == "clusters":
445
+ mapping_args["cluster_label"] = cluster_label
446
+
447
+ ad_map = tg.map_cells_to_space(adata_sc, adata_sp, **mapping_args)
448
+
449
+ # Get mapping score from training history
450
+ tangram_mapping_score = 0.0 # Default score
451
+ try:
452
+ if "training_history" in ad_map.uns:
453
+ history = ad_map.uns["training_history"]
454
+
455
+ # Extract score from main_loss (which is actually a similarity score, higher is better)
456
+ if (
457
+ isinstance(history, dict)
458
+ and "main_loss" in history
459
+ and len(history["main_loss"]) > 0
460
+ ):
461
+ import re
462
+
463
+ last_value = history["main_loss"][-1]
464
+
465
+ # Extract value from tensor string if needed
466
+ if isinstance(last_value, str):
467
+ # Handle tensor string format: 'tensor(0.9050, grad_fn=...)'
468
+ match = re.search(r"tensor\(([-\d.]+)", last_value)
469
+ if match:
470
+ tangram_mapping_score = float(match.group(1))
471
+ else:
472
+ # Try direct conversion
473
+ try:
474
+ tangram_mapping_score = float(last_value)
475
+ except Exception:
476
+ tangram_mapping_score = 0.0
477
+ else:
478
+ tangram_mapping_score = float(last_value)
479
+
480
+ else:
481
+ error_msg = (
482
+ f"Tangram history format not recognized: {type(history).__name__}. "
483
+ f"Upgrade tangram-sc: pip install --upgrade tangram-sc"
484
+ )
485
+ raise ProcessingError(error_msg)
486
+ except Exception as score_error:
487
+ raise ProcessingError(
488
+ f"Tangram mapping completed but score extraction failed: {score_error}"
489
+ ) from score_error
490
+
491
+ # Compute validation metrics if requested
492
+ if params.tangram_compute_validation:
493
+ try:
494
+ scores = tg.compare_spatial_geneexp(ad_map, adata_sp)
495
+ adata_sp.uns["tangram_validation_scores"] = scores
496
+ except Exception as val_error:
497
+ await ctx.warning(f"Could not compute validation metrics: {val_error}")
498
+
499
+ # Project genes if requested
500
+ if params.tangram_project_genes:
501
+ try:
502
+ ad_ge = tg.project_genes(ad_map, adata_sc)
503
+ adata_sp.obsm["tangram_gene_predictions"] = ad_ge.X
504
+ except Exception as gene_error:
505
+ await ctx.warning(f"Could not project genes: {gene_error}")
506
+
507
+ # Project cell annotations to space using proper API function
508
+ try:
509
+ # Determine annotation column
510
+ annotation_col = None
511
+ if mode == "clusters" and cluster_label:
512
+ annotation_col = cluster_label
513
+ else:
514
+ # cell_type_key is now required (no auto-detect)
515
+ if params.cell_type_key not in adata_sc.obs:
516
+ # Improved error message showing available columns
517
+ available_cols = list(adata_sc.obs.columns)
518
+ categorical_cols = [
519
+ col
520
+ for col in available_cols
521
+ if adata_sc.obs[col].dtype.name in ["object", "category"]
522
+ ]
523
+
524
+ raise ParameterError(
525
+ f"Cell type column '{params.cell_type_key}' not found. "
526
+ f"Available: {categorical_cols[:5]}"
527
+ )
528
+
529
+ annotation_col = params.cell_type_key
530
+
531
+ # annotation_col is guaranteed to be set (either from cluster_label or cell_type_key)
532
+ tg.project_cell_annotations(ad_map, adata_sp, annotation=annotation_col)
533
+ except Exception as proj_error:
534
+ await ctx.warning(f"Could not project cell annotations: {proj_error}")
535
+ # Continue without projection
536
+
537
+ # Get cell type predictions (keys provided by caller for single-point control)
538
+ cell_types = []
539
+ counts = {}
540
+ confidence_scores = {}
541
+
542
+ if "tangram_ct_pred" in adata_sp.obsm:
543
+ cell_type_df = adata_sp.obsm["tangram_ct_pred"]
544
+
545
+ # Get cell types and counts
546
+ cell_types = list(cell_type_df.columns)
547
+
548
+ # ===== CRITICAL FIX: Row normalization for proper probability calculation =====
549
+ # tangram_ct_pred contains unnormalized density/abundance values, NOT probabilities
550
+ # Row sums can be != 1.0 and values can exceed 1.0
551
+ # We normalize to convert densities → probability distributions
552
+ cell_type_prob = cell_type_df.div(cell_type_df.sum(axis=1), axis=0)
553
+
554
+ # Validation: Ensure normalized values are valid probabilities
555
+ if not (cell_type_prob.values >= 0).all():
556
+ await ctx.warning(
557
+ "Some normalized probabilities are negative - data quality issue"
558
+ )
559
+ if not (cell_type_prob.values <= 1.0).all():
560
+ await ctx.warning(
561
+ "Some normalized probabilities exceed 1.0 - normalization failed"
562
+ )
563
+ if not np.allclose(cell_type_prob.sum(axis=1), 1.0):
564
+ await ctx.warning(
565
+ "Row sums don't equal 1.0 after normalization - numerical issue"
566
+ )
567
+
568
+ # Assign cell type based on highest probability (argmax is same before/after normalization)
569
+ adata_sp.obs[output_key] = cell_type_prob.idxmax(axis=1)
570
+ ensure_categorical(adata_sp, output_key)
571
+
572
+ # Get counts
573
+ counts = adata_sp.obs[output_key].value_counts().to_dict()
574
+
575
+ # Calculate confidence scores from NORMALIZED probabilities
576
+ confidence_scores = {}
577
+ for cell_type in cell_types:
578
+ cells_of_type = adata_sp.obs[output_key] == cell_type
579
+ if np.sum(cells_of_type) > 0:
580
+ # Use mean PROBABILITY as confidence (now guaranteed to be in [0, 1])
581
+ mean_prob = cell_type_prob.loc[cells_of_type, cell_type].mean()
582
+ confidence_scores[cell_type] = round(float(mean_prob), 3)
583
+
584
+ else:
585
+ await ctx.warning("No cell type predictions found in Tangram results")
586
+
587
+ # Validate results before returning
588
+ if not cell_types:
589
+ raise ProcessingError(
590
+ "Tangram mapping failed - no cell type predictions generated"
591
+ )
592
+
593
+ if tangram_mapping_score <= 0:
594
+ await ctx.warning(
595
+ f"Tangram mapping score is suspiciously low: {tangram_mapping_score}"
596
+ )
597
+
598
+ # ===== Copy results from adata_sp back to original adata =====
599
+ # Since adata_sp was created from adata.raw (different object), we need to
600
+ # transfer the Tangram results back to the original adata for downstream use
601
+ if adata_sp is not adata:
602
+ # Copy cell type assignments
603
+ if output_key in adata_sp.obs:
604
+ adata.obs[output_key] = adata_sp.obs[output_key]
605
+
606
+ # Copy tangram_ct_pred from obsm
607
+ if "tangram_ct_pred" in adata_sp.obsm:
608
+ adata.obsm["tangram_ct_pred"] = adata_sp.obsm["tangram_ct_pred"]
609
+
610
+ # Copy tangram_gene_predictions if they exist
611
+ if "tangram_gene_predictions" in adata_sp.obsm:
612
+ adata.obsm["tangram_gene_predictions"] = adata_sp.obsm[
613
+ "tangram_gene_predictions"
614
+ ]
615
+
616
+ return AnnotationMethodOutput(
617
+ cell_types=cell_types,
618
+ counts=counts,
619
+ confidence=confidence_scores,
620
+ mapping_score=tangram_mapping_score,
621
+ )
622
+
623
+
624
+ async def _annotate_with_scanvi(
625
+ adata,
626
+ params: AnnotationParameters,
627
+ ctx: "ToolContext",
628
+ output_key: str,
629
+ confidence_key: str,
630
+ reference_adata: Optional[Any] = None,
631
+ ) -> AnnotationMethodOutput:
632
+ """Annotate cell types using scANVI (semi-supervised variational inference).
633
+
634
+ scANVI (single-cell ANnotation using Variational Inference) is a deep learning
635
+ method for transferring cell type labels from reference to query data using
636
+ semi-supervised learning with variational autoencoders.
637
+
638
+ Official Implementation: scvi-tools (https://scvi-tools.org)
639
+ Reference: Xu et al. (2021) "Probabilistic harmonization and annotation of
640
+ single-cell transcriptomics data with deep generative models"
641
+
642
+ Method Overview:
643
+ 1. Trains on reference data with known cell type labels
644
+ 2. Learns shared latent representation between reference and query
645
+ 3. Transfers labels to query data via probabilistic predictions
646
+ 4. Supports batch correction and semi-supervised training
647
+
648
+ Requirements:
649
+ - reference_data_id: Must point to preprocessed single-cell reference data
650
+ - cell_type_key: Column in reference data containing cell type labels
651
+ - Both datasets must have 'counts' layer (raw counts, not normalized)
652
+ - Sufficient gene overlap between reference and query data
653
+
654
+ Parameters (via AnnotationParameters):
655
+ Core Architecture:
656
+ - scanvi_n_latent (default: 10): Latent space dimensions
657
+ - scanvi_n_hidden (default: 128): Hidden layer units
658
+ - scanvi_n_layers (default: 1): Number of layers
659
+ - scanvi_dropout_rate (default: 0.1): Dropout for regularization
660
+
661
+ Training Strategy:
662
+ - scanvi_use_scvi_pretrain (default: True): Use SCVI pretraining
663
+ - scanvi_scvi_epochs (default: 200): SCVI pretraining epochs
664
+ - num_epochs (default: 100): SCANVI training epochs
665
+ - scanvi_query_epochs (default: 100): Query data training epochs
666
+
667
+ Advanced:
668
+ - scanvi_unlabeled_category (default: "Unknown"): Label for unlabeled cells
669
+ - scanvi_n_samples_per_label (default: 100): Samples per label
670
+ - batch_key: For batch correction (optional)
671
+
672
+ Official Recommendations (scvi-tools):
673
+ For large integration tasks:
674
+ - scanvi_n_layers: 2
675
+ - scanvi_n_latent: 30
676
+ - scanvi_scvi_epochs: 300 (SCVI pretraining)
677
+ - num_epochs: 100 (SCANVI training)
678
+ - scanvi_query_epochs: 100
679
+ - Gene selection: 1000-10000 HVGs recommended
680
+
681
+ Empirical Adjustments (not official):
682
+ For small datasets (<1000 genes or <1000 cells):
683
+ - scanvi_n_latent: 3-5 (may prevent NaN/gradient explosion)
684
+ - scanvi_dropout_rate: 0.2-0.3 (may improve regularization)
685
+ - scanvi_use_scvi_pretrain: False (may reduce complexity)
686
+ - num_epochs: 50 (may prevent overfitting)
687
+ - scanvi_query_epochs: 50
688
+
689
+ Common Issues:
690
+ - NaN errors during training: Try reducing n_latent or increasing dropout_rate
691
+ - Low confidence scores: Try increasing training epochs or check gene overlap
692
+ - Memory issues: Reduce batch size or use GPU
693
+
694
+ Returns:
695
+ Tuple of (cell_types, counts, confidence_scores, None):
696
+ - cell_types: List of predicted cell type categories
697
+ - counts: Dict mapping cell types to number of cells
698
+ - confidence_scores: Dict mapping cell types to mean prediction probability
699
+ - None: (compatibility placeholder)
700
+
701
+ Example:
702
+ params = AnnotationParameters(
703
+ method="scanvi",
704
+ reference_data_id="reference_sc",
705
+ cell_type_key="cell_types",
706
+ scanvi_n_latent=5, # For small dataset
707
+ scanvi_dropout_rate=0.2, # Better regularization
708
+ scanvi_use_scvi_pretrain=False, # Simpler training
709
+ num_epochs=50, # Prevent overfitting
710
+ )
711
+ """
712
+
713
+ # Validate dependencies with comprehensive error reporting
714
+ scvi = validate_scvi_tools(ctx, components=["SCANVI"])
715
+
716
+ # Check if reference data is provided
717
+ if reference_adata is None:
718
+ raise ParameterError("scANVI requires reference_data_id parameter.")
719
+
720
+ # Use reference single-cell data (passed from main function via ctx.get_adata())
721
+ adata_ref_original = reference_adata
722
+
723
+ # Handle duplicate gene names
724
+ await ensure_unique_var_names_async(adata_ref_original, ctx, "reference data")
725
+ await ensure_unique_var_names_async(adata, ctx, "query data")
726
+
727
+ # Gene alignment
728
+ common_genes = find_common_genes(adata_ref_original.var_names, adata.var_names)
729
+
730
+ if len(common_genes) < min(100, adata_ref_original.n_vars * 0.5):
731
+ raise DataError(
732
+ f"Insufficient gene overlap: Only {len(common_genes)} common genes found. "
733
+ f"Reference has {adata_ref_original.n_vars}, query has {adata.n_vars} genes."
734
+ )
735
+
736
+ # COW FIX: Operate on temporary copies for gene subsetting
737
+ # This prevents loss of HVG information in the original adata
738
+ if len(common_genes) < adata_ref_original.n_vars:
739
+ await ctx.warning(
740
+ f"Subsetting to {len(common_genes)} common genes for ScanVI training "
741
+ f"(reference: {adata_ref_original.n_vars}, query: {adata.n_vars})"
742
+ )
743
+ # Create subsets for ScanVI (not modifying originals)
744
+ adata_ref = adata_ref_original[:, common_genes].copy()
745
+ adata_subset = adata[:, common_genes].copy()
746
+ else:
747
+ # No subsetting needed
748
+ adata_ref = adata_ref_original.copy()
749
+ adata_subset = adata.copy()
750
+
751
+ # Data validation
752
+ if "log1p" not in adata_ref.uns:
753
+ await ctx.warning("Reference data may not be log-normalized")
754
+ if "highly_variable" not in adata_ref.var:
755
+ await ctx.warning("No highly variable genes detected in reference")
756
+
757
+ # Get parameters
758
+ cell_type_key = getattr(params, "cell_type_key", "cell_type")
759
+ batch_key = getattr(params, "batch_key", None)
760
+
761
+ # Optional SCVI Pretraining
762
+ if params.scanvi_use_scvi_pretrain:
763
+ # Setup for SCVI with labels (required for SCANVI conversion)
764
+ # First ensure the reference has the cell type labels
765
+ validate_obs_column(
766
+ adata_ref, cell_type_key, "Cell type column (reference data)"
767
+ )
768
+
769
+ # SCVI needs to know about labels for later SCANVI conversion
770
+ scvi.model.SCVI.setup_anndata(
771
+ adata_ref,
772
+ labels_key=cell_type_key, # Important: include labels_key
773
+ batch_key=batch_key,
774
+ layer=params.layer,
775
+ )
776
+
777
+ # Train SCVI
778
+ scvi_model = scvi.model.SCVI(
779
+ adata_ref,
780
+ n_latent=params.scanvi_n_latent,
781
+ n_hidden=params.scanvi_n_hidden,
782
+ n_layers=params.scanvi_n_layers,
783
+ dropout_rate=params.scanvi_dropout_rate,
784
+ )
785
+
786
+ scvi_model.train(
787
+ max_epochs=params.scanvi_scvi_epochs,
788
+ early_stopping=True,
789
+ check_val_every_n_epoch=params.scanvi_check_val_every_n_epoch,
790
+ )
791
+
792
+ # Convert to SCANVI (no need for setup_anndata, it uses SCVI's setup)
793
+ model = scvi.model.SCANVI.from_scvi_model(
794
+ scvi_model, params.scanvi_unlabeled_category
795
+ )
796
+
797
+ # Train SCANVI (fewer epochs needed after pretraining)
798
+ # Use configurable epochs (default: 20, official recommendation after pretraining)
799
+ model.train(
800
+ max_epochs=params.scanvi_scanvi_epochs,
801
+ n_samples_per_label=params.scanvi_n_samples_per_label,
802
+ early_stopping=True,
803
+ )
804
+
805
+ else:
806
+ # Direct SCANVI training (existing approach)
807
+ # Ensure counts layer exists (create from adata.raw if needed)
808
+ ensure_counts_layer(
809
+ adata_ref,
810
+ error_message="scANVI requires raw counts in layers['counts'].",
811
+ )
812
+
813
+ # Setup AnnData for scANVI
814
+ scvi.model.SCANVI.setup_anndata(
815
+ adata_ref,
816
+ labels_key=cell_type_key,
817
+ unlabeled_category=params.scanvi_unlabeled_category,
818
+ batch_key=batch_key,
819
+ layer="counts",
820
+ )
821
+
822
+ # Create scANVI model
823
+ model = scvi.model.SCANVI(
824
+ adata_ref,
825
+ n_hidden=params.scanvi_n_hidden,
826
+ n_latent=params.scanvi_n_latent,
827
+ n_layers=params.scanvi_n_layers,
828
+ dropout_rate=params.scanvi_dropout_rate,
829
+ )
830
+
831
+ model.train(
832
+ max_epochs=params.num_epochs,
833
+ n_samples_per_label=params.scanvi_n_samples_per_label,
834
+ early_stopping=True,
835
+ check_val_every_n_epoch=params.scanvi_check_val_every_n_epoch,
836
+ )
837
+
838
+ # Query data preparation
839
+ adata_subset.obs[cell_type_key] = params.scanvi_unlabeled_category
840
+
841
+ # Setup query data (batch handling)
842
+ if batch_key and batch_key not in adata_subset.obs:
843
+ adata_subset.obs[batch_key] = "query_batch"
844
+
845
+ # Ensure counts layer exists for query data (create from adata.raw if needed)
846
+ ensure_counts_layer(
847
+ adata_subset,
848
+ error_message="scANVI requires raw counts in layers['counts'].",
849
+ )
850
+
851
+ scvi.model.SCANVI.setup_anndata(
852
+ adata_subset,
853
+ labels_key=cell_type_key,
854
+ unlabeled_category=params.scanvi_unlabeled_category,
855
+ batch_key=batch_key,
856
+ layer="counts",
857
+ )
858
+
859
+ # Transfer model to spatial data with proper parameters
860
+ spatial_model = scvi.model.SCANVI.load_query_data(adata_subset, model)
861
+
862
+ # ===== Improved Query Training (NEW) =====
863
+ spatial_model.train(
864
+ max_epochs=params.scanvi_query_epochs, # Default: 100 (was 50)
865
+ early_stopping=True,
866
+ plan_kwargs=dict(weight_decay=0.0), # Critical: preserve reference space
867
+ check_val_every_n_epoch=params.scanvi_check_val_every_n_epoch,
868
+ )
869
+
870
+ # COW FIX: Get predictions from adata_subset, then add to original adata
871
+ predictions = spatial_model.predict()
872
+ adata_subset.obs[cell_type_key] = predictions
873
+ ensure_categorical(adata_subset, cell_type_key)
874
+
875
+ # Extract results from adata_subset
876
+ cell_types = list(adata_subset.obs[cell_type_key].cat.categories)
877
+ counts = adata_subset.obs[cell_type_key].value_counts().to_dict()
878
+
879
+ # Get prediction probabilities as confidence scores
880
+ try:
881
+ probs = spatial_model.predict(soft=True)
882
+ confidence_scores = {}
883
+ for i, cell_type in enumerate(cell_types):
884
+ cells_of_type = adata_subset.obs[cell_type_key] == cell_type
885
+ if np.sum(cells_of_type) > 0 and isinstance(probs, pd.DataFrame):
886
+ if cell_type in probs.columns:
887
+ mean_prob = probs.loc[cells_of_type, cell_type].mean()
888
+ confidence_scores[cell_type] = round(float(mean_prob), 2)
889
+ # else: No probability column for this cell type - skip confidence
890
+ elif (
891
+ np.sum(cells_of_type) > 0
892
+ and hasattr(probs, "shape")
893
+ and probs.shape[1] > i
894
+ ):
895
+ mean_prob = probs[cells_of_type, i].mean()
896
+ confidence_scores[cell_type] = round(float(mean_prob), 2)
897
+ # else: No cells of this type or no probability data - skip confidence
898
+ except Exception as e:
899
+ await ctx.warning(f"Could not get confidence scores: {e}")
900
+ # Could not extract probabilities - return empty confidence dict
901
+ confidence_scores = (
902
+ {}
903
+ ) # Empty dict clearly indicates no confidence data available
904
+
905
+ # COW FIX: Add prediction results to original adata.obs using output_key
906
+ adata.obs[output_key] = adata_subset.obs[cell_type_key].values
907
+ ensure_categorical(adata, output_key)
908
+
909
+ # Store confidence if available
910
+ if confidence_scores:
911
+ confidence_array = [
912
+ confidence_scores.get(ct, 0.0) for ct in adata.obs[output_key]
913
+ ]
914
+ adata.obs[confidence_key] = confidence_array
915
+
916
+ return AnnotationMethodOutput(
917
+ cell_types=cell_types,
918
+ counts=counts,
919
+ confidence=confidence_scores,
920
+ )
921
+
922
+
923
+ async def _annotate_with_mllmcelltype(
924
+ adata,
925
+ params: AnnotationParameters,
926
+ ctx: "ToolContext",
927
+ output_key: str,
928
+ confidence_key: str,
929
+ ) -> AnnotationMethodOutput:
930
+ """Annotate cell types using mLLMCellType (LLM-based) method.
931
+
932
+ Supports both single-model and multi-model consensus annotation.
933
+
934
+ Single Model Mode (default):
935
+ - Uses one LLM for annotation
936
+ - Fast and cost-effective
937
+ - Providers: openai, anthropic, gemini, deepseek, qwen, zhipu, stepfun, minimax, grok, openrouter
938
+ - Default models: openai="gpt-5", anthropic="claude-sonnet-4-20250514", gemini="gemini-2.5-pro-preview-03-25"
939
+ - Latest recommended: "gpt-5", "claude-sonnet-4-5-20250929", "claude-opus-4-1-20250805", "gemini-2.5-pro"
940
+
941
+ Multi-Model Consensus Mode (set mllm_use_consensus=True):
942
+ - Uses multiple LLMs for collaborative annotation
943
+ - Higher accuracy through consensus
944
+ - Provides uncertainty metrics (consensus proportion, entropy)
945
+ - Structured deliberation for controversial clusters
946
+
947
+ Parameters (via AnnotationParameters):
948
+ - cluster_label: Required. Cluster column in adata.obs
949
+ - mllm_species: "human" or "mouse"
950
+ - mllm_tissue: Tissue context (optional but recommended)
951
+ - mllm_provider: LLM provider (single model mode)
952
+ - mllm_model: Model name (None = use default for provider)
953
+ - mllm_use_consensus: Enable multi-model consensus
954
+ - mllm_models: List of models for consensus (e.g., ["gpt-5", "claude-sonnet-4-5-20250929"])
955
+ - mllm_additional_context: Additional context for better annotation
956
+ - mllm_base_urls: Custom API endpoints (useful for proxies)
957
+ """
958
+
959
+ # Validate dependencies with comprehensive error reporting
960
+ require("mllmcelltype", ctx, feature="mLLMCellType annotation")
961
+ import mllmcelltype
962
+
963
+ # Validate clustering has been performed
964
+ # cluster_label is now required for mLLMCellType (no default value)
965
+ if not params.cluster_label:
966
+ available_cols = list(adata.obs.columns)
967
+ categorical_cols = [
968
+ col
969
+ for col in available_cols
970
+ if adata.obs[col].dtype.name in ["object", "category"]
971
+ ]
972
+
973
+ raise ParameterError(
974
+ f"cluster_label parameter is required for mLLMCellType method.\n\n"
975
+ f"Available categorical columns (likely clusters):\n {', '.join(categorical_cols[:15])}\n"
976
+ f"{f' ... and {len(categorical_cols)-15} more' if len(categorical_cols) > 15 else ''}\n\n"
977
+ f"Common cluster column names: leiden, louvain, seurat_clusters, phenograph\n\n"
978
+ f"Example: params = {{'cluster_label': 'leiden', ...}}"
979
+ )
980
+
981
+ cluster_key = params.cluster_label
982
+ validate_obs_column(adata, cluster_key, "Cluster")
983
+
984
+ # Find differentially expressed genes for each cluster
985
+
986
+ sc.tl.rank_genes_groups(adata, cluster_key, method="wilcoxon")
987
+
988
+ # Extract top marker genes for each cluster
989
+ marker_genes_dict = {}
990
+ n_genes = params.mllm_n_marker_genes
991
+
992
+ for cluster in adata.obs[cluster_key].unique():
993
+ # Get top genes for this cluster
994
+ gene_names = adata.uns["rank_genes_groups"]["names"][str(cluster)][:n_genes]
995
+ marker_genes_dict[f"Cluster_{cluster}"] = list(gene_names)
996
+
997
+ # Prepare parameters for mllmcelltype
998
+ species = params.mllm_species
999
+ tissue = params.mllm_tissue
1000
+ additional_context = params.mllm_additional_context
1001
+ use_cache = params.mllm_use_cache
1002
+ base_urls = params.mllm_base_urls
1003
+ verbose = params.mllm_verbose
1004
+ force_rerun = params.mllm_force_rerun
1005
+ clusters_to_analyze = params.mllm_clusters_to_analyze
1006
+
1007
+ # Check if using multi-model consensus or single model
1008
+ use_consensus = params.mllm_use_consensus
1009
+
1010
+ try:
1011
+ if use_consensus:
1012
+ # Use interactive_consensus_annotation with multiple models
1013
+ models = params.mllm_models
1014
+ if not models:
1015
+ raise ParameterError(
1016
+ "mllm_models parameter is required when mllm_use_consensus=True. "
1017
+ "Provide a list of model names, e.g., ['gpt-5', 'claude-sonnet-4-5-20250929', 'gemini-2.5-pro']"
1018
+ )
1019
+
1020
+ api_keys = params.mllm_api_keys
1021
+ consensus_threshold = params.mllm_consensus_threshold
1022
+ entropy_threshold = params.mllm_entropy_threshold
1023
+ max_discussion_rounds = params.mllm_max_discussion_rounds
1024
+ consensus_model = params.mllm_consensus_model
1025
+
1026
+ # Call interactive_consensus_annotation
1027
+ consensus_results = mllmcelltype.interactive_consensus_annotation(
1028
+ marker_genes=marker_genes_dict,
1029
+ species=species,
1030
+ models=models,
1031
+ api_keys=api_keys,
1032
+ tissue=tissue,
1033
+ additional_context=additional_context,
1034
+ consensus_threshold=consensus_threshold,
1035
+ entropy_threshold=entropy_threshold,
1036
+ max_discussion_rounds=max_discussion_rounds,
1037
+ use_cache=use_cache,
1038
+ verbose=verbose,
1039
+ consensus_model=consensus_model,
1040
+ base_urls=base_urls,
1041
+ clusters_to_analyze=clusters_to_analyze,
1042
+ force_rerun=force_rerun,
1043
+ )
1044
+
1045
+ # Extract consensus annotations
1046
+ annotations = consensus_results.get("consensus", {})
1047
+
1048
+ else:
1049
+ # Use single model annotation
1050
+ provider = params.mllm_provider
1051
+ model = params.mllm_model
1052
+ api_key = params.mllm_api_key
1053
+
1054
+ # Call annotate_clusters (single model)
1055
+ annotations = mllmcelltype.annotate_clusters(
1056
+ marker_genes=marker_genes_dict,
1057
+ species=species,
1058
+ provider=provider,
1059
+ model=model,
1060
+ api_key=api_key,
1061
+ tissue=tissue,
1062
+ additional_context=additional_context,
1063
+ use_cache=use_cache,
1064
+ base_urls=base_urls,
1065
+ )
1066
+ except Exception as e:
1067
+ raise ProcessingError(f"mLLMCellType annotation failed: {e}") from e
1068
+
1069
+ # Map cluster annotations back to cells
1070
+ cluster_to_celltype = {}
1071
+ for cluster_name, cell_type in annotations.items():
1072
+ # Extract cluster number from "Cluster_X" format
1073
+ cluster_id = cluster_name.replace("Cluster_", "")
1074
+ cluster_to_celltype[cluster_id] = cell_type
1075
+
1076
+ # Apply cell type annotations to cells (key provided by caller)
1077
+ adata.obs[output_key] = adata.obs[cluster_key].astype(str).map(cluster_to_celltype)
1078
+
1079
+ # Handle any unmapped clusters
1080
+ unmapped = adata.obs[output_key].isna()
1081
+ if unmapped.any():
1082
+ await ctx.warning(f"Found {unmapped.sum()} cells in unmapped clusters")
1083
+ adata.obs.loc[unmapped, output_key] = "Unknown"
1084
+
1085
+ ensure_categorical(adata, output_key)
1086
+
1087
+ # Get cell types and counts
1088
+ cell_types = list(adata.obs[output_key].unique())
1089
+ counts = adata.obs[output_key].value_counts().to_dict()
1090
+
1091
+ # LLM-based annotations don't provide numeric confidence scores
1092
+ # We intentionally leave this empty rather than assigning misleading values
1093
+ return AnnotationMethodOutput(
1094
+ cell_types=cell_types,
1095
+ counts=counts,
1096
+ confidence={},
1097
+ )
1098
+
1099
+
1100
+ async def _annotate_with_cellassign(
1101
+ adata,
1102
+ params: AnnotationParameters,
1103
+ ctx: "ToolContext",
1104
+ output_key: str,
1105
+ confidence_key: str,
1106
+ ) -> AnnotationMethodOutput:
1107
+ """Annotate cell types using CellAssign method"""
1108
+
1109
+ # Validate dependencies with comprehensive error reporting
1110
+ validate_scvi_tools(ctx, components=["CellAssign"])
1111
+ from scvi.external import CellAssign
1112
+
1113
+ # Check if marker genes are provided
1114
+ if params.marker_genes is None:
1115
+ raise ParameterError(
1116
+ "CellAssign requires marker genes to be provided. "
1117
+ "Please specify marker_genes parameter with a dictionary of cell types and their marker genes."
1118
+ )
1119
+
1120
+ marker_genes = params.marker_genes
1121
+
1122
+ # CRITICAL FIX: Use adata.raw for marker gene validation if available
1123
+ # Preprocessing filters genes to HVGs, but marker genes may not be in HVGs
1124
+ # adata.raw contains all original genes and should be checked first
1125
+ if adata.raw is not None:
1126
+ all_genes = set(adata.raw.var_names)
1127
+ gene_source = "adata.raw"
1128
+ else:
1129
+ all_genes = set(adata.var_names)
1130
+ gene_source = "adata.var_names"
1131
+ await ctx.warning(
1132
+ f"Using filtered gene set for marker gene validation "
1133
+ f"({len(all_genes)} genes). Some marker genes may be missing. "
1134
+ f"Consider using unpreprocessed data for CellAssign."
1135
+ )
1136
+
1137
+ # Validate marker genes exist in dataset
1138
+ valid_marker_genes = {}
1139
+ total_markers = sum(len(g) for g in marker_genes.values())
1140
+ markers_found = 0
1141
+ markers_missing = 0
1142
+
1143
+ for cell_type, genes in marker_genes.items():
1144
+ existing_genes = [gene for gene in genes if gene in all_genes]
1145
+ missing_genes = [gene for gene in genes if gene not in all_genes]
1146
+
1147
+ if existing_genes:
1148
+ valid_marker_genes[cell_type] = existing_genes
1149
+ markers_found += len(existing_genes)
1150
+ if missing_genes and len(missing_genes) > len(existing_genes):
1151
+ await ctx.warning(
1152
+ f"Missing most markers for {cell_type}: {len(missing_genes)}/{len(genes)}"
1153
+ )
1154
+ else:
1155
+ markers_missing += len(genes)
1156
+ await ctx.warning(
1157
+ f"No marker genes found for {cell_type} - all {len(genes)} markers missing!"
1158
+ )
1159
+
1160
+ if not valid_marker_genes:
1161
+ raise DataError(
1162
+ f"No valid marker genes found for any cell type. "
1163
+ f"Checked {total_markers} markers against {len(all_genes)} genes in {gene_source}. "
1164
+ f"If data was preprocessed, marker genes may have been filtered out. "
1165
+ f"Consider using unpreprocessed data or ensure marker genes are highly variable."
1166
+ )
1167
+ valid_cell_types = list(valid_marker_genes)
1168
+
1169
+ # Create marker gene matrix as DataFrame (required by CellAssign API)
1170
+ all_marker_genes = []
1171
+ for genes in valid_marker_genes.values():
1172
+ all_marker_genes.extend(genes)
1173
+ available_marker_genes = list(set(all_marker_genes)) # Remove duplicates
1174
+
1175
+ # Note: available_marker_genes cannot be empty here because valid_marker_genes
1176
+ # is already validated at line 1120 to have at least one cell type with genes
1177
+
1178
+ # Create DataFrame with genes as index, cell types as columns
1179
+ marker_gene_matrix = pd.DataFrame(
1180
+ np.zeros((len(available_marker_genes), len(valid_cell_types))),
1181
+ index=available_marker_genes,
1182
+ columns=valid_cell_types,
1183
+ )
1184
+
1185
+ # Fill marker matrix
1186
+ for cell_type in valid_cell_types:
1187
+ for gene in valid_marker_genes[cell_type]:
1188
+ if gene in available_marker_genes:
1189
+ marker_gene_matrix.loc[gene, cell_type] = 1
1190
+
1191
+ # Compute size factors BEFORE subsetting (official CellAssign requirement)
1192
+ if "size_factors" not in adata.obs:
1193
+ # Calculate size factors from FULL dataset
1194
+ if hasattr(adata.X, "sum"):
1195
+ size_factors = adata.X.sum(axis=1)
1196
+ if hasattr(size_factors, "A1"): # sparse matrix
1197
+ size_factors = size_factors.A1
1198
+ else:
1199
+ size_factors = np.sum(adata.X, axis=1)
1200
+
1201
+ # Normalize and ensure positive
1202
+ size_factors = np.maximum(size_factors, 1e-6)
1203
+ mean_sf = np.mean(size_factors)
1204
+ size_factors_normalized = size_factors / mean_sf
1205
+
1206
+ adata.obs["size_factors"] = pd.Series(
1207
+ size_factors_normalized, index=adata.obs.index
1208
+ )
1209
+
1210
+ # Subset data to marker genes (size factors already computed)
1211
+ # Use adata.raw if available (contains all genes including markers)
1212
+ if adata.raw is not None:
1213
+ import anndata as ad_module
1214
+
1215
+ adata_subset = ad_module.AnnData(
1216
+ X=adata.raw[:, available_marker_genes].X,
1217
+ obs=adata.obs.copy(),
1218
+ var=adata.raw.var.loc[available_marker_genes].copy(),
1219
+ )
1220
+ else:
1221
+ adata_subset = adata[:, available_marker_genes].copy()
1222
+
1223
+ # Check for invalid values in the data
1224
+ X_array = to_dense(adata_subset.X)
1225
+
1226
+ # Replace any NaN or Inf values with zeros
1227
+ if np.any(np.isnan(X_array)) or np.any(np.isinf(X_array)):
1228
+ await ctx.warning("Found NaN or Inf values in data, replacing with zeros")
1229
+ X_array = np.nan_to_num(X_array, nan=0.0, posinf=0.0, neginf=0.0)
1230
+ adata_subset.X = X_array
1231
+
1232
+ # Additional data cleaning for CellAssign compatibility
1233
+ # Check for genes with zero variance (which cause numerical issues in CellAssign)
1234
+ gene_vars = np.var(X_array, axis=0)
1235
+ zero_var_genes = gene_vars == 0
1236
+ if np.any(zero_var_genes):
1237
+ adata_subset.var_names[zero_var_genes].tolist()
1238
+ await ctx.warning(
1239
+ f"Found {np.sum(zero_var_genes)} genes with zero variance. "
1240
+ f"CellAssign may have numerical issues with these genes."
1241
+ )
1242
+ # Don't raise error, just warn - CellAssign might handle it
1243
+
1244
+ # Ensure data is non-negative (CellAssign expects count-like data)
1245
+ if np.any(X_array < 0):
1246
+ await ctx.warning("Found negative values in data, clipping to zero")
1247
+ X_array = np.maximum(X_array, 0)
1248
+ adata_subset.X = X_array
1249
+
1250
+ # Verify size factors were transferred to subset
1251
+ if "size_factors" not in adata_subset.obs:
1252
+ raise ProcessingError(
1253
+ "Size factors not found in adata.obs. This should not happen - "
1254
+ "they should have been computed before subsetting. Please report this bug."
1255
+ )
1256
+
1257
+ # Setup CellAssign on subset data only
1258
+ CellAssign.setup_anndata(adata_subset, size_factor_key="size_factors")
1259
+
1260
+ # Train CellAssign model
1261
+ model = CellAssign(adata_subset, marker_gene_matrix)
1262
+
1263
+ model.train(
1264
+ max_epochs=params.cellassign_max_iter, lr=params.cellassign_learning_rate
1265
+ )
1266
+
1267
+ # Get predictions
1268
+ predictions = model.predict()
1269
+
1270
+ # Handle different prediction formats (key provided by caller)
1271
+ if isinstance(predictions, pd.DataFrame):
1272
+ # CellAssign returns DataFrame with probabilities
1273
+ predicted_indices = predictions.values.argmax(axis=1)
1274
+ adata.obs[output_key] = [valid_cell_types[i] for i in predicted_indices]
1275
+
1276
+ # Get confidence scores from probabilities DataFrame
1277
+ confidence_scores = {}
1278
+ for i, cell_type in enumerate(valid_cell_types):
1279
+ cells_of_type = adata.obs[output_key] == cell_type
1280
+ if np.sum(cells_of_type) > 0:
1281
+ # Use iloc with boolean indexing properly
1282
+ cell_indices = np.where(cells_of_type)[0]
1283
+ mean_prob = predictions.iloc[cell_indices, i].mean()
1284
+ confidence_scores[cell_type] = round(float(mean_prob), 2)
1285
+ # else: No cells of this type - skip confidence
1286
+ else:
1287
+ # Other models return indices directly
1288
+ adata.obs[output_key] = [valid_cell_types[i] for i in predictions]
1289
+ # CellAssign returned indices, not probabilities - no confidence available
1290
+ confidence_scores = {} # Empty dict indicates no confidence data
1291
+
1292
+ ensure_categorical(adata, output_key)
1293
+
1294
+ # Store confidence if available
1295
+ if confidence_scores:
1296
+ confidence_array = [
1297
+ confidence_scores.get(ct, 0.0) for ct in adata.obs[output_key]
1298
+ ]
1299
+ adata.obs[confidence_key] = confidence_array
1300
+
1301
+ # Get cell types and counts
1302
+ counts = adata.obs[output_key].value_counts().to_dict()
1303
+
1304
+ return AnnotationMethodOutput(
1305
+ cell_types=valid_cell_types,
1306
+ counts=counts,
1307
+ confidence=confidence_scores,
1308
+ )
1309
+
1310
+
1311
+ async def annotate_cell_types(
1312
+ data_id: str,
1313
+ ctx: ToolContext,
1314
+ params: AnnotationParameters, # No default - must be provided by caller (LLM)
1315
+ ) -> AnnotationResult:
1316
+ """Annotate cell types in spatial transcriptomics data
1317
+
1318
+ Args:
1319
+ data_id: Dataset ID
1320
+ ctx: Tool context for data access and logging
1321
+ params: Annotation parameters
1322
+
1323
+ Returns:
1324
+ Annotation result
1325
+ """
1326
+ # Retrieve the AnnData object via ToolContext
1327
+ adata = await ctx.get_adata(data_id)
1328
+
1329
+ # Validate method first - clean and simple
1330
+ if params.method not in SUPPORTED_METHODS:
1331
+ raise ParameterError(
1332
+ f"Unsupported method: {params.method}. Supported: {sorted(SUPPORTED_METHODS)}"
1333
+ )
1334
+
1335
+ # Get reference data if needed for methods that require it
1336
+ reference_adata = None
1337
+ if params.method in ["tangram", "scanvi", "singler"] and params.reference_data_id:
1338
+ reference_adata = await ctx.get_adata(params.reference_data_id)
1339
+
1340
+ # Generate output keys in ONE place (single-point control)
1341
+ output_key = f"cell_type_{params.method}"
1342
+ confidence_key = f"confidence_{params.method}"
1343
+
1344
+ # Route to appropriate annotation method
1345
+ try:
1346
+ if params.method == "tangram":
1347
+ result = await _annotate_with_tangram(
1348
+ adata, params, ctx, output_key, confidence_key, reference_adata
1349
+ )
1350
+ elif params.method == "scanvi":
1351
+ result = await _annotate_with_scanvi(
1352
+ adata, params, ctx, output_key, confidence_key, reference_adata
1353
+ )
1354
+ elif params.method == "cellassign":
1355
+ result = await _annotate_with_cellassign(
1356
+ adata, params, ctx, output_key, confidence_key
1357
+ )
1358
+ elif params.method == "mllmcelltype":
1359
+ result = await _annotate_with_mllmcelltype(
1360
+ adata, params, ctx, output_key, confidence_key
1361
+ )
1362
+ elif params.method == "singler":
1363
+ result = await _annotate_with_singler(
1364
+ adata, params, ctx, output_key, confidence_key, reference_adata
1365
+ )
1366
+ else: # sctype
1367
+ result = await _annotate_with_sctype(
1368
+ adata, params, ctx, output_key, confidence_key
1369
+ )
1370
+
1371
+ except Exception as e:
1372
+ raise ProcessingError(f"Annotation failed: {e}") from e
1373
+
1374
+ # Extract values from unified result type
1375
+ cell_types = result.cell_types
1376
+ counts = result.counts
1377
+ confidence_scores = result.confidence
1378
+ tangram_mapping_score = result.mapping_score
1379
+
1380
+ # Determine if confidence_key should be reported (only if we have confidence data)
1381
+ confidence_key_for_result = confidence_key if confidence_scores else None
1382
+
1383
+ # Store scientific metadata for reproducibility
1384
+ from ..utils.adata_utils import store_analysis_metadata
1385
+
1386
+ # Extract results keys
1387
+ results_keys_dict = {"obs": [output_key], "obsm": [], "uns": []}
1388
+ if confidence_key_for_result:
1389
+ results_keys_dict["obs"].append(confidence_key)
1390
+
1391
+ # Add method-specific result keys
1392
+ if params.method == "tangram":
1393
+ results_keys_dict["obsm"].extend(
1394
+ ["tangram_ct_pred", "tangram_gene_predictions"]
1395
+ )
1396
+
1397
+ # Prepare parameters dict (only scientifically important ones)
1398
+ parameters_dict = {}
1399
+ if params.method == "tangram":
1400
+ parameters_dict = {
1401
+ "device": params.tangram_device,
1402
+ "n_epochs": params.num_epochs, # Fixed: use num_epochs instead of tangram_num_epochs
1403
+ "learning_rate": params.tangram_learning_rate,
1404
+ }
1405
+ elif params.method == "scanvi":
1406
+ parameters_dict = {
1407
+ "n_latent": params.scanvi_n_latent,
1408
+ "n_hidden": params.scanvi_n_hidden,
1409
+ "dropout_rate": params.scanvi_dropout_rate,
1410
+ "use_scvi_pretrain": params.scanvi_use_scvi_pretrain,
1411
+ }
1412
+ elif params.method == "mllmcelltype":
1413
+ parameters_dict = {
1414
+ "n_marker_genes": params.mllm_n_marker_genes,
1415
+ "species": params.mllm_species,
1416
+ "provider": params.mllm_provider,
1417
+ "model": params.mllm_model,
1418
+ "use_consensus": params.mllm_use_consensus,
1419
+ }
1420
+ elif params.method == "sctype":
1421
+ parameters_dict = {
1422
+ "tissue": params.sctype_tissue,
1423
+ "scaled": params.sctype_scaled,
1424
+ }
1425
+ elif params.method == "singler":
1426
+ parameters_dict = {
1427
+ "fine_tune": params.singler_fine_tune,
1428
+ }
1429
+
1430
+ # Prepare statistics dict
1431
+ statistics_dict = {"n_cell_types": len(cell_types)}
1432
+ if tangram_mapping_score is not None:
1433
+ statistics_dict["mapping_score"] = tangram_mapping_score
1434
+
1435
+ # Prepare reference info if applicable
1436
+ reference_info_dict = None
1437
+ if params.method in ["tangram", "scanvi", "singler"] and params.reference_data_id:
1438
+ reference_info_dict = {"reference_data_id": params.reference_data_id}
1439
+
1440
+ # Store metadata
1441
+ store_analysis_metadata(
1442
+ adata,
1443
+ analysis_name=f"annotation_{params.method}",
1444
+ method=params.method,
1445
+ parameters=parameters_dict,
1446
+ results_keys=results_keys_dict,
1447
+ statistics=statistics_dict,
1448
+ reference_info=reference_info_dict,
1449
+ )
1450
+
1451
+ # Return result
1452
+ return AnnotationResult(
1453
+ data_id=data_id,
1454
+ method=params.method,
1455
+ output_key=output_key,
1456
+ confidence_key=confidence_key_for_result,
1457
+ cell_types=cell_types,
1458
+ counts=counts,
1459
+ confidence_scores=confidence_scores,
1460
+ tangram_mapping_score=tangram_mapping_score,
1461
+ )
1462
+
1463
+
1464
+ # ============================================================================
1465
+ # SC-TYPE IMPLEMENTATION
1466
+ # ============================================================================
1467
+
1468
+ # Cache for sc-type results (memory only, no pickle)
1469
+ _SCTYPE_CACHE: dict[str, Any] = {}
1470
+ _SCTYPE_CACHE_DIR = Path.home() / ".chatspatial" / "sctype_cache"
1471
+
1472
+ # R code constants for sc-type (extracted for clarity)
1473
+ _R_INSTALL_PACKAGES = """
1474
+ required_packages <- c("dplyr", "openxlsx", "HGNChelper")
1475
+ for (pkg in required_packages) {
1476
+ if (!require(pkg, character.only = TRUE, quietly = TRUE)) {
1477
+ install.packages(pkg, repos = "https://cran.r-project.org/", quiet = TRUE)
1478
+ if (!require(pkg, character.only = TRUE, quietly = TRUE)) {
1479
+ stop(paste("Failed to install R package:", pkg))
1480
+ }
1481
+ }
1482
+ }
1483
+ """
1484
+
1485
+ _R_LOAD_SCTYPE = """
1486
+ source("https://raw.githubusercontent.com/IanevskiAleksandr/sc-type/master/R/gene_sets_prepare.R")
1487
+ source("https://raw.githubusercontent.com/IanevskiAleksandr/sc-type/master/R/sctype_score_.R")
1488
+ """
1489
+
1490
+ _R_SCTYPE_SCORING = """
1491
+ # Set row/column names and convert to dense
1492
+ rownames(scdata) <- gene_names
1493
+ colnames(scdata) <- cell_names
1494
+ if (inherits(scdata, 'sparseMatrix')) scdata <- as.matrix(scdata)
1495
+
1496
+ # Extract gene sets
1497
+ gs_positive <- gs_list$gs_positive
1498
+ gs_negative <- gs_list$gs_negative
1499
+
1500
+ if (length(gs_positive) == 0) stop("No valid positive gene sets found")
1501
+
1502
+ # Filter gene sets to genes present in data
1503
+ available_genes <- rownames(scdata)
1504
+ filtered_gs_positive <- list()
1505
+ filtered_gs_negative <- list()
1506
+
1507
+ for (celltype in names(gs_positive)) {
1508
+ pos_genes <- gs_positive[[celltype]]
1509
+ neg_genes <- if (celltype %in% names(gs_negative)) gs_negative[[celltype]] else c()
1510
+ pos_overlap <- intersect(toupper(pos_genes), toupper(available_genes))
1511
+ if (length(pos_overlap) > 0) {
1512
+ filtered_gs_positive[[celltype]] <- pos_overlap
1513
+ filtered_gs_negative[[celltype]] <- intersect(toupper(neg_genes), toupper(available_genes))
1514
+ }
1515
+ }
1516
+
1517
+ if (length(filtered_gs_positive) == 0) {
1518
+ stop("No valid cell type gene sets found after filtering.")
1519
+ }
1520
+
1521
+ # Run sc-type scoring
1522
+ es_max <- sctype_score(
1523
+ scRNAseqData = as.matrix(scdata),
1524
+ scaled = TRUE,
1525
+ gs = filtered_gs_positive,
1526
+ gs2 = filtered_gs_negative
1527
+ )
1528
+
1529
+ if (is.null(es_max) || nrow(es_max) == 0) {
1530
+ stop("SC-Type scoring failed to produce results.")
1531
+ }
1532
+ """
1533
+
1534
+ # Valid tissue types from sc-type database
1535
+ SCTYPE_VALID_TISSUES = {
1536
+ "Adrenal",
1537
+ "Brain",
1538
+ "Eye",
1539
+ "Heart",
1540
+ "Hippocampus",
1541
+ "Immune system",
1542
+ "Intestine",
1543
+ "Kidney",
1544
+ "Liver",
1545
+ "Lung",
1546
+ "Muscle",
1547
+ "Pancreas",
1548
+ "Placenta",
1549
+ "Spleen",
1550
+ "Stomach",
1551
+ "Thymus",
1552
+ }
1553
+
1554
+
1555
+ def _get_sctype_cache_key(adata, params: AnnotationParameters) -> str:
1556
+ """Generate cache key for sc-type results"""
1557
+ # Create a hash based on data and parameters
1558
+ data_hash = hashlib.md5()
1559
+
1560
+ # Hash expression data (sample first 1000 cells and 500 genes for efficiency)
1561
+ sample_slice = adata.X[: min(1000, adata.n_obs), : min(500, adata.n_vars)]
1562
+ sample_data = to_dense(sample_slice)
1563
+ data_hash.update(sample_data.tobytes())
1564
+
1565
+ # Hash relevant parameters
1566
+ params_dict = {
1567
+ "tissue": params.sctype_tissue,
1568
+ "db": params.sctype_db_,
1569
+ "scaled": params.sctype_scaled,
1570
+ "custom_markers": params.sctype_custom_markers,
1571
+ }
1572
+ data_hash.update(str(params_dict).encode())
1573
+
1574
+ return data_hash.hexdigest()
1575
+
1576
+
1577
+ def _load_sctype_functions(ctx: "ToolContext") -> None:
1578
+ """Load sc-type R functions and auto-install R packages if needed."""
1579
+ robjects, _, _, _, _, default_converter, openrlib, _ = validate_r_environment(ctx)
1580
+ from rpy2.robjects import conversion
1581
+
1582
+ with openrlib.rlock:
1583
+ with conversion.localconverter(default_converter):
1584
+ robjects.r(_R_INSTALL_PACKAGES)
1585
+ robjects.r(_R_LOAD_SCTYPE)
1586
+
1587
+
1588
+ def _prepare_sctype_genesets(params: AnnotationParameters, ctx: "ToolContext"):
1589
+ """Prepare gene sets for sc-type."""
1590
+ if params.sctype_custom_markers:
1591
+ return _convert_custom_markers_to_gs(params.sctype_custom_markers, ctx)
1592
+
1593
+ # Use sc-type database
1594
+ tissue = params.sctype_tissue
1595
+ if not tissue:
1596
+ raise ParameterError("sctype_tissue is required when not using custom markers")
1597
+
1598
+ robjects, _, _, _, _, default_converter, openrlib, _ = validate_r_environment(ctx)
1599
+ from rpy2.robjects import conversion
1600
+
1601
+ db_path = (
1602
+ params.sctype_db_
1603
+ or "https://raw.githubusercontent.com/IanevskiAleksandr/sc-type/master/ScTypeDB_full.xlsx"
1604
+ )
1605
+
1606
+ with openrlib.rlock:
1607
+ with conversion.localconverter(default_converter):
1608
+ robjects.r.assign("db_path", db_path)
1609
+ robjects.r.assign("tissue_type", tissue)
1610
+ robjects.r("gs_list <- gene_sets_prepare(db_path, tissue_type)")
1611
+ return robjects.r["gs_list"]
1612
+
1613
+
1614
+ def _convert_custom_markers_to_gs(
1615
+ custom_markers: dict[str, dict[str, list[str]]], ctx: "ToolContext"
1616
+ ):
1617
+ """Convert custom markers to sc-type gene set format"""
1618
+ if not custom_markers:
1619
+ raise DataError("Custom markers dictionary is empty")
1620
+
1621
+ gs_positive = {}
1622
+ gs_negative = {}
1623
+
1624
+ valid_celltypes = 0
1625
+
1626
+ for cell_type, markers in custom_markers.items():
1627
+ if not isinstance(markers, dict):
1628
+ continue
1629
+
1630
+ positive_genes = []
1631
+ negative_genes = []
1632
+
1633
+ if "positive" in markers and isinstance(markers["positive"], list):
1634
+ positive_genes = [
1635
+ str(g).strip().upper()
1636
+ for g in markers["positive"]
1637
+ if g and str(g).strip()
1638
+ ]
1639
+
1640
+ if "negative" in markers and isinstance(markers["negative"], list):
1641
+ negative_genes = [
1642
+ str(g).strip().upper()
1643
+ for g in markers["negative"]
1644
+ if g and str(g).strip()
1645
+ ]
1646
+
1647
+ # Only include cell types that have at least some positive markers
1648
+ if positive_genes:
1649
+ gs_positive[cell_type] = positive_genes
1650
+ gs_negative[cell_type] = negative_genes # Can be empty list
1651
+ valid_celltypes += 1
1652
+
1653
+ if valid_celltypes == 0:
1654
+ raise DataError(
1655
+ "No valid cell types found in custom markers - all cell types need at least one positive marker"
1656
+ )
1657
+
1658
+ # Get robjects and converters from validation
1659
+ robjects, pandas2ri, _, _, localconverter, default_converter, openrlib, _ = (
1660
+ validate_r_environment(ctx)
1661
+ )
1662
+
1663
+ # Wrap R calls in conversion context (FIX for contextvars issue)
1664
+ with openrlib.rlock:
1665
+ with localconverter(robjects.default_converter + pandas2ri.converter):
1666
+ # Convert Python dictionaries to R named lists, handle empty lists properly
1667
+ r_gs_positive = robjects.r["list"](
1668
+ **{
1669
+ k: robjects.StrVector(v) if v else robjects.StrVector([])
1670
+ for k, v in gs_positive.items()
1671
+ }
1672
+ )
1673
+ r_gs_negative = robjects.r["list"](
1674
+ **{
1675
+ k: robjects.StrVector(v) if v else robjects.StrVector([])
1676
+ for k, v in gs_negative.items()
1677
+ }
1678
+ )
1679
+
1680
+ # Create the final gs_list structure
1681
+ gs_list = robjects.r["list"](
1682
+ gs_positive=r_gs_positive, gs_negative=r_gs_negative
1683
+ )
1684
+
1685
+ return gs_list
1686
+
1687
+
1688
+ def _run_sctype_scoring(
1689
+ adata, gs_list, params: AnnotationParameters, ctx: "ToolContext"
1690
+ ) -> pd.DataFrame:
1691
+ """Run sc-type scoring algorithm."""
1692
+ robjects, pandas2ri, numpy2ri, _, _, default_converter, openrlib, anndata2ri = (
1693
+ validate_r_environment(ctx)
1694
+ )
1695
+ from rpy2.robjects import conversion
1696
+
1697
+ # Prepare expression data
1698
+ expr_data = (
1699
+ adata.layers["scaled"]
1700
+ if params.sctype_scaled and "scaled" in adata.layers
1701
+ else adata.X
1702
+ )
1703
+
1704
+ with openrlib.rlock:
1705
+ with conversion.localconverter(
1706
+ default_converter
1707
+ + anndata2ri.converter
1708
+ + pandas2ri.converter
1709
+ + numpy2ri.converter
1710
+ ):
1711
+ # Transfer data to R (genes × cells for scType)
1712
+ robjects.r.assign("scdata", expr_data.T)
1713
+ robjects.r.assign("gene_names", list(adata.var_names))
1714
+ robjects.r.assign("cell_names", list(adata.obs_names))
1715
+ robjects.r.assign("gs_list", gs_list)
1716
+
1717
+ # Run scoring using pre-defined R code
1718
+ robjects.r(_R_SCTYPE_SCORING)
1719
+
1720
+ # Get results
1721
+ row_names = list(robjects.r("rownames(es_max)"))
1722
+ col_names = list(robjects.r("colnames(es_max)"))
1723
+ scores_matrix = robjects.r["es_max"]
1724
+
1725
+ # Convert to DataFrame
1726
+ if isinstance(scores_matrix, pd.DataFrame):
1727
+ scores_df = scores_matrix
1728
+ scores_df.index = row_names if row_names else scores_df.index
1729
+ scores_df.columns = col_names if col_names else scores_df.columns
1730
+ else:
1731
+ scores_df = pd.DataFrame(scores_matrix, index=row_names, columns=col_names)
1732
+
1733
+ return scores_df
1734
+
1735
+
1736
+ def _softmax(scores_array: np.ndarray) -> np.ndarray:
1737
+ """Compute softmax probabilities from raw scores (numerically stable)."""
1738
+ shifted = scores_array - np.max(scores_array)
1739
+ exp_scores = np.exp(shifted)
1740
+ return exp_scores / np.sum(exp_scores)
1741
+
1742
+
1743
+ def _assign_sctype_celltypes(
1744
+ scores_df: pd.DataFrame, ctx: "ToolContext"
1745
+ ) -> tuple[list[str], list[float]]:
1746
+ """Assign cell types based on sc-type scores using softmax confidence."""
1747
+ if scores_df is None or scores_df.empty:
1748
+ raise DataError("Scores DataFrame is empty or None")
1749
+
1750
+ cell_types = []
1751
+ confidence_scores = []
1752
+
1753
+ for col_name in scores_df.columns:
1754
+ cell_scores = scores_df[col_name]
1755
+ max_idx = cell_scores.idxmax()
1756
+ max_score = cell_scores.loc[max_idx]
1757
+
1758
+ if max_score > 0:
1759
+ cell_types.append(str(max_idx))
1760
+ # Softmax gives statistically meaningful confidence
1761
+ softmax_probs = _softmax(cell_scores.values)
1762
+ confidence_scores.append(
1763
+ float(softmax_probs[cell_scores.index.get_loc(max_idx)])
1764
+ )
1765
+ else:
1766
+ cell_types.append("Unknown")
1767
+ confidence_scores.append(0.0)
1768
+
1769
+ return cell_types, confidence_scores
1770
+
1771
+
1772
+ def _calculate_sctype_stats(cell_types: list[str]) -> dict[str, int]:
1773
+ """Calculate cell type counts."""
1774
+ from collections import Counter
1775
+
1776
+ return dict(Counter(cell_types))
1777
+
1778
+
1779
+ async def _cache_sctype_results(
1780
+ cache_key: str, results: tuple, ctx: "ToolContext"
1781
+ ) -> None:
1782
+ """Cache sc-type results to disk as JSON (secure, no pickle)."""
1783
+ try:
1784
+ _SCTYPE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
1785
+ cache_file = _SCTYPE_CACHE_DIR / f"{cache_key}.json"
1786
+
1787
+ # Convert tuple to serializable dict
1788
+ cell_types, counts, confidence_by_celltype, mapping_score = results
1789
+ cache_data = {
1790
+ "cell_types": cell_types,
1791
+ "counts": counts,
1792
+ "confidence_by_celltype": confidence_by_celltype,
1793
+ "mapping_score": mapping_score,
1794
+ }
1795
+
1796
+ with open(cache_file, "w", encoding="utf-8") as f:
1797
+ json.dump(cache_data, f)
1798
+
1799
+ _SCTYPE_CACHE[cache_key] = results
1800
+ except Exception as e:
1801
+ await ctx.warning(f"Failed to cache results: {e}")
1802
+
1803
+
1804
+ def _load_cached_sctype_results(cache_key: str, ctx: "ToolContext") -> Optional[tuple]:
1805
+ """Load cached sc-type results from memory or JSON file."""
1806
+ # Check memory cache first
1807
+ if cache_key in _SCTYPE_CACHE:
1808
+ return _SCTYPE_CACHE[cache_key]
1809
+
1810
+ # Check disk cache (JSON)
1811
+ cache_file = _SCTYPE_CACHE_DIR / f"{cache_key}.json"
1812
+ if cache_file.exists():
1813
+ try:
1814
+ with open(cache_file, "r", encoding="utf-8") as f:
1815
+ cache_data = json.load(f)
1816
+
1817
+ results = (
1818
+ cache_data["cell_types"],
1819
+ cache_data["counts"],
1820
+ cache_data["confidence_by_celltype"],
1821
+ cache_data.get("mapping_score"),
1822
+ )
1823
+ _SCTYPE_CACHE[cache_key] = results
1824
+ return results
1825
+ except Exception:
1826
+ # Cache corrupted or incompatible, will recompute
1827
+ pass
1828
+
1829
+ return None
1830
+
1831
+
1832
+ async def _annotate_with_sctype(
1833
+ adata: sc.AnnData,
1834
+ params: AnnotationParameters,
1835
+ ctx: "ToolContext",
1836
+ output_key: str,
1837
+ confidence_key: str,
1838
+ ) -> AnnotationMethodOutput:
1839
+ """Annotate cell types using sc-type method."""
1840
+ # Validate R environment
1841
+ validate_r_environment(ctx)
1842
+
1843
+ # Validate parameters
1844
+ if not params.sctype_tissue and not params.sctype_custom_markers:
1845
+ raise ParameterError(
1846
+ "Either sctype_tissue or sctype_custom_markers must be specified"
1847
+ )
1848
+
1849
+ if params.sctype_tissue and params.sctype_tissue not in SCTYPE_VALID_TISSUES:
1850
+ raise ParameterError(
1851
+ f"Tissue '{params.sctype_tissue}' not supported. "
1852
+ f"Valid: {', '.join(sorted(SCTYPE_VALID_TISSUES))}"
1853
+ )
1854
+
1855
+ # Check cache
1856
+ cache_key = None
1857
+ if params.sctype_use_cache:
1858
+ cache_key = _get_sctype_cache_key(adata, params)
1859
+ cached = _load_cached_sctype_results(cache_key, ctx)
1860
+ if cached:
1861
+ # Convert cached tuple to AnnotationMethodOutput
1862
+ cell_types, counts, confidence, _ = cached
1863
+ # Still need to store in adata.obs when using cache
1864
+ adata.obs[output_key] = pd.Categorical(cell_types)
1865
+ return AnnotationMethodOutput(
1866
+ cell_types=cell_types,
1867
+ counts=counts,
1868
+ confidence=confidence,
1869
+ )
1870
+
1871
+ # Run sc-type pipeline
1872
+ _load_sctype_functions(ctx)
1873
+ gs_list = _prepare_sctype_genesets(params, ctx)
1874
+ scores_df = _run_sctype_scoring(adata, gs_list, params, ctx)
1875
+ per_cell_types, per_cell_confidence = _assign_sctype_celltypes(scores_df, ctx)
1876
+
1877
+ # Calculate statistics
1878
+ counts = _calculate_sctype_stats(per_cell_types)
1879
+
1880
+ # Average confidence per cell type (for return value)
1881
+ confidence_by_celltype = {}
1882
+ for ct in set(per_cell_types):
1883
+ ct_confs = [
1884
+ c for i, c in enumerate(per_cell_confidence) if per_cell_types[i] == ct
1885
+ ]
1886
+ confidence_by_celltype[ct] = sum(ct_confs) / len(ct_confs) if ct_confs else 0.0
1887
+
1888
+ # Store in adata.obs (keys provided by caller)
1889
+ adata.obs[output_key] = pd.Categorical(per_cell_types)
1890
+ adata.obs[confidence_key] = per_cell_confidence
1891
+
1892
+ unique_cell_types = list(set(per_cell_types))
1893
+
1894
+ # Cache results (as tuple for compatibility)
1895
+ if params.sctype_use_cache and cache_key:
1896
+ cache_tuple = (unique_cell_types, counts, confidence_by_celltype, None)
1897
+ await _cache_sctype_results(cache_key, cache_tuple, ctx)
1898
+
1899
+ return AnnotationMethodOutput(
1900
+ cell_types=unique_cell_types,
1901
+ counts=counts,
1902
+ confidence=confidence_by_celltype,
1903
+ )