chatspatial 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. chatspatial/__init__.py +11 -0
  2. chatspatial/__main__.py +141 -0
  3. chatspatial/cli/__init__.py +7 -0
  4. chatspatial/config.py +53 -0
  5. chatspatial/models/__init__.py +85 -0
  6. chatspatial/models/analysis.py +513 -0
  7. chatspatial/models/data.py +2462 -0
  8. chatspatial/server.py +1763 -0
  9. chatspatial/spatial_mcp_adapter.py +720 -0
  10. chatspatial/tools/__init__.py +3 -0
  11. chatspatial/tools/annotation.py +1903 -0
  12. chatspatial/tools/cell_communication.py +1603 -0
  13. chatspatial/tools/cnv_analysis.py +605 -0
  14. chatspatial/tools/condition_comparison.py +595 -0
  15. chatspatial/tools/deconvolution/__init__.py +402 -0
  16. chatspatial/tools/deconvolution/base.py +318 -0
  17. chatspatial/tools/deconvolution/card.py +244 -0
  18. chatspatial/tools/deconvolution/cell2location.py +326 -0
  19. chatspatial/tools/deconvolution/destvi.py +144 -0
  20. chatspatial/tools/deconvolution/flashdeconv.py +101 -0
  21. chatspatial/tools/deconvolution/rctd.py +317 -0
  22. chatspatial/tools/deconvolution/spotlight.py +216 -0
  23. chatspatial/tools/deconvolution/stereoscope.py +109 -0
  24. chatspatial/tools/deconvolution/tangram.py +135 -0
  25. chatspatial/tools/differential.py +625 -0
  26. chatspatial/tools/embeddings.py +298 -0
  27. chatspatial/tools/enrichment.py +1863 -0
  28. chatspatial/tools/integration.py +807 -0
  29. chatspatial/tools/preprocessing.py +723 -0
  30. chatspatial/tools/spatial_domains.py +808 -0
  31. chatspatial/tools/spatial_genes.py +836 -0
  32. chatspatial/tools/spatial_registration.py +441 -0
  33. chatspatial/tools/spatial_statistics.py +1476 -0
  34. chatspatial/tools/trajectory.py +495 -0
  35. chatspatial/tools/velocity.py +405 -0
  36. chatspatial/tools/visualization/__init__.py +155 -0
  37. chatspatial/tools/visualization/basic.py +393 -0
  38. chatspatial/tools/visualization/cell_comm.py +699 -0
  39. chatspatial/tools/visualization/cnv.py +320 -0
  40. chatspatial/tools/visualization/core.py +684 -0
  41. chatspatial/tools/visualization/deconvolution.py +852 -0
  42. chatspatial/tools/visualization/enrichment.py +660 -0
  43. chatspatial/tools/visualization/integration.py +205 -0
  44. chatspatial/tools/visualization/main.py +164 -0
  45. chatspatial/tools/visualization/multi_gene.py +739 -0
  46. chatspatial/tools/visualization/persistence.py +335 -0
  47. chatspatial/tools/visualization/spatial_stats.py +469 -0
  48. chatspatial/tools/visualization/trajectory.py +639 -0
  49. chatspatial/tools/visualization/velocity.py +411 -0
  50. chatspatial/utils/__init__.py +115 -0
  51. chatspatial/utils/adata_utils.py +1372 -0
  52. chatspatial/utils/compute.py +327 -0
  53. chatspatial/utils/data_loader.py +499 -0
  54. chatspatial/utils/dependency_manager.py +462 -0
  55. chatspatial/utils/device_utils.py +165 -0
  56. chatspatial/utils/exceptions.py +185 -0
  57. chatspatial/utils/image_utils.py +267 -0
  58. chatspatial/utils/mcp_utils.py +137 -0
  59. chatspatial/utils/path_utils.py +243 -0
  60. chatspatial/utils/persistence.py +78 -0
  61. chatspatial/utils/scipy_compat.py +143 -0
  62. chatspatial-1.1.0.dist-info/METADATA +242 -0
  63. chatspatial-1.1.0.dist-info/RECORD +67 -0
  64. chatspatial-1.1.0.dist-info/WHEEL +5 -0
  65. chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
  66. chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
  67. chatspatial-1.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,335 @@
