chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,462 @@
1
+ """
2
+ Unified dependency management for ChatSpatial MCP.
3
+
4
+ Provides a consistent API for managing optional dependencies, replacing
5
+ scattered try/except ImportError patterns with centralized handling.
6
+
7
+ Usage:
8
+ # Require a dependency (raises if missing)
9
+ scvi = require("scvi-tools", feature="cell type annotation")
10
+
11
+ # Get optional dependency (returns None if missing)
12
+ torch = get("torch")
13
+
14
+ # Check availability
15
+ if is_available("rpy2"):
16
+ import rpy2
17
+ """
18
+
19
+ import importlib
20
+ import importlib.util
21
+ import warnings
22
+ from dataclasses import dataclass
23
+ from functools import lru_cache
24
+ from typing import TYPE_CHECKING, Any, Optional
25
+
26
+ if TYPE_CHECKING:
27
+ from ..spatial_mcp_adapter import ToolContext
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class DependencyInfo:
32
+ """Metadata for an optional dependency."""
33
+
34
+ module_name: str
35
+ install_cmd: str
36
+ description: str = ""
37
+
38
+
39
+ # Registry of optional dependencies with install instructions
40
+ DEPENDENCY_REGISTRY: dict[str, DependencyInfo] = {
41
+ # Deep Learning
42
+ "scvi-tools": DependencyInfo(
43
+ "scvi", "pip install scvi-tools", "Single-cell variational inference tools"
44
+ ),
45
+ "torch": DependencyInfo(
46
+ "torch", "pip install torch", "PyTorch deep learning framework"
47
+ ),
48
+ "cell2location": DependencyInfo(
49
+ "cell2location",
50
+ "pip install cell2location",
51
+ "Probabilistic cell type deconvolution",
52
+ ),
53
+ "flashdeconv": DependencyInfo(
54
+ "flashdeconv", "pip install flashdeconv", "Ultra-fast spatial deconvolution"
55
+ ),
56
+ # Spatial Analysis
57
+ "tangram": DependencyInfo(
58
+ "tangram",
59
+ "pip install tangram-sc",
60
+ "Spatial mapping of single-cell transcriptomics",
61
+ ),
62
+ "squidpy": DependencyInfo(
63
+ "squidpy", "pip install squidpy", "Spatial single-cell analysis"
64
+ ),
65
+ "SpaGCN": DependencyInfo(
66
+ "SpaGCN",
67
+ "pip install SpaGCN",
68
+ "Spatial domain identification using graph convolutional networks",
69
+ ),
70
+ "STAGATE": DependencyInfo(
71
+ "STAGATE_pyG",
72
+ "pip install STAGATE-pyG",
73
+ "Spatial domain identification using graph attention",
74
+ ),
75
+ "GraphST": DependencyInfo(
76
+ "GraphST",
77
+ "pip install GraphST",
78
+ "Graph self-supervised contrastive learning for spatial domains",
79
+ ),
80
+ "paste": DependencyInfo(
81
+ "paste",
82
+ "pip install paste-bio",
83
+ "Probabilistic alignment of spatial transcriptomics",
84
+ ),
85
+ "stalign": DependencyInfo(
86
+ "STalign", "pip install STalign", "Spatial transcriptomics alignment"
87
+ ),
88
+ # R Interface
89
+ "rpy2": DependencyInfo(
90
+ "rpy2", "pip install rpy2", "R-Python interface (requires R installation)"
91
+ ),
92
+ "anndata2ri": DependencyInfo(
93
+ "anndata2ri",
94
+ "pip install anndata2ri",
95
+ "AnnData to R SingleCellExperiment conversion",
96
+ ),
97
+ # Cell Communication
98
+ "liana": DependencyInfo(
99
+ "liana", "pip install liana", "Ligand-receptor analysis framework"
100
+ ),
101
+ "cellphonedb": DependencyInfo(
102
+ "cellphonedb",
103
+ "pip install cellphonedb",
104
+ "Statistical method for cell-cell communication",
105
+ ),
106
+ "ktplotspy": DependencyInfo(
107
+ "ktplotspy", "pip install ktplotspy", "CellPhoneDB visualization toolkit"
108
+ ),
109
+ # RNA Velocity
110
+ "scvelo": DependencyInfo("scvelo", "pip install scvelo", "RNA velocity analysis"),
111
+ "velovi": DependencyInfo(
112
+ "velovi", "pip install velovi", "Variational inference for RNA velocity"
113
+ ),
114
+ "cellrank": DependencyInfo(
115
+ "cellrank", "pip install cellrank", "Trajectory inference using RNA velocity"
116
+ ),
117
+ "palantir": DependencyInfo(
118
+ "palantir", "pip install palantir", "Trajectory inference for cell fate"
119
+ ),
120
+ # Annotation
121
+ "singler": DependencyInfo(
122
+ "singler",
123
+ "pip install singler singlecellexperiment",
124
+ "Reference-based cell type annotation",
125
+ ),
126
+ "mllmcelltype": DependencyInfo(
127
+ "mllmcelltype", "pip install mllmcelltype", "LLM-based cell type annotation"
128
+ ),
129
+ "celldex": DependencyInfo(
130
+ "celldex", "pip install celldex", "Cell type reference datasets for SingleR"
131
+ ),
132
+ # Enrichment
133
+ "gseapy": DependencyInfo(
134
+ "gseapy", "pip install gseapy", "Gene set enrichment analysis"
135
+ ),
136
+ "decoupler": DependencyInfo(
137
+ "decoupler", "pip install decoupler", "Functional analysis of omics data"
138
+ ),
139
+ # Spatial Statistics
140
+ "sparkx": DependencyInfo(
141
+ "sparkx", "pip install SPARK-X", "SPARK-X non-parametric spatial gene detection"
142
+ ),
143
+ "spatialde": DependencyInfo(
144
+ "NaiveDE",
145
+ "pip install SpatialDE",
146
+ "SpatialDE Gaussian process spatial gene detection",
147
+ ),
148
+ # CNV
149
+ "infercnvpy": DependencyInfo(
150
+ "infercnvpy", "pip install infercnvpy", "Copy number variation inference"
151
+ ),
152
+ # Visualization
153
+ "plotly": DependencyInfo(
154
+ "plotly", "pip install plotly", "Interactive visualization"
155
+ ),
156
+ "adjustText": DependencyInfo(
157
+ "adjustText", "pip install adjustText", "Text label placement for matplotlib"
158
+ ),
159
+ "splot": DependencyInfo("splot", "pip install splot", "Spatial plotting for PySAL"),
160
+ # Data handling
161
+ "mudata": DependencyInfo(
162
+ "mudata", "pip install mudata", "Multimodal data handling"
163
+ ),
164
+ # Integration
165
+ "harmonypy": DependencyInfo(
166
+ "harmonypy", "pip install harmonypy", "Harmony batch integration"
167
+ ),
168
+ "scanorama": DependencyInfo(
169
+ "scanorama", "pip install scanorama", "Scanorama batch integration"
170
+ ),
171
+ "bbknn": DependencyInfo(
172
+ "bbknn", "pip install bbknn", "Batch balanced k-nearest neighbors"
173
+ ),
174
+ # Spatial weights
175
+ "esda": DependencyInfo(
176
+ "esda", "pip install esda", "Exploratory spatial data analysis"
177
+ ),
178
+ "libpysal": DependencyInfo(
179
+ "libpysal", "pip install libpysal", "Python spatial analysis library"
180
+ ),
181
+ # Other
182
+ "dask": DependencyInfo("dask", "pip install dask", "Parallel computing library"),
183
+ "ot": DependencyInfo("ot", "pip install POT", "Python Optimal Transport library"),
184
+ "louvain": DependencyInfo(
185
+ "louvain", "pip install louvain", "Louvain community detection algorithm"
186
+ ),
187
+ "pydeseq2": DependencyInfo(
188
+ "pydeseq2", "pip install pydeseq2", "Python implementation of DESeq2"
189
+ ),
190
+ "enrichmap": DependencyInfo(
191
+ "enrichmap", "pip install enrichmap", "Spatial enrichment mapping"
192
+ ),
193
+ "pygam": DependencyInfo(
194
+ "pygam", "pip install pygam", "Generalized additive models"
195
+ ),
196
+ "skgstat": DependencyInfo(
197
+ "skgstat", "pip install scikit-gstat", "Geostatistical analysis toolkit"
198
+ ),
199
+ "sklearn": DependencyInfo(
200
+ "sklearn", "pip install scikit-learn", "Machine learning library"
201
+ ),
202
+ "statsmodels": DependencyInfo(
203
+ "statsmodels", "pip install statsmodels", "Statistical models and tests"
204
+ ),
205
+ "scipy": DependencyInfo(
206
+ "scipy", "pip install scipy", "Scientific computing library"
207
+ ),
208
+ "scanpy": DependencyInfo(
209
+ "scanpy", "pip install scanpy", "Single-cell analysis in Python"
210
+ ),
211
+ "Pillow": DependencyInfo("PIL", "pip install Pillow", "Python Imaging Library"),
212
+ }
213
+
214
+
215
+ # =============================================================================
216
+ # Core Functions (using @lru_cache for thread-safe caching)
217
+ # =============================================================================
218
+
219
+
220
+ def _get_info(name: str) -> DependencyInfo:
221
+ """Get dependency info, creating default if not in registry."""
222
+ if name in DEPENDENCY_REGISTRY:
223
+ return DEPENDENCY_REGISTRY[name]
224
+ # Check by module name
225
+ for info in DEPENDENCY_REGISTRY.values():
226
+ if info.module_name == name:
227
+ return info
228
+ # Default for unknown dependencies
229
+ return DependencyInfo(name, f"pip install {name}", f"Optional: {name}")
230
+
231
+
232
+ @lru_cache(maxsize=256)
233
+ def _try_import(module_name: str) -> Optional[Any]:
234
+ """Import module with caching. Returns None if unavailable."""
235
+ try:
236
+ return importlib.import_module(module_name)
237
+ except ImportError:
238
+ return None
239
+
240
+
241
+ @lru_cache(maxsize=256)
242
+ def _check_spec(module_name: str) -> bool:
243
+ """Fast availability check without importing."""
244
+ return importlib.util.find_spec(module_name) is not None
245
+
246
+
247
+ # =============================================================================
248
+ # Public API
249
+ # =============================================================================
250
+
251
+
252
+ def is_available(name: str) -> bool:
253
+ """Check if a dependency is available (fast, no import)."""
254
+ return _check_spec(_get_info(name).module_name)
255
+
256
+
257
+ def get(
258
+ name: str,
259
+ ctx: Optional["ToolContext"] = None,
260
+ warn_if_missing: bool = False,
261
+ ) -> Optional[Any]:
262
+ """Get optional dependency, returning None if unavailable."""
263
+ info = _get_info(name)
264
+ module = _try_import(info.module_name)
265
+
266
+ if module is not None:
267
+ return module
268
+
269
+ if warn_if_missing:
270
+ msg = f"{name} not available. Install: {info.install_cmd}"
271
+ warnings.warn(msg, stacklevel=2)
272
+
273
+ return None
274
+
275
+
276
+ def require(
277
+ name: str,
278
+ ctx: Optional["ToolContext"] = None,
279
+ feature: Optional[str] = None,
280
+ ) -> Any:
281
+ """Require a dependency, raising ImportError if unavailable."""
282
+ info = _get_info(name)
283
+ module = _try_import(info.module_name)
284
+
285
+ if module is not None:
286
+ return module
287
+
288
+ feature_msg = f" for {feature}" if feature else ""
289
+ raise ImportError(
290
+ f"{name} is required{feature_msg}.\n\n"
291
+ f"Install: {info.install_cmd}\n"
292
+ f"Description: {info.description}"
293
+ )
294
+
295
+
296
+ # =============================================================================
297
+ # R Environment Validation
298
+ # =============================================================================
299
+
300
+
301
+ def validate_r_environment(
302
+ ctx: Optional["ToolContext"] = None,
303
+ required_packages: Optional[list[str]] = None,
304
+ ) -> tuple[Any, ...]:
305
+ """Validate R environment and return required modules.
306
+
307
+ Returns:
308
+ Tuple of (robjects, pandas2ri, numpy2ri, importr, localconverter,
309
+ default_converter, openrlib, anndata2ri)
310
+
311
+ Raises:
312
+ ImportError: If rpy2 or required R packages are not available
313
+ """
314
+ if not is_available("rpy2"):
315
+ raise ImportError(
316
+ "rpy2 is required for R-based methods. "
317
+ "Install: pip install rpy2 (requires R installation)"
318
+ )
319
+ if not is_available("anndata2ri"):
320
+ raise ImportError(
321
+ "anndata2ri is required for R-based methods. "
322
+ "Install: pip install anndata2ri"
323
+ )
324
+
325
+ try:
326
+ import anndata2ri
327
+ import rpy2.robjects as robjects
328
+ from rpy2.rinterface_lib import openrlib
329
+ from rpy2.robjects import conversion, default_converter, numpy2ri, pandas2ri
330
+ from rpy2.robjects.conversion import localconverter
331
+ from rpy2.robjects.packages import importr
332
+
333
+ # Test R availability
334
+ with openrlib.rlock:
335
+ with conversion.localconverter(default_converter):
336
+ robjects.r("R.version")
337
+
338
+ # Check required R packages
339
+ if required_packages:
340
+ missing = []
341
+ for pkg in required_packages:
342
+ try:
343
+ with openrlib.rlock:
344
+ with conversion.localconverter(default_converter):
345
+ importr(pkg)
346
+ except Exception:
347
+ missing.append(pkg)
348
+
349
+ if missing:
350
+ pkg_list = ", ".join(f"'{p}'" for p in missing)
351
+ raise ImportError(
352
+ f"Missing R packages: {pkg_list}\n"
353
+ f"Install in R: install.packages(c({pkg_list}))"
354
+ )
355
+
356
+ return (
357
+ robjects,
358
+ pandas2ri,
359
+ numpy2ri,
360
+ importr,
361
+ localconverter,
362
+ default_converter,
363
+ openrlib,
364
+ anndata2ri,
365
+ )
366
+
367
+ except ImportError:
368
+ raise
369
+ except Exception as e:
370
+ raise ImportError(
371
+ f"R environment setup failed: {e}\n\n"
372
+ "Solutions:\n"
373
+ " - Install R: https://www.r-project.org/\n"
374
+ " - Set R_HOME environment variable\n"
375
+ " - macOS: brew install r\n"
376
+ " - Ubuntu: sudo apt install r-base"
377
+ ) from e
378
+
379
+
380
+ def validate_r_package(
381
+ package_name: str,
382
+ ctx: Optional["ToolContext"] = None,
383
+ install_cmd: Optional[str] = None,
384
+ ) -> bool:
385
+ """Check if an R package is available."""
386
+ if not is_available("rpy2"):
387
+ raise ImportError(
388
+ "rpy2 is required for R-based methods.\n"
389
+ "Install: pip install rpy2 (requires R)"
390
+ )
391
+
392
+ try:
393
+ from rpy2.rinterface_lib import openrlib
394
+ from rpy2.robjects import conversion, default_converter
395
+ from rpy2.robjects.packages import importr
396
+
397
+ with openrlib.rlock:
398
+ with conversion.localconverter(default_converter):
399
+ importr(package_name)
400
+
401
+ return True
402
+
403
+ except Exception as e:
404
+ install = install_cmd or f"install.packages('{package_name}')"
405
+ raise ImportError(
406
+ f"R package '{package_name}' not installed.\n" f"Install in R: {install}"
407
+ ) from e
408
+
409
+
410
+ def check_r_packages(
411
+ packages: list[str],
412
+ ctx: Optional["ToolContext"] = None,
413
+ ) -> list[str]:
414
+ """Check availability of multiple R packages. Returns missing ones."""
415
+ if not is_available("rpy2"):
416
+ return packages
417
+
418
+ missing = []
419
+ for pkg in packages:
420
+ try:
421
+ validate_r_package(pkg)
422
+ except ImportError:
423
+ missing.append(pkg)
424
+
425
+ return missing
426
+
427
+
428
+ def validate_scvi_tools(
429
+ ctx: Optional["ToolContext"] = None,
430
+ components: Optional[list[str]] = None,
431
+ ) -> Any:
432
+ """Validate scvi-tools availability and return the module."""
433
+ scvi = require("scvi-tools", ctx, "scvi-tools methods")
434
+
435
+ if components:
436
+ missing = []
437
+ for comp in components:
438
+ try:
439
+ if comp == "CellAssign":
440
+ from scvi.external import CellAssign # noqa: F401
441
+ elif comp == "Cell2location":
442
+ import cell2location # noqa: F401
443
+ elif comp == "SCANVI":
444
+ from scvi.model import SCANVI # noqa: F401
445
+ elif comp == "DestVI":
446
+ from scvi.external import DestVI # noqa: F401
447
+ elif comp == "Stereoscope":
448
+ from scvi.external import Stereoscope # noqa: F401
449
+ else:
450
+ getattr(scvi, comp, None) or getattr(
451
+ scvi.model, comp, None
452
+ ) or getattr(scvi.external, comp, None)
453
+ except (ImportError, AttributeError):
454
+ missing.append(comp)
455
+
456
+ if missing:
457
+ raise ImportError(
458
+ f"scvi-tools components not available: {', '.join(missing)}\n"
459
+ "Try: pip install --upgrade scvi-tools"
460
+ )
461
+
462
+ return scvi
@@ -0,0 +1,165 @@
1
+ """
2
+ Device utilities for compute backend selection.
3
+
4
+ This module provides lazy-loaded device detection and selection functions
5
+ for GPU/CPU computation. Follows the same design principles as compute.py.
6
+
7
+ Design Principles:
8
+ 1. Lazy Loading: torch is only imported when needed
9
+ 2. Pure Functions: No side effects, callers decide how to handle results
10
+ 3. String Returns: Callers convert to torch.device if needed
11
+ 4. Composable: Basic building blocks for various use cases
12
+
13
+ Usage:
14
+ # Simple device selection
15
+ device = get_device(use_gpu=True)
16
+
17
+ # With warning when GPU unavailable
18
+ device = get_device(use_gpu=params.use_gpu)
19
+ if params.use_gpu and device == "cpu":
20
+ await ctx.warning("GPU requested but not available")
21
+
22
+ # Convert to torch.device when needed
23
+ import torch
24
+ device = torch.device(get_device(use_gpu=True, allow_mps=True))
25
+ """
26
+
27
+ from typing import TYPE_CHECKING, Any
28
+
29
+ if TYPE_CHECKING:
30
+ from ..spatial_mcp_adapter import ToolContext
31
+
32
+
33
+ # =============================================================================
34
+ # Availability Checks (has_* pattern)
35
+ # =============================================================================
36
+
37
+
38
+ def cuda_available() -> bool:
39
+ """Check if CUDA GPU is available.
40
+
41
+ Lazy imports torch to avoid loading it when not needed.
42
+
43
+ Returns:
44
+ True if CUDA is available, False otherwise
45
+ """
46
+ try:
47
+ import torch
48
+
49
+ return torch.cuda.is_available()
50
+ except ImportError:
51
+ return False
52
+
53
+
54
+ def mps_available() -> bool:
55
+ """Check if Apple Silicon MPS is available.
56
+
57
+ Lazy imports torch to avoid loading it when not needed.
58
+
59
+ Returns:
60
+ True if MPS is available, False otherwise
61
+ """
62
+ try:
63
+ import torch
64
+
65
+ return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
66
+ except ImportError:
67
+ return False
68
+
69
+
70
+ # =============================================================================
71
+ # Device Selection (core function)
72
+ # =============================================================================
73
+
74
+
75
+ def get_device(
76
+ prefer_gpu: bool = False,
77
+ allow_mps: bool = False,
78
+ ) -> str:
79
+ """Select compute device based on preference and availability.
80
+
81
+ This is THE single source of truth for device selection across ChatSpatial.
82
+ Returns a device string that can be used directly or converted to torch.device.
83
+
84
+ Args:
85
+ prefer_gpu: If True, try to use GPU (CUDA first, then MPS if allowed)
86
+ allow_mps: If True, allow MPS as fallback when CUDA unavailable
87
+
88
+ Returns:
89
+ Device string: "cuda:0", "mps", or "cpu"
90
+
91
+ Examples:
92
+ # Basic usage
93
+ device = get_device(use_gpu=True) # "cuda:0" or "cpu"
94
+
95
+ # With MPS support (Apple Silicon)
96
+ device = get_device(use_gpu=True, allow_mps=True) # "cuda:0", "mps", or "cpu"
97
+
98
+ # Convert to torch.device
99
+ import torch
100
+ device = torch.device(get_device(prefer_gpu=True))
101
+
102
+ # With warning when requested but unavailable
103
+ device = get_device(params.use_gpu)
104
+ if params.use_gpu and device == "cpu":
105
+ await ctx.warning("GPU requested but not available - using CPU")
106
+ """
107
+ if prefer_gpu:
108
+ if cuda_available():
109
+ return "cuda:0"
110
+ if allow_mps and mps_available():
111
+ return "mps"
112
+ return "cpu"
113
+
114
+
115
+ # =============================================================================
116
+ # Async Helper with Context Warning
117
+ # =============================================================================
118
+
119
+
120
+ async def resolve_device_async(
121
+ prefer_gpu: bool,
122
+ ctx: "ToolContext",
123
+ allow_mps: bool = False,
124
+ warn_on_fallback: bool = True,
125
+ ) -> str:
126
+ """Select device with optional warning when GPU unavailable.
127
+
128
+ Convenience function for async tools that want automatic warning.
129
+
130
+ Args:
131
+ prefer_gpu: If True, try to use GPU
132
+ ctx: ToolContext for logging warnings
133
+ allow_mps: If True, allow MPS as fallback
134
+ warn_on_fallback: If True, warn when requested GPU is unavailable
135
+
136
+ Returns:
137
+ Device string: "cuda:0", "mps", or "cpu"
138
+ """
139
+ device = get_device(prefer_gpu=prefer_gpu, allow_mps=allow_mps)
140
+
141
+ if warn_on_fallback and prefer_gpu and device == "cpu":
142
+ await ctx.warning("GPU requested but not available - using CPU")
143
+
144
+ return device
145
+
146
+
147
+ # =============================================================================
148
+ # Specialized Backend Functions
149
+ # =============================================================================
150
+
151
+
152
+ def get_ot_backend(use_gpu: bool = False) -> Any:
153
+ """Get optimal transport backend for PASTE alignment.
154
+
155
+ Args:
156
+ use_gpu: If True, try to use TorchBackend with CUDA
157
+
158
+ Returns:
159
+ POT backend (TorchBackend if CUDA available and requested, else NumpyBackend)
160
+ """
161
+ import ot
162
+
163
+ if use_gpu and cuda_available():
164
+ return ot.backend.TorchBackend()
165
+ return ot.backend.NumpyBackend()