chatspatial 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- chatspatial/__init__.py +11 -0
- chatspatial/__main__.py +141 -0
- chatspatial/cli/__init__.py +7 -0
- chatspatial/config.py +53 -0
- chatspatial/models/__init__.py +85 -0
- chatspatial/models/analysis.py +513 -0
- chatspatial/models/data.py +2462 -0
- chatspatial/server.py +1763 -0
- chatspatial/spatial_mcp_adapter.py +720 -0
- chatspatial/tools/__init__.py +3 -0
- chatspatial/tools/annotation.py +1903 -0
- chatspatial/tools/cell_communication.py +1603 -0
- chatspatial/tools/cnv_analysis.py +605 -0
- chatspatial/tools/condition_comparison.py +595 -0
- chatspatial/tools/deconvolution/__init__.py +402 -0
- chatspatial/tools/deconvolution/base.py +318 -0
- chatspatial/tools/deconvolution/card.py +244 -0
- chatspatial/tools/deconvolution/cell2location.py +326 -0
- chatspatial/tools/deconvolution/destvi.py +144 -0
- chatspatial/tools/deconvolution/flashdeconv.py +101 -0
- chatspatial/tools/deconvolution/rctd.py +317 -0
- chatspatial/tools/deconvolution/spotlight.py +216 -0
- chatspatial/tools/deconvolution/stereoscope.py +109 -0
- chatspatial/tools/deconvolution/tangram.py +135 -0
- chatspatial/tools/differential.py +625 -0
- chatspatial/tools/embeddings.py +298 -0
- chatspatial/tools/enrichment.py +1863 -0
- chatspatial/tools/integration.py +807 -0
- chatspatial/tools/preprocessing.py +723 -0
- chatspatial/tools/spatial_domains.py +808 -0
- chatspatial/tools/spatial_genes.py +836 -0
- chatspatial/tools/spatial_registration.py +441 -0
- chatspatial/tools/spatial_statistics.py +1476 -0
- chatspatial/tools/trajectory.py +495 -0
- chatspatial/tools/velocity.py +405 -0
- chatspatial/tools/visualization/__init__.py +155 -0
- chatspatial/tools/visualization/basic.py +393 -0
- chatspatial/tools/visualization/cell_comm.py +699 -0
- chatspatial/tools/visualization/cnv.py +320 -0
- chatspatial/tools/visualization/core.py +684 -0
- chatspatial/tools/visualization/deconvolution.py +852 -0
- chatspatial/tools/visualization/enrichment.py +660 -0
- chatspatial/tools/visualization/integration.py +205 -0
- chatspatial/tools/visualization/main.py +164 -0
- chatspatial/tools/visualization/multi_gene.py +739 -0
- chatspatial/tools/visualization/persistence.py +335 -0
- chatspatial/tools/visualization/spatial_stats.py +469 -0
- chatspatial/tools/visualization/trajectory.py +639 -0
- chatspatial/tools/visualization/velocity.py +411 -0
- chatspatial/utils/__init__.py +115 -0
- chatspatial/utils/adata_utils.py +1372 -0
- chatspatial/utils/compute.py +327 -0
- chatspatial/utils/data_loader.py +499 -0
- chatspatial/utils/dependency_manager.py +462 -0
- chatspatial/utils/device_utils.py +165 -0
- chatspatial/utils/exceptions.py +185 -0
- chatspatial/utils/image_utils.py +267 -0
- chatspatial/utils/mcp_utils.py +137 -0
- chatspatial/utils/path_utils.py +243 -0
- chatspatial/utils/persistence.py +78 -0
- chatspatial/utils/scipy_compat.py +143 -0
- chatspatial-1.1.0.dist-info/METADATA +242 -0
- chatspatial-1.1.0.dist-info/RECORD +67 -0
- chatspatial-1.1.0.dist-info/WHEEL +5 -0
- chatspatial-1.1.0.dist-info/entry_points.txt +2 -0
- chatspatial-1.1.0.dist-info/licenses/LICENSE +21 -0
- chatspatial-1.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,335 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visualization persistence functions.
|
|
3
|
+
|
|
4
|
+
This module contains:
|
|
5
|
+
- save_visualization: Save single visualization to disk
|
|
6
|
+
- export_all_visualizations: Export all cached visualizations
|
|
7
|
+
- clear_visualization_cache: Clear visualization cache
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from datetime import datetime
|
|
11
|
+
from typing import TYPE_CHECKING, Any, Optional
|
|
12
|
+
|
|
13
|
+
import matplotlib
|
|
14
|
+
import matplotlib.pyplot as plt
|
|
15
|
+
|
|
16
|
+
from ...models.data import VisualizationParameters
|
|
17
|
+
from ...utils.exceptions import DataNotFoundError, ParameterError, ProcessingError
|
|
18
|
+
from ...utils.path_utils import get_output_dir_from_config, get_safe_output_path
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
import anndata as ad
|
|
22
|
+
|
|
23
|
+
from ...spatial_mcp_adapter import ToolContext
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# =============================================================================
|
|
27
|
+
# Internal Helper Functions
|
|
28
|
+
# =============================================================================
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
async def _regenerate_figure_for_export(
|
|
32
|
+
adata: "ad.AnnData",
|
|
33
|
+
params: VisualizationParameters,
|
|
34
|
+
context: Optional["ToolContext"] = None,
|
|
35
|
+
) -> plt.Figure:
|
|
36
|
+
"""Regenerate a matplotlib figure from saved parameters for high-quality export.
|
|
37
|
+
|
|
38
|
+
This is an internal helper function used by save_visualization to recreate
|
|
39
|
+
figures from JSON metadata. It directly returns the matplotlib Figure object
|
|
40
|
+
(instead of ImageContent) so it can be exported at arbitrary DPI/format.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
adata: AnnData object containing the data
|
|
44
|
+
params: VisualizationParameters reconstructed from saved metadata
|
|
45
|
+
context: MCP context for logging
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
Matplotlib Figure object ready for export
|
|
49
|
+
|
|
50
|
+
Raises:
|
|
51
|
+
ValueError: If plot_type is unknown
|
|
52
|
+
"""
|
|
53
|
+
# Import here to avoid circular imports
|
|
54
|
+
from . import PLOT_HANDLERS
|
|
55
|
+
|
|
56
|
+
plot_type = params.plot_type
|
|
57
|
+
|
|
58
|
+
if plot_type not in PLOT_HANDLERS:
|
|
59
|
+
raise ParameterError(f"Unknown plot type: {plot_type}")
|
|
60
|
+
|
|
61
|
+
# Get the appropriate visualization function
|
|
62
|
+
viz_func = PLOT_HANDLERS[plot_type]
|
|
63
|
+
|
|
64
|
+
# Call the visualization function to get the figure
|
|
65
|
+
fig = await viz_func(adata, params, context)
|
|
66
|
+
|
|
67
|
+
return fig
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
# =============================================================================
|
|
71
|
+
# Public API Functions
|
|
72
|
+
# =============================================================================
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
async def save_visualization(
|
|
76
|
+
data_id: str,
|
|
77
|
+
ctx: "ToolContext",
|
|
78
|
+
plot_type: str,
|
|
79
|
+
subtype: Optional[str] = None,
|
|
80
|
+
output_dir: str = "./outputs",
|
|
81
|
+
filename: Optional[str] = None,
|
|
82
|
+
format: str = "png",
|
|
83
|
+
dpi: Optional[int] = None,
|
|
84
|
+
) -> str:
|
|
85
|
+
"""Save a visualization to disk at publication quality by regenerating from metadata.
|
|
86
|
+
|
|
87
|
+
This function regenerates visualizations from stored metadata (JSON) and the original
|
|
88
|
+
data, then exports at the requested quality. This approach is more secure than
|
|
89
|
+
loading serialized figure objects (pickle) because:
|
|
90
|
+
1. JSON metadata cannot contain executable code
|
|
91
|
+
2. Regeneration uses the trusted visualization codebase
|
|
92
|
+
3. All parameters are human-readable and auditable
|
|
93
|
+
|
|
94
|
+
Supports multiple formats including vector (PDF, SVG, EPS) and raster (PNG, JPEG, TIFF)
|
|
95
|
+
with publication-ready metadata.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
data_id: Dataset ID
|
|
99
|
+
ctx: ToolContext for unified data access and logging
|
|
100
|
+
plot_type: Type of plot to save (e.g., 'spatial', 'deconvolution', 'spatial_statistics')
|
|
101
|
+
subtype: Optional subtype for plot types with variants
|
|
102
|
+
- For deconvolution: 'spatial_multi', 'dominant_type', 'diversity', etc.
|
|
103
|
+
- For spatial_statistics: 'neighborhood', 'co_occurrence', 'ripley', etc.
|
|
104
|
+
output_dir: Directory to save the file (default: ./outputs)
|
|
105
|
+
filename: Custom filename (optional, auto-generated if not provided)
|
|
106
|
+
format: Image format (png, jpg, jpeg, pdf, svg, eps, ps, tiff)
|
|
107
|
+
dpi: DPI for raster formats (default: 300 for publication quality)
|
|
108
|
+
Vector formats (PDF, SVG, EPS, PS) ignore DPI
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
Path to the saved file
|
|
112
|
+
|
|
113
|
+
Raises:
|
|
114
|
+
DataNotFoundError: If visualization metadata not found
|
|
115
|
+
ProcessingError: If regeneration or saving fails
|
|
116
|
+
"""
|
|
117
|
+
try:
|
|
118
|
+
# Use environment variable for output_dir if default value was passed
|
|
119
|
+
if output_dir == "./outputs":
|
|
120
|
+
output_dir = get_output_dir_from_config(default="./outputs")
|
|
121
|
+
|
|
122
|
+
# Validate format
|
|
123
|
+
valid_formats = ["png", "jpg", "jpeg", "pdf", "svg", "eps", "ps", "tiff"]
|
|
124
|
+
if format.lower() not in valid_formats:
|
|
125
|
+
raise ParameterError(
|
|
126
|
+
f"Invalid format: {format}. Must be one of {valid_formats}"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
# Generate cache key with subtype if provided
|
|
130
|
+
cache_key = (
|
|
131
|
+
f"{data_id}_{plot_type}_{subtype}" if subtype else f"{data_id}_{plot_type}"
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
# Check if visualization exists in registry
|
|
135
|
+
viz_entry = ctx.get_visualization(cache_key)
|
|
136
|
+
|
|
137
|
+
# Set default DPI based on format
|
|
138
|
+
if dpi is None:
|
|
139
|
+
dpi = 300 # High quality for all formats (publication-ready)
|
|
140
|
+
|
|
141
|
+
# Create output directory using safe path handling
|
|
142
|
+
try:
|
|
143
|
+
output_path = get_safe_output_path(
|
|
144
|
+
output_dir, fallback_to_tmp=True, create_if_missing=True
|
|
145
|
+
)
|
|
146
|
+
except PermissionError as e:
|
|
147
|
+
raise ProcessingError(
|
|
148
|
+
f"Cannot save to {output_dir}: {e}. Check permissions."
|
|
149
|
+
) from e
|
|
150
|
+
|
|
151
|
+
# Generate filename if not provided
|
|
152
|
+
if filename is None:
|
|
153
|
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
154
|
+
plot_name = f"{plot_type}_{subtype}" if subtype else plot_type
|
|
155
|
+
|
|
156
|
+
if dpi != 100:
|
|
157
|
+
filename = f"{data_id}_{plot_name}_{dpi}dpi_{timestamp}.{format}"
|
|
158
|
+
else:
|
|
159
|
+
filename = f"{data_id}_{plot_name}_{timestamp}.{format}"
|
|
160
|
+
else:
|
|
161
|
+
# Ensure filename has correct extension
|
|
162
|
+
if not filename.endswith(f".{format}"):
|
|
163
|
+
filename = f"{filename}.{format}"
|
|
164
|
+
|
|
165
|
+
# Full path for the file
|
|
166
|
+
file_path = output_path / filename
|
|
167
|
+
|
|
168
|
+
# Get visualization params from registry and regenerate figure
|
|
169
|
+
if viz_entry is None:
|
|
170
|
+
raise DataNotFoundError(
|
|
171
|
+
f"Visualization '{plot_type}' not found. Use visualize_data() first."
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# Regenerate figure from stored params (first principles: params + data = output)
|
|
175
|
+
try:
|
|
176
|
+
# Get params from registry entry and override DPI
|
|
177
|
+
viz_params_dict = viz_entry.params.model_dump()
|
|
178
|
+
viz_params_dict["dpi"] = dpi
|
|
179
|
+
viz_params = VisualizationParameters(**viz_params_dict)
|
|
180
|
+
|
|
181
|
+
# Regenerate the figure
|
|
182
|
+
adata = await ctx.get_adata(data_id)
|
|
183
|
+
cached_fig = await _regenerate_figure_for_export(
|
|
184
|
+
adata, viz_params, ctx._mcp_context
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
except Exception as e:
|
|
188
|
+
raise ProcessingError(
|
|
189
|
+
f"Failed to regenerate '{cache_key}': {e}"
|
|
190
|
+
) from e
|
|
191
|
+
|
|
192
|
+
try:
|
|
193
|
+
# Prepare save parameters
|
|
194
|
+
save_params: dict[str, Any] = {
|
|
195
|
+
"bbox_inches": "tight",
|
|
196
|
+
"facecolor": "white",
|
|
197
|
+
"edgecolor": "none",
|
|
198
|
+
"transparent": False,
|
|
199
|
+
"pad_inches": 0.1,
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
# Format-specific settings
|
|
203
|
+
if format.lower() == "pdf":
|
|
204
|
+
save_params["dpi"] = dpi
|
|
205
|
+
save_params["format"] = "pdf"
|
|
206
|
+
save_params["metadata"] = {
|
|
207
|
+
"Title": f"{plot_type} visualization of {data_id}",
|
|
208
|
+
"Author": "ChatSpatial MCP",
|
|
209
|
+
"Subject": "Spatial Transcriptomics Analysis",
|
|
210
|
+
"Keywords": f"{plot_type}, {data_id}, spatial transcriptomics",
|
|
211
|
+
"Creator": "ChatSpatial with matplotlib",
|
|
212
|
+
"Producer": f"matplotlib {matplotlib.__version__}",
|
|
213
|
+
}
|
|
214
|
+
elif format.lower() == "svg":
|
|
215
|
+
save_params["format"] = "svg"
|
|
216
|
+
elif format.lower() in ["eps", "ps"]:
|
|
217
|
+
save_params["format"] = format.lower()
|
|
218
|
+
elif format.lower() in ["png", "jpg", "jpeg", "tiff"]:
|
|
219
|
+
save_params["dpi"] = dpi
|
|
220
|
+
save_params["format"] = format.lower()
|
|
221
|
+
if format.lower() in ["jpg", "jpeg"]:
|
|
222
|
+
save_params["pil_kwargs"] = {"quality": 95}
|
|
223
|
+
|
|
224
|
+
# Save the figure
|
|
225
|
+
cached_fig.savefig(str(file_path), **save_params)
|
|
226
|
+
|
|
227
|
+
return str(file_path)
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
raise ProcessingError(f"Failed to export visualization: {e}") from e
|
|
231
|
+
|
|
232
|
+
except (DataNotFoundError, ParameterError):
|
|
233
|
+
raise
|
|
234
|
+
except Exception as e:
|
|
235
|
+
raise ProcessingError(f"Failed to save visualization: {e}") from e
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
async def export_all_visualizations(
|
|
239
|
+
data_id: str,
|
|
240
|
+
ctx: "ToolContext",
|
|
241
|
+
output_dir: str = "./exports",
|
|
242
|
+
format: str = "png",
|
|
243
|
+
dpi: Optional[int] = None,
|
|
244
|
+
) -> list[str]:
|
|
245
|
+
"""Export all cached visualizations for a dataset to disk.
|
|
246
|
+
|
|
247
|
+
Args:
|
|
248
|
+
data_id: Dataset ID to export visualizations for
|
|
249
|
+
ctx: ToolContext for unified data access and logging
|
|
250
|
+
output_dir: Directory to save files
|
|
251
|
+
format: Image format (png, jpg, pdf, svg)
|
|
252
|
+
dpi: DPI for saved images (default: 300 for publication quality)
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
List of paths to saved files
|
|
256
|
+
"""
|
|
257
|
+
try:
|
|
258
|
+
# Get visualization keys from registry (single source of truth)
|
|
259
|
+
relevant_keys = ctx.list_visualizations(data_id)
|
|
260
|
+
|
|
261
|
+
if not relevant_keys:
|
|
262
|
+
await ctx.warning(f"No visualizations found for dataset '{data_id}'")
|
|
263
|
+
return []
|
|
264
|
+
|
|
265
|
+
saved_files = []
|
|
266
|
+
|
|
267
|
+
for cache_key in relevant_keys:
|
|
268
|
+
# Extract plot_type and subtype from cache key
|
|
269
|
+
remainder = cache_key.replace(f"{data_id}_", "")
|
|
270
|
+
|
|
271
|
+
# Known plot types that support subtypes
|
|
272
|
+
known_plot_types_with_subtype = ["deconvolution", "spatial_statistics"]
|
|
273
|
+
|
|
274
|
+
plot_type = None
|
|
275
|
+
subtype = None
|
|
276
|
+
|
|
277
|
+
# Try to match known plot types with subtypes
|
|
278
|
+
for known_type in known_plot_types_with_subtype:
|
|
279
|
+
if remainder.startswith(f"{known_type}_"):
|
|
280
|
+
plot_type = known_type
|
|
281
|
+
subtype = remainder[len(known_type) + 1 :]
|
|
282
|
+
break
|
|
283
|
+
|
|
284
|
+
# If no match, treat the entire remainder as plot_type
|
|
285
|
+
if plot_type is None:
|
|
286
|
+
plot_type = remainder
|
|
287
|
+
subtype = None
|
|
288
|
+
|
|
289
|
+
try:
|
|
290
|
+
saved_path = await save_visualization(
|
|
291
|
+
data_id=data_id,
|
|
292
|
+
ctx=ctx,
|
|
293
|
+
plot_type=plot_type,
|
|
294
|
+
subtype=subtype,
|
|
295
|
+
output_dir=output_dir,
|
|
296
|
+
format=format,
|
|
297
|
+
dpi=dpi,
|
|
298
|
+
)
|
|
299
|
+
saved_files.append(saved_path)
|
|
300
|
+
except Exception as e:
|
|
301
|
+
await ctx.warning(f"Failed to export {cache_key}: {e}")
|
|
302
|
+
|
|
303
|
+
return saved_files
|
|
304
|
+
|
|
305
|
+
except ProcessingError:
|
|
306
|
+
raise
|
|
307
|
+
except Exception as e:
|
|
308
|
+
raise ProcessingError(f"Failed to export visualizations: {e}") from e
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
async def clear_visualization_cache(
|
|
312
|
+
ctx: "ToolContext",
|
|
313
|
+
data_id: Optional[str] = None,
|
|
314
|
+
) -> int:
|
|
315
|
+
"""Clear visualization cache to free memory.
|
|
316
|
+
|
|
317
|
+
Args:
|
|
318
|
+
ctx: ToolContext for unified data access and logging
|
|
319
|
+
data_id: Optional dataset ID to clear specific visualizations
|
|
320
|
+
|
|
321
|
+
Returns:
|
|
322
|
+
Number of visualizations cleared
|
|
323
|
+
"""
|
|
324
|
+
try:
|
|
325
|
+
if data_id:
|
|
326
|
+
# Clear specific dataset visualizations using prefix
|
|
327
|
+
cleared_count = ctx.clear_visualizations(prefix=f"{data_id}_")
|
|
328
|
+
else:
|
|
329
|
+
# Clear all visualizations
|
|
330
|
+
cleared_count = ctx.clear_visualizations()
|
|
331
|
+
|
|
332
|
+
return cleared_count
|
|
333
|
+
|
|
334
|
+
except Exception as e:
|
|
335
|
+
raise ProcessingError(f"Failed to clear cache: {e}") from e
|