chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,462 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Unified dependency management for ChatSpatial MCP.
|
|
3
|
+
|
|
4
|
+
Provides a consistent API for managing optional dependencies, replacing
|
|
5
|
+
scattered try/except ImportError patterns with centralized handling.
|
|
6
|
+
|
|
7
|
+
Usage:
|
|
8
|
+
# Require a dependency (raises if missing)
|
|
9
|
+
scvi = require("scvi-tools", feature="cell type annotation")
|
|
10
|
+
|
|
11
|
+
# Get optional dependency (returns None if missing)
|
|
12
|
+
torch = get("torch")
|
|
13
|
+
|
|
14
|
+
# Check availability
|
|
15
|
+
if is_available("rpy2"):
|
|
16
|
+
import rpy2
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import importlib
|
|
20
|
+
import importlib.util
|
|
21
|
+
import warnings
|
|
22
|
+
from dataclasses import dataclass
|
|
23
|
+
from functools import lru_cache
|
|
24
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from ..spatial_mcp_adapter import ToolContext
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass(frozen=True)
|
|
31
|
+
class DependencyInfo:
|
|
32
|
+
"""Metadata for an optional dependency."""
|
|
33
|
+
|
|
34
|
+
module_name: str
|
|
35
|
+
install_cmd: str
|
|
36
|
+
description: str = ""
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
# Registry of optional dependencies with install instructions
|
|
40
|
+
DEPENDENCY_REGISTRY: dict[str, DependencyInfo] = {
|
|
41
|
+
# Deep Learning
|
|
42
|
+
"scvi-tools": DependencyInfo(
|
|
43
|
+
"scvi", "pip install scvi-tools", "Single-cell variational inference tools"
|
|
44
|
+
),
|
|
45
|
+
"torch": DependencyInfo(
|
|
46
|
+
"torch", "pip install torch", "PyTorch deep learning framework"
|
|
47
|
+
),
|
|
48
|
+
"cell2location": DependencyInfo(
|
|
49
|
+
"cell2location",
|
|
50
|
+
"pip install cell2location",
|
|
51
|
+
"Probabilistic cell type deconvolution",
|
|
52
|
+
),
|
|
53
|
+
"flashdeconv": DependencyInfo(
|
|
54
|
+
"flashdeconv", "pip install flashdeconv", "Ultra-fast spatial deconvolution"
|
|
55
|
+
),
|
|
56
|
+
# Spatial Analysis
|
|
57
|
+
"tangram": DependencyInfo(
|
|
58
|
+
"tangram",
|
|
59
|
+
"pip install tangram-sc",
|
|
60
|
+
"Spatial mapping of single-cell transcriptomics",
|
|
61
|
+
),
|
|
62
|
+
"squidpy": DependencyInfo(
|
|
63
|
+
"squidpy", "pip install squidpy", "Spatial single-cell analysis"
|
|
64
|
+
),
|
|
65
|
+
"SpaGCN": DependencyInfo(
|
|
66
|
+
"SpaGCN",
|
|
67
|
+
"pip install SpaGCN",
|
|
68
|
+
"Spatial domain identification using graph convolutional networks",
|
|
69
|
+
),
|
|
70
|
+
"STAGATE": DependencyInfo(
|
|
71
|
+
"STAGATE_pyG",
|
|
72
|
+
"pip install STAGATE-pyG",
|
|
73
|
+
"Spatial domain identification using graph attention",
|
|
74
|
+
),
|
|
75
|
+
"GraphST": DependencyInfo(
|
|
76
|
+
"GraphST",
|
|
77
|
+
"pip install GraphST",
|
|
78
|
+
"Graph self-supervised contrastive learning for spatial domains",
|
|
79
|
+
),
|
|
80
|
+
"paste": DependencyInfo(
|
|
81
|
+
"paste",
|
|
82
|
+
"pip install paste-bio",
|
|
83
|
+
"Probabilistic alignment of spatial transcriptomics",
|
|
84
|
+
),
|
|
85
|
+
"stalign": DependencyInfo(
|
|
86
|
+
"STalign", "pip install STalign", "Spatial transcriptomics alignment"
|
|
87
|
+
),
|
|
88
|
+
# R Interface
|
|
89
|
+
"rpy2": DependencyInfo(
|
|
90
|
+
"rpy2", "pip install rpy2", "R-Python interface (requires R installation)"
|
|
91
|
+
),
|
|
92
|
+
"anndata2ri": DependencyInfo(
|
|
93
|
+
"anndata2ri",
|
|
94
|
+
"pip install anndata2ri",
|
|
95
|
+
"AnnData to R SingleCellExperiment conversion",
|
|
96
|
+
),
|
|
97
|
+
# Cell Communication
|
|
98
|
+
"liana": DependencyInfo(
|
|
99
|
+
"liana", "pip install liana", "Ligand-receptor analysis framework"
|
|
100
|
+
),
|
|
101
|
+
"cellphonedb": DependencyInfo(
|
|
102
|
+
"cellphonedb",
|
|
103
|
+
"pip install cellphonedb",
|
|
104
|
+
"Statistical method for cell-cell communication",
|
|
105
|
+
),
|
|
106
|
+
"ktplotspy": DependencyInfo(
|
|
107
|
+
"ktplotspy", "pip install ktplotspy", "CellPhoneDB visualization toolkit"
|
|
108
|
+
),
|
|
109
|
+
# RNA Velocity
|
|
110
|
+
"scvelo": DependencyInfo("scvelo", "pip install scvelo", "RNA velocity analysis"),
|
|
111
|
+
"velovi": DependencyInfo(
|
|
112
|
+
"velovi", "pip install velovi", "Variational inference for RNA velocity"
|
|
113
|
+
),
|
|
114
|
+
"cellrank": DependencyInfo(
|
|
115
|
+
"cellrank", "pip install cellrank", "Trajectory inference using RNA velocity"
|
|
116
|
+
),
|
|
117
|
+
"palantir": DependencyInfo(
|
|
118
|
+
"palantir", "pip install palantir", "Trajectory inference for cell fate"
|
|
119
|
+
),
|
|
120
|
+
# Annotation
|
|
121
|
+
"singler": DependencyInfo(
|
|
122
|
+
"singler",
|
|
123
|
+
"pip install singler singlecellexperiment",
|
|
124
|
+
"Reference-based cell type annotation",
|
|
125
|
+
),
|
|
126
|
+
"mllmcelltype": DependencyInfo(
|
|
127
|
+
"mllmcelltype", "pip install mllmcelltype", "LLM-based cell type annotation"
|
|
128
|
+
),
|
|
129
|
+
"celldex": DependencyInfo(
|
|
130
|
+
"celldex", "pip install celldex", "Cell type reference datasets for SingleR"
|
|
131
|
+
),
|
|
132
|
+
# Enrichment
|
|
133
|
+
"gseapy": DependencyInfo(
|
|
134
|
+
"gseapy", "pip install gseapy", "Gene set enrichment analysis"
|
|
135
|
+
),
|
|
136
|
+
"decoupler": DependencyInfo(
|
|
137
|
+
"decoupler", "pip install decoupler", "Functional analysis of omics data"
|
|
138
|
+
),
|
|
139
|
+
# Spatial Statistics
|
|
140
|
+
"sparkx": DependencyInfo(
|
|
141
|
+
"sparkx", "pip install SPARK-X", "SPARK-X non-parametric spatial gene detection"
|
|
142
|
+
),
|
|
143
|
+
"spatialde": DependencyInfo(
|
|
144
|
+
"NaiveDE",
|
|
145
|
+
"pip install SpatialDE",
|
|
146
|
+
"SpatialDE Gaussian process spatial gene detection",
|
|
147
|
+
),
|
|
148
|
+
# CNV
|
|
149
|
+
"infercnvpy": DependencyInfo(
|
|
150
|
+
"infercnvpy", "pip install infercnvpy", "Copy number variation inference"
|
|
151
|
+
),
|
|
152
|
+
# Visualization
|
|
153
|
+
"plotly": DependencyInfo(
|
|
154
|
+
"plotly", "pip install plotly", "Interactive visualization"
|
|
155
|
+
),
|
|
156
|
+
"adjustText": DependencyInfo(
|
|
157
|
+
"adjustText", "pip install adjustText", "Text label placement for matplotlib"
|
|
158
|
+
),
|
|
159
|
+
"splot": DependencyInfo("splot", "pip install splot", "Spatial plotting for PySAL"),
|
|
160
|
+
# Data handling
|
|
161
|
+
"mudata": DependencyInfo(
|
|
162
|
+
"mudata", "pip install mudata", "Multimodal data handling"
|
|
163
|
+
),
|
|
164
|
+
# Integration
|
|
165
|
+
"harmonypy": DependencyInfo(
|
|
166
|
+
"harmonypy", "pip install harmonypy", "Harmony batch integration"
|
|
167
|
+
),
|
|
168
|
+
"scanorama": DependencyInfo(
|
|
169
|
+
"scanorama", "pip install scanorama", "Scanorama batch integration"
|
|
170
|
+
),
|
|
171
|
+
"bbknn": DependencyInfo(
|
|
172
|
+
"bbknn", "pip install bbknn", "Batch balanced k-nearest neighbors"
|
|
173
|
+
),
|
|
174
|
+
# Spatial weights
|
|
175
|
+
"esda": DependencyInfo(
|
|
176
|
+
"esda", "pip install esda", "Exploratory spatial data analysis"
|
|
177
|
+
),
|
|
178
|
+
"libpysal": DependencyInfo(
|
|
179
|
+
"libpysal", "pip install libpysal", "Python spatial analysis library"
|
|
180
|
+
),
|
|
181
|
+
# Other
|
|
182
|
+
"dask": DependencyInfo("dask", "pip install dask", "Parallel computing library"),
|
|
183
|
+
"ot": DependencyInfo("ot", "pip install POT", "Python Optimal Transport library"),
|
|
184
|
+
"louvain": DependencyInfo(
|
|
185
|
+
"louvain", "pip install louvain", "Louvain community detection algorithm"
|
|
186
|
+
),
|
|
187
|
+
"pydeseq2": DependencyInfo(
|
|
188
|
+
"pydeseq2", "pip install pydeseq2", "Python implementation of DESeq2"
|
|
189
|
+
),
|
|
190
|
+
"enrichmap": DependencyInfo(
|
|
191
|
+
"enrichmap", "pip install enrichmap", "Spatial enrichment mapping"
|
|
192
|
+
),
|
|
193
|
+
"pygam": DependencyInfo(
|
|
194
|
+
"pygam", "pip install pygam", "Generalized additive models"
|
|
195
|
+
),
|
|
196
|
+
"skgstat": DependencyInfo(
|
|
197
|
+
"skgstat", "pip install scikit-gstat", "Geostatistical analysis toolkit"
|
|
198
|
+
),
|
|
199
|
+
"sklearn": DependencyInfo(
|
|
200
|
+
"sklearn", "pip install scikit-learn", "Machine learning library"
|
|
201
|
+
),
|
|
202
|
+
"statsmodels": DependencyInfo(
|
|
203
|
+
"statsmodels", "pip install statsmodels", "Statistical models and tests"
|
|
204
|
+
),
|
|
205
|
+
"scipy": DependencyInfo(
|
|
206
|
+
"scipy", "pip install scipy", "Scientific computing library"
|
|
207
|
+
),
|
|
208
|
+
"scanpy": DependencyInfo(
|
|
209
|
+
"scanpy", "pip install scanpy", "Single-cell analysis in Python"
|
|
210
|
+
),
|
|
211
|
+
"Pillow": DependencyInfo("PIL", "pip install Pillow", "Python Imaging Library"),
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
# =============================================================================
|
|
216
|
+
# Core Functions (using @lru_cache for thread-safe caching)
|
|
217
|
+
# =============================================================================
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _get_info(name: str) -> DependencyInfo:
|
|
221
|
+
"""Get dependency info, creating default if not in registry."""
|
|
222
|
+
if name in DEPENDENCY_REGISTRY:
|
|
223
|
+
return DEPENDENCY_REGISTRY[name]
|
|
224
|
+
# Check by module name
|
|
225
|
+
for info in DEPENDENCY_REGISTRY.values():
|
|
226
|
+
if info.module_name == name:
|
|
227
|
+
return info
|
|
228
|
+
# Default for unknown dependencies
|
|
229
|
+
return DependencyInfo(name, f"pip install {name}", f"Optional: {name}")
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@lru_cache(maxsize=256)
|
|
233
|
+
def _try_import(module_name: str) -> Optional[Any]:
|
|
234
|
+
"""Import module with caching. Returns None if unavailable."""
|
|
235
|
+
try:
|
|
236
|
+
return importlib.import_module(module_name)
|
|
237
|
+
except ImportError:
|
|
238
|
+
return None
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@lru_cache(maxsize=256)
|
|
242
|
+
def _check_spec(module_name: str) -> bool:
|
|
243
|
+
"""Fast availability check without importing."""
|
|
244
|
+
return importlib.util.find_spec(module_name) is not None
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
# =============================================================================
|
|
248
|
+
# Public API
|
|
249
|
+
# =============================================================================
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def is_available(name: str) -> bool:
|
|
253
|
+
"""Check if a dependency is available (fast, no import)."""
|
|
254
|
+
return _check_spec(_get_info(name).module_name)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def get(
|
|
258
|
+
name: str,
|
|
259
|
+
ctx: Optional["ToolContext"] = None,
|
|
260
|
+
warn_if_missing: bool = False,
|
|
261
|
+
) -> Optional[Any]:
|
|
262
|
+
"""Get optional dependency, returning None if unavailable."""
|
|
263
|
+
info = _get_info(name)
|
|
264
|
+
module = _try_import(info.module_name)
|
|
265
|
+
|
|
266
|
+
if module is not None:
|
|
267
|
+
return module
|
|
268
|
+
|
|
269
|
+
if warn_if_missing:
|
|
270
|
+
msg = f"{name} not available. Install: {info.install_cmd}"
|
|
271
|
+
warnings.warn(msg, stacklevel=2)
|
|
272
|
+
|
|
273
|
+
return None
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def require(
|
|
277
|
+
name: str,
|
|
278
|
+
ctx: Optional["ToolContext"] = None,
|
|
279
|
+
feature: Optional[str] = None,
|
|
280
|
+
) -> Any:
|
|
281
|
+
"""Require a dependency, raising ImportError if unavailable."""
|
|
282
|
+
info = _get_info(name)
|
|
283
|
+
module = _try_import(info.module_name)
|
|
284
|
+
|
|
285
|
+
if module is not None:
|
|
286
|
+
return module
|
|
287
|
+
|
|
288
|
+
feature_msg = f" for {feature}" if feature else ""
|
|
289
|
+
raise ImportError(
|
|
290
|
+
f"{name} is required{feature_msg}.\n\n"
|
|
291
|
+
f"Install: {info.install_cmd}\n"
|
|
292
|
+
f"Description: {info.description}"
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
# =============================================================================
|
|
297
|
+
# R Environment Validation
|
|
298
|
+
# =============================================================================
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
def validate_r_environment(
|
|
302
|
+
ctx: Optional["ToolContext"] = None,
|
|
303
|
+
required_packages: Optional[list[str]] = None,
|
|
304
|
+
) -> tuple[Any, ...]:
|
|
305
|
+
"""Validate R environment and return required modules.
|
|
306
|
+
|
|
307
|
+
Returns:
|
|
308
|
+
Tuple of (robjects, pandas2ri, numpy2ri, importr, localconverter,
|
|
309
|
+
default_converter, openrlib, anndata2ri)
|
|
310
|
+
|
|
311
|
+
Raises:
|
|
312
|
+
ImportError: If rpy2 or required R packages are not available
|
|
313
|
+
"""
|
|
314
|
+
if not is_available("rpy2"):
|
|
315
|
+
raise ImportError(
|
|
316
|
+
"rpy2 is required for R-based methods. "
|
|
317
|
+
"Install: pip install rpy2 (requires R installation)"
|
|
318
|
+
)
|
|
319
|
+
if not is_available("anndata2ri"):
|
|
320
|
+
raise ImportError(
|
|
321
|
+
"anndata2ri is required for R-based methods. "
|
|
322
|
+
"Install: pip install anndata2ri"
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
try:
|
|
326
|
+
import anndata2ri
|
|
327
|
+
import rpy2.robjects as robjects
|
|
328
|
+
from rpy2.rinterface_lib import openrlib
|
|
329
|
+
from rpy2.robjects import conversion, default_converter, numpy2ri, pandas2ri
|
|
330
|
+
from rpy2.robjects.conversion import localconverter
|
|
331
|
+
from rpy2.robjects.packages import importr
|
|
332
|
+
|
|
333
|
+
# Test R availability
|
|
334
|
+
with openrlib.rlock:
|
|
335
|
+
with conversion.localconverter(default_converter):
|
|
336
|
+
robjects.r("R.version")
|
|
337
|
+
|
|
338
|
+
# Check required R packages
|
|
339
|
+
if required_packages:
|
|
340
|
+
missing = []
|
|
341
|
+
for pkg in required_packages:
|
|
342
|
+
try:
|
|
343
|
+
with openrlib.rlock:
|
|
344
|
+
with conversion.localconverter(default_converter):
|
|
345
|
+
importr(pkg)
|
|
346
|
+
except Exception:
|
|
347
|
+
missing.append(pkg)
|
|
348
|
+
|
|
349
|
+
if missing:
|
|
350
|
+
pkg_list = ", ".join(f"'{p}'" for p in missing)
|
|
351
|
+
raise ImportError(
|
|
352
|
+
f"Missing R packages: {pkg_list}\n"
|
|
353
|
+
f"Install in R: install.packages(c({pkg_list}))"
|
|
354
|
+
)
|
|
355
|
+
|
|
356
|
+
return (
|
|
357
|
+
robjects,
|
|
358
|
+
pandas2ri,
|
|
359
|
+
numpy2ri,
|
|
360
|
+
importr,
|
|
361
|
+
localconverter,
|
|
362
|
+
default_converter,
|
|
363
|
+
openrlib,
|
|
364
|
+
anndata2ri,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
except ImportError:
|
|
368
|
+
raise
|
|
369
|
+
except Exception as e:
|
|
370
|
+
raise ImportError(
|
|
371
|
+
f"R environment setup failed: {e}\n\n"
|
|
372
|
+
"Solutions:\n"
|
|
373
|
+
" - Install R: https://www.r-project.org/\n"
|
|
374
|
+
" - Set R_HOME environment variable\n"
|
|
375
|
+
" - macOS: brew install r\n"
|
|
376
|
+
" - Ubuntu: sudo apt install r-base"
|
|
377
|
+
) from e
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def validate_r_package(
|
|
381
|
+
package_name: str,
|
|
382
|
+
ctx: Optional["ToolContext"] = None,
|
|
383
|
+
install_cmd: Optional[str] = None,
|
|
384
|
+
) -> bool:
|
|
385
|
+
"""Check if an R package is available."""
|
|
386
|
+
if not is_available("rpy2"):
|
|
387
|
+
raise ImportError(
|
|
388
|
+
"rpy2 is required for R-based methods.\n"
|
|
389
|
+
"Install: pip install rpy2 (requires R)"
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
try:
|
|
393
|
+
from rpy2.rinterface_lib import openrlib
|
|
394
|
+
from rpy2.robjects import conversion, default_converter
|
|
395
|
+
from rpy2.robjects.packages import importr
|
|
396
|
+
|
|
397
|
+
with openrlib.rlock:
|
|
398
|
+
with conversion.localconverter(default_converter):
|
|
399
|
+
importr(package_name)
|
|
400
|
+
|
|
401
|
+
return True
|
|
402
|
+
|
|
403
|
+
except Exception as e:
|
|
404
|
+
install = install_cmd or f"install.packages('{package_name}')"
|
|
405
|
+
raise ImportError(
|
|
406
|
+
f"R package '{package_name}' not installed.\n" f"Install in R: {install}"
|
|
407
|
+
) from e
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def check_r_packages(
|
|
411
|
+
packages: list[str],
|
|
412
|
+
ctx: Optional["ToolContext"] = None,
|
|
413
|
+
) -> list[str]:
|
|
414
|
+
"""Check availability of multiple R packages. Returns missing ones."""
|
|
415
|
+
if not is_available("rpy2"):
|
|
416
|
+
return packages
|
|
417
|
+
|
|
418
|
+
missing = []
|
|
419
|
+
for pkg in packages:
|
|
420
|
+
try:
|
|
421
|
+
validate_r_package(pkg)
|
|
422
|
+
except ImportError:
|
|
423
|
+
missing.append(pkg)
|
|
424
|
+
|
|
425
|
+
return missing
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def validate_scvi_tools(
|
|
429
|
+
ctx: Optional["ToolContext"] = None,
|
|
430
|
+
components: Optional[list[str]] = None,
|
|
431
|
+
) -> Any:
|
|
432
|
+
"""Validate scvi-tools availability and return the module."""
|
|
433
|
+
scvi = require("scvi-tools", ctx, "scvi-tools methods")
|
|
434
|
+
|
|
435
|
+
if components:
|
|
436
|
+
missing = []
|
|
437
|
+
for comp in components:
|
|
438
|
+
try:
|
|
439
|
+
if comp == "CellAssign":
|
|
440
|
+
from scvi.external import CellAssign # noqa: F401
|
|
441
|
+
elif comp == "Cell2location":
|
|
442
|
+
import cell2location # noqa: F401
|
|
443
|
+
elif comp == "SCANVI":
|
|
444
|
+
from scvi.model import SCANVI # noqa: F401
|
|
445
|
+
elif comp == "DestVI":
|
|
446
|
+
from scvi.external import DestVI # noqa: F401
|
|
447
|
+
elif comp == "Stereoscope":
|
|
448
|
+
from scvi.external import Stereoscope # noqa: F401
|
|
449
|
+
else:
|
|
450
|
+
getattr(scvi, comp, None) or getattr(
|
|
451
|
+
scvi.model, comp, None
|
|
452
|
+
) or getattr(scvi.external, comp, None)
|
|
453
|
+
except (ImportError, AttributeError):
|
|
454
|
+
missing.append(comp)
|
|
455
|
+
|
|
456
|
+
if missing:
|
|
457
|
+
raise ImportError(
|
|
458
|
+
f"scvi-tools components not available: {', '.join(missing)}\n"
|
|
459
|
+
"Try: pip install --upgrade scvi-tools"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
return scvi
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Device utilities for compute backend selection.
|
|
3
|
+
|
|
4
|
+
This module provides lazy-loaded device detection and selection functions
|
|
5
|
+
for GPU/CPU computation. Follows the same design principles as compute.py.
|
|
6
|
+
|
|
7
|
+
Design Principles:
|
|
8
|
+
1. Lazy Loading: torch is only imported when needed
|
|
9
|
+
2. Pure Functions: No side effects, callers decide how to handle results
|
|
10
|
+
3. String Returns: Callers convert to torch.device if needed
|
|
11
|
+
4. Composable: Basic building blocks for various use cases
|
|
12
|
+
|
|
13
|
+
Usage:
|
|
14
|
+
# Simple device selection
|
|
15
|
+
device = get_device(use_gpu=True)
|
|
16
|
+
|
|
17
|
+
# With warning when GPU unavailable
|
|
18
|
+
device = get_device(use_gpu=params.use_gpu)
|
|
19
|
+
if params.use_gpu and device == "cpu":
|
|
20
|
+
await ctx.warning("GPU requested but not available")
|
|
21
|
+
|
|
22
|
+
# Convert to torch.device when needed
|
|
23
|
+
import torch
|
|
24
|
+
device = torch.device(get_device(use_gpu=True, allow_mps=True))
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
from typing import TYPE_CHECKING, Any
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from ..spatial_mcp_adapter import ToolContext
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# =============================================================================
|
|
34
|
+
# Availability Checks (has_* pattern)
|
|
35
|
+
# =============================================================================
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def cuda_available() -> bool:
|
|
39
|
+
"""Check if CUDA GPU is available.
|
|
40
|
+
|
|
41
|
+
Lazy imports torch to avoid loading it when not needed.
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
True if CUDA is available, False otherwise
|
|
45
|
+
"""
|
|
46
|
+
try:
|
|
47
|
+
import torch
|
|
48
|
+
|
|
49
|
+
return torch.cuda.is_available()
|
|
50
|
+
except ImportError:
|
|
51
|
+
return False
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def mps_available() -> bool:
|
|
55
|
+
"""Check if Apple Silicon MPS is available.
|
|
56
|
+
|
|
57
|
+
Lazy imports torch to avoid loading it when not needed.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
True if MPS is available, False otherwise
|
|
61
|
+
"""
|
|
62
|
+
try:
|
|
63
|
+
import torch
|
|
64
|
+
|
|
65
|
+
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
|
66
|
+
except ImportError:
|
|
67
|
+
return False
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# =============================================================================
|
|
71
|
+
# Device Selection (core function)
|
|
72
|
+
# =============================================================================
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_device(
|
|
76
|
+
prefer_gpu: bool = False,
|
|
77
|
+
allow_mps: bool = False,
|
|
78
|
+
) -> str:
|
|
79
|
+
"""Select compute device based on preference and availability.
|
|
80
|
+
|
|
81
|
+
This is THE single source of truth for device selection across ChatSpatial.
|
|
82
|
+
Returns a device string that can be used directly or converted to torch.device.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
prefer_gpu: If True, try to use GPU (CUDA first, then MPS if allowed)
|
|
86
|
+
allow_mps: If True, allow MPS as fallback when CUDA unavailable
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
Device string: "cuda:0", "mps", or "cpu"
|
|
90
|
+
|
|
91
|
+
Examples:
|
|
92
|
+
# Basic usage
|
|
93
|
+
device = get_device(use_gpu=True) # "cuda:0" or "cpu"
|
|
94
|
+
|
|
95
|
+
# With MPS support (Apple Silicon)
|
|
96
|
+
device = get_device(use_gpu=True, allow_mps=True) # "cuda:0", "mps", or "cpu"
|
|
97
|
+
|
|
98
|
+
# Convert to torch.device
|
|
99
|
+
import torch
|
|
100
|
+
device = torch.device(get_device(prefer_gpu=True))
|
|
101
|
+
|
|
102
|
+
# With warning when requested but unavailable
|
|
103
|
+
device = get_device(params.use_gpu)
|
|
104
|
+
if params.use_gpu and device == "cpu":
|
|
105
|
+
await ctx.warning("GPU requested but not available - using CPU")
|
|
106
|
+
"""
|
|
107
|
+
if prefer_gpu:
|
|
108
|
+
if cuda_available():
|
|
109
|
+
return "cuda:0"
|
|
110
|
+
if allow_mps and mps_available():
|
|
111
|
+
return "mps"
|
|
112
|
+
return "cpu"
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
# =============================================================================
|
|
116
|
+
# Async Helper with Context Warning
|
|
117
|
+
# =============================================================================
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
async def resolve_device_async(
|
|
121
|
+
prefer_gpu: bool,
|
|
122
|
+
ctx: "ToolContext",
|
|
123
|
+
allow_mps: bool = False,
|
|
124
|
+
warn_on_fallback: bool = True,
|
|
125
|
+
) -> str:
|
|
126
|
+
"""Select device with optional warning when GPU unavailable.
|
|
127
|
+
|
|
128
|
+
Convenience function for async tools that want automatic warning.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
prefer_gpu: If True, try to use GPU
|
|
132
|
+
ctx: ToolContext for logging warnings
|
|
133
|
+
allow_mps: If True, allow MPS as fallback
|
|
134
|
+
warn_on_fallback: If True, warn when requested GPU is unavailable
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Device string: "cuda:0", "mps", or "cpu"
|
|
138
|
+
"""
|
|
139
|
+
device = get_device(prefer_gpu=prefer_gpu, allow_mps=allow_mps)
|
|
140
|
+
|
|
141
|
+
if warn_on_fallback and prefer_gpu and device == "cpu":
|
|
142
|
+
await ctx.warning("GPU requested but not available - using CPU")
|
|
143
|
+
|
|
144
|
+
return device
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# =============================================================================
|
|
148
|
+
# Specialized Backend Functions
|
|
149
|
+
# =============================================================================
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def get_ot_backend(use_gpu: bool = False) -> Any:
|
|
153
|
+
"""Get optimal transport backend for PASTE alignment.
|
|
154
|
+
|
|
155
|
+
Args:
|
|
156
|
+
use_gpu: If True, try to use TorchBackend with CUDA
|
|
157
|
+
|
|
158
|
+
Returns:
|
|
159
|
+
POT backend (TorchBackend if CUDA available and requested, else NumpyBackend)
|
|
160
|
+
"""
|
|
161
|
+
import ot
|
|
162
|
+
|
|
163
|
+
if use_gpu and cuda_available():
|
|
164
|
+
return ot.backend.TorchBackend()
|
|
165
|
+
return ot.backend.NumpyBackend()
|