1
+ """
2
+ Visualization persistence functions.
3
+
4
+ This module contains:
5
+ - save_visualization: Save single visualization to disk
6
+ - export_all_visualizations: Export all cached visualizations
7
+ - clear_visualization_cache: Clear visualization cache
8
+ """
9
+
10
+ from datetime import datetime
11
+ from typing import TYPE_CHECKING, Any, Optional
12
+
13
+ import matplotlib
14
+ import matplotlib.pyplot as plt
15
+
16
+ from ...models.data import VisualizationParameters
17
+ from ...utils.exceptions import DataNotFoundError, ParameterError, ProcessingError
18
+ from ...utils.path_utils import get_output_dir_from_config, get_safe_output_path
19
+
20
+ if TYPE_CHECKING:
21
+ import anndata as ad
22
+
23
+ from ...spatial_mcp_adapter import ToolContext
24
+
25
+
26
+ # =============================================================================
27
+ # Internal Helper Functions
28
+ # =============================================================================
29
+
30
+
31
+ async def _regenerate_figure_for_export(
32
+ adata: "ad.AnnData",
33
+ params: VisualizationParameters,
34
+ context: Optional["ToolContext"] = None,
35
+ ) -> plt.Figure:
36
+ """Regenerate a matplotlib figure from saved parameters for high-quality export.
37
+
38
+ This is an internal helper function used by save_visualization to recreate
39
+ figures from JSON metadata. It directly returns the matplotlib Figure object
40
+ (instead of ImageContent) so it can be exported at arbitrary DPI/format.
41
+
42
+ Args:
43
+ adata: AnnData object containing the data
44
+ params: VisualizationParameters reconstructed from saved metadata
45
+ context: MCP context for logging
46
+
47
+ Returns:
48
+ Matplotlib Figure object ready for export
49
+
50
+ Raises:
51
+ ValueError: If plot_type is unknown
52
+ """
53
+ # Import here to avoid circular imports
54
+ from . import PLOT_HANDLERS
55
+
56
+ plot_type = params.plot_type
57
+
58
+ if plot_type not in PLOT_HANDLERS:
59
+ raise ParameterError(f"Unknown plot type: {plot_type}")
60
+
61
+ # Get the appropriate visualization function
62
+ viz_func = PLOT_HANDLERS[plot_type]
63
+
64
+ # Call the visualization function to get the figure
65
+ fig = await viz_func(adata, params, context)
66
+
67
+ return fig
68
+
69
+
70
+ # =============================================================================
71
+ # Public API Functions
72
+ # =============================================================================
73
+
74
+
75
+ async def save_visualization(
76
+ data_id: str,
77
+ ctx: "ToolContext",
78
+ plot_type: str,
79
+ subtype: Optional[str] = None,
80
+ output_dir: str = "./outputs",
81
+ filename: Optional[str] = None,
82
+ format: str = "png",
83
+ dpi: Optional[int] = None,
84
+ ) -> str:
85
+ """Save a visualization to disk at publication quality by regenerating from metadata.
86
+
87
+ This function regenerates visualizations from stored metadata (JSON) and the original
88
+ data, then exports at the requested quality. This approach is more secure than
89
+ loading serialized figure objects (pickle) because:
90
+ 1. JSON metadata cannot contain executable code
91
+ 2. Regeneration uses the trusted visualization codebase
92
+ 3. All parameters are human-readable and auditable
93
+
94
+ Supports multiple formats including vector (PDF, SVG, EPS) and raster (PNG, JPEG, TIFF)
95
+ with publication-ready metadata.
96
+
97
+ Args:
98
+ data_id: Dataset ID
99
+ ctx: ToolContext for unified data access and logging
100
+ plot_type: Type of plot to save (e.g., 'spatial', 'deconvolution', 'spatial_statistics')
101
+ subtype: Optional subtype for plot types with variants
102
+ - For deconvolution: 'spatial_multi', 'dominant_type', 'diversity', etc.
103
+ - For spatial_statistics: 'neighborhood', 'co_occurrence', 'ripley', etc.
104
+ output_dir: Directory to save the file (default: ./outputs)
105
+ filename: Custom filename (optional, auto-generated if not provided)
106
+ format: Image format (png, jpg, jpeg, pdf, svg, eps, ps, tiff)
107
+ dpi: DPI for raster formats (default: 300 for publication quality)
108
+ Vector formats (PDF, SVG, EPS, PS) ignore DPI
109
+
110
+ Returns:
111
+ Path to the saved file
112
+
113
+ Raises:
114
+ DataNotFoundError: If visualization metadata not found
115
+ ProcessingError: If regeneration or saving fails
116
+ """
117
+ try:
118
+ # Use environment variable for output_dir if default value was passed
119
+ if output_dir == "./outputs":
120
+ output_dir = get_output_dir_from_config(default="./outputs")
121
+
122
+ # Validate format
123
+ valid_formats = ["png", "jpg", "jpeg", "pdf", "svg", "eps", "ps", "tiff"]
124
+ if format.lower() not in valid_formats:
125
+ raise ParameterError(
126
+ f"Invalid format: {format}. Must be one of {valid_formats}"
127
+ )
128
+
129
+ # Generate cache key with subtype if provided
130
+ cache_key = (
131
+ f"{data_id}_{plot_type}_{subtype}" if subtype else f"{data_id}_{plot_type}"
132
+ )
133
+
134
+ # Check if visualization exists in registry
135
+ viz_entry = ctx.get_visualization(cache_key)
136
+
137
+ # Set default DPI based on format
138
+ if dpi is None:
139
+ dpi = 300 # High quality for all formats (publication-ready)
140
+
141
+ # Create output directory using safe path handling
142
+ try:
143
+ output_path = get_safe_output_path(
144
+ output_dir, fallback_to_tmp=True, create_if_missing=True
145
+ )
146
+ except PermissionError as e:
147
+ raise ProcessingError(
148
+ f"Cannot save to {output_dir}: {e}. Check permissions."
149
+ ) from e
150
+
151
+ # Generate filename if not provided
152
+ if filename is None:
153
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
154
+ plot_name = f"{plot_type}_{subtype}" if subtype else plot_type
155
+
156
+ if dpi != 100:
157
+ filename = f"{data_id}_{plot_name}_{dpi}dpi_{timestamp}.{format}"
158
+ else:
159
+ filename = f"{data_id}_{plot_name}_{timestamp}.{format}"
160
+ else:
161
+ # Ensure filename has correct extension
162
+ if not filename.endswith(f".{format}"):
163
+ filename = f"{filename}.{format}"
164
+
165
+ # Full path for the file
166
+ file_path = output_path / filename
167
+
168
+ # Get visualization params from registry and regenerate figure
169
+ if viz_entry is None:
170
+ raise DataNotFoundError(
171
+ f"Visualization '{plot_type}' not found. Use visualize_data() first."
172
+ )
173
+
174
+ # Regenerate figure from stored params (first principles: params + data = output)
175
+ try:
176
+ # Get params from registry entry and override DPI
177
+ viz_params_dict = viz_entry.params.model_dump()
178
+ viz_params_dict["dpi"] = dpi
179
+ viz_params = VisualizationParameters(**viz_params_dict)
180
+
181
+ # Regenerate the figure
182
+ adata = await ctx.get_adata(data_id)
183
+ cached_fig = await _regenerate_figure_for_export(
184
+ adata, viz_params, ctx._mcp_context
185
+ )
186
+
187
+ except Exception as e:
188
+ raise ProcessingError(
189
+ f"Failed to regenerate '{cache_key}': {e}"
190
+ ) from e
191
+
192
+ try:
193
+ # Prepare save parameters
194
+ save_params: dict[str, Any] = {
195
+ "bbox_inches": "tight",
196
+ "facecolor": "white",
197
+ "edgecolor": "none",
198
+ "transparent": False,
199
+ "pad_inches": 0.1,
200
+ }
201
+
202
+ # Format-specific settings
203
+ if format.lower() == "pdf":
204
+ save_params["dpi"] = dpi
205
+ save_params["format"] = "pdf"
206
+ save_params["metadata"] = {
207
+ "Title": f"{plot_type} visualization of {data_id}",
208
+ "Author": "ChatSpatial MCP",
209
+ "Subject": "Spatial Transcriptomics Analysis",
210
+ "Keywords": f"{plot_type}, {data_id}, spatial transcriptomics",
211
+ "Creator": "ChatSpatial with matplotlib",
212
+ "Producer": f"matplotlib {matplotlib.__version__}",
213
+ }
214
+ elif format.lower() == "svg":
215
+ save_params["format"] = "svg"
216
+ elif format.lower() in ["eps", "ps"]:
217
+ save_params["format"] = format.lower()
218
+ elif format.lower() in ["png", "jpg", "jpeg", "tiff"]:
219
+ save_params["dpi"] = dpi
220
+ save_params["format"] = format.lower()
221
+ if format.lower() in ["jpg", "jpeg"]:
222
+ save_params["pil_kwargs"] = {"quality": 95}
223
+
224
+ # Save the figure
225
+ cached_fig.savefig(str(file_path), **save_params)
226
+
227
+ return str(file_path)
228
+
229
+ except Exception as e:
230
+ raise ProcessingError(f"Failed to export visualization: {e}") from e
231
+
232
+ except (DataNotFoundError, ParameterError):
233
+ raise
234
+ except Exception as e:
235
+ raise ProcessingError(f"Failed to save visualization: {e}") from e
236
+
237
+
238
+ async def export_all_visualizations(
239
+ data_id: str,
240
+ ctx: "ToolContext",
241
+ output_dir: str = "./exports",
242
+ format: str = "png",
243
+ dpi: Optional[int] = None,
244
+ ) -> list[str]:
245
+ """Export all cached visualizations for a dataset to disk.
246
+
247
+ Args:
248
+ data_id: Dataset ID to export visualizations for
249
+ ctx: ToolContext for unified data access and logging
250
+ output_dir: Directory to save files
251
+ format: Image format (png, jpg, pdf, svg)
252
+ dpi: DPI for saved images (default: 300 for publication quality)
253
+
254
+ Returns:
255
+ List of paths to saved files
256
+ """
257
+ try:
258
+ # Get visualization keys from registry (single source of truth)
259
+ relevant_keys = ctx.list_visualizations(data_id)
260
+
261
+ if not relevant_keys:
262
+ await ctx.warning(f"No visualizations found for dataset '{data_id}'")
263
+ return []
264
+
265
+ saved_files = []
266
+
267
+ for cache_key in relevant_keys:
268
+ # Extract plot_type and subtype from cache key
269
+ remainder = cache_key.replace(f"{data_id}_", "")
270
+
271
+ # Known plot types that support subtypes
272
+ known_plot_types_with_subtype = ["deconvolution", "spatial_statistics"]
273
+
274
+ plot_type = None
275
+ subtype = None
276
+
277
+ # Try to match known plot types with subtypes
278
+ for known_type in known_plot_types_with_subtype:
279
+ if remainder.startswith(f"{known_type}_"):
280
+ plot_type = known_type
281
+ subtype = remainder[len(known_type) + 1 :]
282
+ break
283
+
284
+ # If no match, treat the entire remainder as plot_type
285
+ if plot_type is None:
286
+ plot_type = remainder
287
+ subtype = None
288
+
289
+ try:
290
+ saved_path = await save_visualization(
291
+ data_id=data_id,
292
+ ctx=ctx,
293
+ plot_type=plot_type,
294
+ subtype=subtype,
295
+ output_dir=output_dir,
296
+ format=format,
297
+ dpi=dpi,
298
+ )
299
+ saved_files.append(saved_path)
300
+ except Exception as e:
301
+ await ctx.warning(f"Failed to export {cache_key}: {e}")
302
+
303
+ return saved_files
304
+
305
+ except ProcessingError:
306
+ raise
307
+ except Exception as e:
308
+ raise ProcessingError(f"Failed to export visualizations: {e}") from e
309
+
310
+
311
+ async def clear_visualization_cache(
312
+ ctx: "ToolContext",
313
+ data_id: Optional[str] = None,
314
+ ) -> int:
315
+ """Clear visualization cache to free memory.
316
+
317
+ Args:
318
+ ctx: ToolContext for unified data access and logging
319
+ data_id: Optional dataset ID to clear specific visualizations
320
+
321
+ Returns:
322
+ Number of visualizations cleared
323
+ """
324
+ try:
325
+ if data_id:
326
+ # Clear specific dataset visualizations using prefix
327
+ cleared_count = ctx.clear_visualizations(prefix=f"{data_id}_")
328
+ else:
329
+ # Clear all visualizations
330
+ cleared_count = ctx.clear_visualizations()
331
+
332
+ return cleared_count
333
+
334
+ except Exception as e:
335
+ raise ProcessingError(f"Failed to clear cache: {e}") from e