chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,495 @@
1
+ """
2
+ Trajectory inference for spatial transcriptomics.
3
+
4
+ This module infers cellular trajectories and pseudotime by combining
5
+ expression patterns with optional velocity and spatial information.
6
+
7
+ Key functionality:
8
+ - `analyze_trajectory`: Main MCP entry point for trajectory analysis
9
+ - Supports CellRank (velocity-based), Palantir (expression-based), and DPT (diffusion-based)
10
+ """
11
+
12
+ from typing import TYPE_CHECKING, Any, Optional
13
+
14
+ import pandas as pd
15
+
16
+ if TYPE_CHECKING:
17
+ from ..spatial_mcp_adapter import ToolContext
18
+
19
+ from ..models.analysis import TrajectoryResult
20
+ from ..models.data import TrajectoryParameters
21
+ from ..utils.adata_utils import (
22
+ get_spatial_key,
23
+ require_spatial_coords,
24
+ validate_obs_column,
25
+ )
26
+ from ..utils.compute import ensure_diffmap, ensure_neighbors, ensure_pca
27
+ from ..utils.dependency_manager import require
28
+ from ..utils.exceptions import (
29
+ DataError,
30
+ DataNotFoundError,
31
+ ParameterError,
32
+ ProcessingError,
33
+ )
34
+ from ..utils.mcp_utils import suppress_output
35
+
36
+
37
+ def prepare_gam_model_for_visualization(
38
+ adata,
39
+ genes: list,
40
+ time_key: str = "latent_time",
41
+ fate_key: str = "lineages_fwd",
42
+ ):
43
+ """
44
+ Prepare a GAM model for CellRank gene trends visualization.
45
+
46
+ This function handles the computation logic needed for CellRank 2.0 gene trends
47
+ and fate heatmap visualizations. Requires data analyzed via analyze_rna_velocity
48
+ (dynamical mode) and analyze_trajectory (cellrank method).
49
+
50
+ Parameters
51
+ ----------
52
+ adata : AnnData
53
+ The annotated data matrix with CellRank results.
54
+ genes : list
55
+ List of gene names to prepare the model for.
56
+ time_key : str, default 'latent_time'
57
+ Key in adata.obs for pseudotime/latent time values.
58
+ fate_key : str, default 'lineages_fwd'
59
+ Key in adata.obsm for fate probabilities.
60
+
61
+ Returns
62
+ -------
63
+ tuple
64
+ (model, lineage_names) - The GAM model and list of lineage names.
65
+ """
66
+ require("cellrank")
67
+ from cellrank.models import GAM
68
+
69
+ # Validate required data
70
+ validate_obs_column(adata, time_key, "Time")
71
+
72
+ if fate_key not in adata.obsm:
73
+ raise DataNotFoundError(
74
+ f"Fate probabilities '{fate_key}' not found. Run analyze_trajectory first."
75
+ )
76
+
77
+ # Validate Lineage object has names
78
+ fate_probs = adata.obsm[fate_key]
79
+ if not hasattr(fate_probs, "names") or fate_probs.names is None:
80
+ raise DataError(
81
+ "Fate probabilities must be a CellRank Lineage object with names. "
82
+ "This requires running the full analysis pipeline in memory:\n"
83
+ "1. analyze_rna_velocity(data_id, params={'scvelo_mode': 'dynamical'})\n"
84
+ "2. analyze_trajectory(data_id, params={'method': 'cellrank'})\n"
85
+ "3. Then visualize with plot_type='trajectory', subtype='gene_trends'"
86
+ )
87
+ lineage_names = list(fate_probs.names)
88
+
89
+ # Validate genes exist
90
+ missing_genes = [g for g in genes if g not in adata.var_names]
91
+ if missing_genes:
92
+ raise DataNotFoundError(
93
+ f"Genes not found in data: {missing_genes}. "
94
+ f"Available genes: {list(adata.var_names[:10])}..."
95
+ )
96
+
97
+ model = GAM(adata)
98
+ return model, lineage_names
99
+
100
+
101
+ def infer_spatial_trajectory_cellrank(
102
+ adata, spatial_weight=0.5, kernel_weights=(0.8, 0.2), n_states=5
103
+ ):
104
+ """
105
+ Infers cellular trajectories by combining RNA velocity with CellRank.
106
+
107
+ This function uses CellRank to model cell-state transitions by constructing
108
+ a transition matrix from multiple kernels:
109
+ 1. A velocity kernel from RNA velocity.
110
+ 2. A connectivity kernel based on transcriptomic similarity.
111
+ 3. (Optional) A spatial kernel based on physical proximity.
112
+
113
+ Raises ProcessingError if CellRank computation fails.
114
+ """
115
+ import cellrank as cr
116
+ import numpy as np
117
+ from scipy.sparse import csr_matrix
118
+ from scipy.spatial.distance import pdist, squareform
119
+
120
+ # Check if spatial data is available
121
+ spatial_key = get_spatial_key(adata)
122
+ has_spatial = spatial_key is not None
123
+
124
+ if not has_spatial and spatial_weight > 0:
125
+ spatial_weight = 0
126
+
127
+ # Handle different velocity methods
128
+ if "velocity_method" in adata.uns and adata.uns["velocity_method"] == "velovi":
129
+ if "velovi_adata" in adata.uns:
130
+ adata_for_cellrank = adata.uns["velovi_adata"]
131
+ if has_spatial:
132
+ adata_for_cellrank.obsm["spatial"] = adata.obsm[spatial_key]
133
+
134
+ if "velocity_velovi" in adata_for_cellrank.layers:
135
+ adata_for_cellrank.layers["velocity"] = adata_for_cellrank.layers[
136
+ "velocity_velovi"
137
+ ]
138
+
139
+ vk = cr.kernels.VelocityKernel(adata_for_cellrank)
140
+ vk.compute_transition_matrix()
141
+ else:
142
+ raise ProcessingError("VELOVI velocity data not found")
143
+ else:
144
+ adata_for_cellrank = adata
145
+ vk = cr.kernels.VelocityKernel(adata_for_cellrank)
146
+ vk.compute_transition_matrix()
147
+
148
+ # Create connectivity kernel
149
+ ck = cr.kernels.ConnectivityKernel(adata_for_cellrank)
150
+ ck.compute_transition_matrix()
151
+
152
+ # Combine kernels
153
+ vk_weight, ck_weight = kernel_weights
154
+
155
+ if has_spatial and spatial_weight > 0:
156
+ spatial_coords = adata.obsm[spatial_key]
157
+ spatial_dist = squareform(pdist(spatial_coords))
158
+ spatial_sim = np.exp(-spatial_dist / spatial_dist.mean())
159
+ spatial_kernel = csr_matrix(spatial_sim)
160
+
161
+ sk = cr.kernels.PrecomputedKernel(spatial_kernel, adata_for_cellrank)
162
+ sk.compute_transition_matrix()
163
+
164
+ combined_kernel = (1 - spatial_weight) * (
165
+ vk_weight * vk + ck_weight * ck
166
+ ) + spatial_weight * sk
167
+ else:
168
+ combined_kernel = vk_weight * vk + ck_weight * ck
169
+
170
+ # GPCCA analysis
171
+ g = cr.estimators.GPCCA(combined_kernel)
172
+ g.compute_eigendecomposition()
173
+
174
+ try:
175
+ g.compute_macrostates(n_states=n_states)
176
+ except Exception as e:
177
+ raise ProcessingError(
178
+ f"CellRank failed with n_states={n_states}: {e}. "
179
+ f"Try reducing n_states or use method='palantir'/'dpt'."
180
+ ) from e
181
+
182
+ # Predict terminal states
183
+ try:
184
+ g.predict_terminal_states(method="stability")
185
+ except ValueError as e:
186
+ if "No macrostates have been selected" not in str(e):
187
+ raise
188
+
189
+ # Check terminal states and compute fate probabilities
190
+ has_terminal_states = (
191
+ hasattr(g, "terminal_states") and g.terminal_states is not None
192
+ )
193
+
194
+ if has_terminal_states and len(g.terminal_states.cat.categories) > 0:
195
+ g.compute_fate_probabilities()
196
+ absorption_probs = g.fate_probabilities
197
+ terminal_states = list(g.terminal_states.cat.categories)
198
+ root_state = terminal_states[0]
199
+ pseudotime = 1 - absorption_probs[root_state].X.flatten()
200
+
201
+ adata_for_cellrank.obs["pseudotime"] = pseudotime
202
+ adata_for_cellrank.obsm["fate_probabilities"] = absorption_probs
203
+ adata_for_cellrank.obs["terminal_states"] = g.terminal_states
204
+ else:
205
+ if hasattr(g, "macrostates") and g.macrostates is not None:
206
+ macrostate_probs = g.macrostates_memberships
207
+ pseudotime = 1 - macrostate_probs[:, 0].X.flatten()
208
+ adata_for_cellrank.obs["pseudotime"] = pseudotime
209
+ else:
210
+ raise ProcessingError(
211
+ "CellRank could not compute either terminal states or macrostates"
212
+ )
213
+
214
+ if hasattr(g, "macrostates") and g.macrostates is not None:
215
+ adata_for_cellrank.obs["macrostates"] = g.macrostates
216
+
217
+ # Transfer results back to original adata
218
+ if "pseudotime" in adata_for_cellrank.obs:
219
+ adata.obs["pseudotime"] = adata_for_cellrank.obs["pseudotime"]
220
+ if "terminal_states" in adata_for_cellrank.obs:
221
+ adata.obs["terminal_states"] = adata_for_cellrank.obs["terminal_states"]
222
+ if "macrostates" in adata_for_cellrank.obs:
223
+ adata.obs["macrostates"] = adata_for_cellrank.obs["macrostates"]
224
+ if "fate_probabilities" in adata_for_cellrank.obsm:
225
+ adata.obsm["fate_probabilities"] = adata_for_cellrank.obsm["fate_probabilities"]
226
+
227
+ # Update velovi_adata if used
228
+ if (
229
+ adata.uns.get("velocity_method") == "velovi"
230
+ and "velovi_adata" in adata.uns
231
+ ):
232
+ adata.uns["velovi_adata"] = adata_for_cellrank
233
+
234
+ return adata
235
+
236
+
237
+ def spatial_aware_embedding(adata, spatial_weight=0.3):
238
+ """Generate spatially-aware low-dimensional embedding."""
239
+ from sklearn.metrics.pairwise import euclidean_distances
240
+ from umap import UMAP
241
+
242
+ spatial_coords = require_spatial_coords(adata)
243
+ ensure_pca(adata)
244
+
245
+ expr_dist = euclidean_distances(adata.obsm["X_pca"])
246
+ spatial_dist = euclidean_distances(spatial_coords)
247
+ combined_dist = (1 - spatial_weight) * expr_dist + spatial_weight * spatial_dist
248
+
249
+ umap_op = UMAP(metric="precomputed")
250
+ embedding = umap_op.fit_transform(combined_dist)
251
+ adata.obsm["X_spatial_umap"] = embedding
252
+
253
+ return adata
254
+
255
+
256
+ def infer_pseudotime_palantir(
257
+ adata, root_cells=None, n_diffusion_components=10, num_waypoints=500
258
+ ):
259
+ """
260
+ Infers cellular trajectories and pseudotime using Palantir.
261
+
262
+ Palantir models differentiation as a stochastic process on a graph,
263
+ using diffusion maps to capture data geometry and computing fate
264
+ probabilities via random walks from a root cell.
265
+
266
+ Parameters
267
+ ----------
268
+ adata : AnnData
269
+ The annotated data matrix with PCA results.
270
+ root_cells : list of str, optional
271
+ Cell identifiers as starting points. Auto-selected if not provided.
272
+ n_diffusion_components : int, default 10
273
+ Number of diffusion components.
274
+ num_waypoints : int, default 500
275
+ Number of waypoints for trajectory granularity.
276
+ """
277
+ import palantir
278
+
279
+ ensure_pca(adata)
280
+
281
+ pca_df = pd.DataFrame(adata.obsm["X_pca"], index=adata.obs_names)
282
+ dm_res = palantir.utils.run_diffusion_maps(
283
+ pca_df, n_components=n_diffusion_components
284
+ )
285
+ ms_data = pd.DataFrame(dm_res["EigenVectors"], index=pca_df.index)
286
+
287
+ if root_cells is not None and len(root_cells) > 0:
288
+ if root_cells[0] not in ms_data.index:
289
+ raise ParameterError(f"Root cell '{root_cells[0]}' not found in data")
290
+ start_cell = root_cells[0]
291
+ else:
292
+ start_cell = ms_data.iloc[:, 0].idxmax()
293
+
294
+ pr_res = palantir.core.run_palantir(
295
+ ms_data, start_cell, num_waypoints=num_waypoints
296
+ )
297
+
298
+ adata.obs["palantir_pseudotime"] = pr_res.pseudotime
299
+ adata.obsm["palantir_branch_probs"] = pr_res.branch_probs
300
+
301
+ return adata
302
+
303
+
304
+ def compute_dpt_trajectory(adata, root_cells=None, ctx: Optional["ToolContext"] = None):
305
+ """Compute Diffusion Pseudotime trajectory analysis."""
306
+ import numpy as np
307
+ import scanpy as sc
308
+
309
+ ensure_pca(adata)
310
+ ensure_neighbors(adata)
311
+ ensure_diffmap(adata)
312
+
313
+ if root_cells is not None and len(root_cells) > 0:
314
+ if root_cells[0] in adata.obs_names:
315
+ adata.uns["iroot"] = np.where(adata.obs_names == root_cells[0])[0][0]
316
+ else:
317
+ raise ParameterError(
318
+ f"Root cell '{root_cells[0]}' not found. "
319
+ f"Use valid cell ID from adata.obs_names or omit to auto-select."
320
+ )
321
+ else:
322
+ adata.uns["iroot"] = 0
323
+
324
+ if "dpt_pseudotime" not in adata.obs:
325
+ try:
326
+ sc.tl.dpt(adata)
327
+ except Exception as e:
328
+ raise ProcessingError(f"DPT computation failed: {e}") from e
329
+
330
+ if "dpt_pseudotime" not in adata.obs.columns:
331
+ raise ProcessingError("DPT computation did not create 'dpt_pseudotime' column")
332
+
333
+ adata.obs["dpt_pseudotime"] = adata.obs["dpt_pseudotime"].fillna(0)
334
+
335
+ return adata
336
+
337
+
338
+ def has_velocity_data(adata) -> bool:
339
+ """Check if RNA velocity has been computed (by any method)."""
340
+ return (
341
+ "velocity_graph" in adata.uns
342
+ or "velovi_adata" in adata.uns
343
+ or "velocity_method" in adata.uns
344
+ )
345
+
346
+
347
+ async def analyze_trajectory(
348
+ data_id: str,
349
+ ctx: "ToolContext",
350
+ params: TrajectoryParameters = TrajectoryParameters(),
351
+ ) -> TrajectoryResult:
352
+ """
353
+ Analyze trajectory and cell state transitions in spatial transcriptomics data.
354
+
355
+ This is the main MCP entry point for trajectory inference. It supports:
356
+ - CellRank: Requires pre-computed velocity data
357
+ - Palantir: Expression-based, no velocity required
358
+ - DPT: Diffusion-based, no velocity required
359
+
360
+ Args:
361
+ data_id: Dataset identifier.
362
+ ctx: ToolContext for data access and logging.
363
+ params: Trajectory analysis parameters.
364
+
365
+ Returns:
366
+ TrajectoryResult with pseudotime and method metadata.
367
+ """
368
+ adata = await ctx.get_adata(data_id)
369
+
370
+ velocity_available = has_velocity_data(adata)
371
+ pseudotime_key = None
372
+ method_used = params.method
373
+
374
+ # Execute requested method
375
+ if params.method == "cellrank":
376
+ if not velocity_available:
377
+ raise ProcessingError(
378
+ "CellRank requires velocity data. Run velocity analysis first or use palantir/dpt."
379
+ )
380
+
381
+ require("cellrank")
382
+ import cellrank as cr # noqa: F401
383
+
384
+ try:
385
+ with suppress_output():
386
+ adata = infer_spatial_trajectory_cellrank(
387
+ adata,
388
+ spatial_weight=params.spatial_weight,
389
+ kernel_weights=params.cellrank_kernel_weights,
390
+ n_states=params.cellrank_n_states,
391
+ )
392
+ pseudotime_key = "pseudotime"
393
+ method_used = "cellrank"
394
+ except Exception as e:
395
+ raise ProcessingError(f"CellRank trajectory inference failed: {e}") from e
396
+
397
+ elif params.method == "palantir":
398
+ try:
399
+ with suppress_output():
400
+ has_spatial = get_spatial_key(adata) is not None
401
+ if has_spatial and params.spatial_weight > 0:
402
+ adata = spatial_aware_embedding(
403
+ adata, spatial_weight=params.spatial_weight
404
+ )
405
+ elif not has_spatial and params.spatial_weight > 0:
406
+ await ctx.warning(
407
+ f"Spatial weight {params.spatial_weight} specified but no spatial "
408
+ "coordinates found. Using expression-only Palantir."
409
+ )
410
+
411
+ adata = infer_pseudotime_palantir(
412
+ adata,
413
+ root_cells=params.root_cells,
414
+ n_diffusion_components=params.palantir_n_diffusion_components,
415
+ num_waypoints=params.palantir_num_waypoints,
416
+ )
417
+
418
+ pseudotime_key = "palantir_pseudotime"
419
+ method_used = "palantir"
420
+
421
+ except Exception as e:
422
+ raise ProcessingError(f"Palantir trajectory inference failed: {e}") from e
423
+
424
+ elif params.method == "dpt":
425
+ try:
426
+ with suppress_output():
427
+ adata = compute_dpt_trajectory(
428
+ adata, root_cells=params.root_cells, ctx=ctx
429
+ )
430
+ pseudotime_key = "dpt_pseudotime"
431
+ method_used = "dpt"
432
+ except Exception as e:
433
+ raise ProcessingError(f"DPT analysis failed: {e}") from e
434
+
435
+ else:
436
+ raise ParameterError(f"Unknown trajectory method: {params.method}")
437
+
438
+ if pseudotime_key is None or pseudotime_key not in adata.obs.columns:
439
+ raise ProcessingError("Failed to compute pseudotime with any available method")
440
+
441
+ # Store scientific metadata
442
+ from ..utils.adata_utils import store_analysis_metadata
443
+
444
+ results_keys_dict: dict[str, Any] = {"obs": [pseudotime_key], "obsm": [], "uns": []}
445
+
446
+ if method_used == "cellrank":
447
+ results_keys_dict["obs"].extend(["terminal_states", "macrostates"])
448
+ results_keys_dict["obsm"].append("fate_probabilities")
449
+ results_keys_dict["uns"].append("velocity_method")
450
+ elif method_used == "palantir":
451
+ results_keys_dict["obsm"].append("palantir_branch_probs")
452
+ elif method_used == "dpt":
453
+ results_keys_dict["uns"].append("iroot")
454
+
455
+ parameters_dict: dict[str, Any] = {"spatial_weight": params.spatial_weight}
456
+ if method_used == "cellrank":
457
+ parameters_dict.update(
458
+ {
459
+ "kernel_weights": params.cellrank_kernel_weights,
460
+ "n_states": params.cellrank_n_states,
461
+ }
462
+ )
463
+ elif method_used == "palantir":
464
+ parameters_dict.update(
465
+ {
466
+ "n_diffusion_components": params.palantir_n_diffusion_components,
467
+ "num_waypoints": params.palantir_num_waypoints,
468
+ }
469
+ )
470
+
471
+ if params.root_cells:
472
+ parameters_dict["root_cells"] = params.root_cells
473
+
474
+ statistics_dict = {
475
+ "velocity_computed": velocity_available,
476
+ "pseudotime_key": pseudotime_key,
477
+ }
478
+
479
+ store_analysis_metadata(
480
+ adata,
481
+ analysis_name=f"trajectory_{method_used}",
482
+ method=method_used,
483
+ parameters=parameters_dict,
484
+ results_keys=results_keys_dict,
485
+ statistics=statistics_dict,
486
+ )
487
+
488
+ return TrajectoryResult(
489
+ data_id=data_id,
490
+ pseudotime_computed=True,
491
+ velocity_computed=velocity_available,
492
+ pseudotime_key=pseudotime_key,
493
+ method=method_used,
494
+ spatial_weight=params.spatial_weight,
495
+ )