scmcp-shared 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scmcp_shared/__init__.py +1 -3
- scmcp_shared/agent.py +38 -21
- scmcp_shared/backend.py +44 -0
- scmcp_shared/cli.py +75 -46
- scmcp_shared/kb.py +139 -0
- scmcp_shared/logging_config.py +6 -8
- scmcp_shared/mcp_base.py +184 -0
- scmcp_shared/schema/io.py +101 -59
- scmcp_shared/schema/pl.py +386 -490
- scmcp_shared/schema/pp.py +514 -265
- scmcp_shared/schema/preset/__init__.py +15 -0
- scmcp_shared/schema/preset/io.py +103 -0
- scmcp_shared/schema/preset/pl.py +843 -0
- scmcp_shared/schema/preset/pp.py +616 -0
- scmcp_shared/schema/preset/tl.py +917 -0
- scmcp_shared/schema/preset/util.py +123 -0
- scmcp_shared/schema/tl.py +355 -407
- scmcp_shared/schema/util.py +57 -72
- scmcp_shared/server/__init__.py +5 -10
- scmcp_shared/server/auto.py +15 -11
- scmcp_shared/server/code.py +3 -0
- scmcp_shared/server/preset/__init__.py +14 -0
- scmcp_shared/server/{io.py → preset/io.py} +26 -22
- scmcp_shared/server/{pl.py → preset/pl.py} +162 -78
- scmcp_shared/server/{pp.py → preset/pp.py} +123 -65
- scmcp_shared/server/{tl.py → preset/tl.py} +142 -79
- scmcp_shared/server/{util.py → preset/util.py} +123 -66
- scmcp_shared/server/rag.py +13 -0
- scmcp_shared/util.py +109 -38
- {scmcp_shared-0.4.0.dist-info → scmcp_shared-0.6.0.dist-info}/METADATA +6 -2
- scmcp_shared-0.6.0.dist-info/RECORD +35 -0
- scmcp_shared/server/base.py +0 -148
- scmcp_shared-0.4.0.dist-info/RECORD +0 -24
- {scmcp_shared-0.4.0.dist-info → scmcp_shared-0.6.0.dist-info}/WHEEL +0 -0
- {scmcp_shared-0.4.0.dist-info → scmcp_shared-0.6.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,843 @@
|
|
1
|
+
from typing import Optional, Union, List, Literal, Tuple, Mapping
|
2
|
+
from pydantic import Field, field_validator, BaseModel
|
3
|
+
|
4
|
+
|
5
|
+
# 创建 Mixin 类处理特定功能
|
6
|
+
class LegendMixin:
|
7
|
+
"""处理图例相关的字段"""
|
8
|
+
|
9
|
+
legend_fontsize: Optional[Union[int, float, str]] = Field(
|
10
|
+
default=None, description="Numeric size in pt or string describing the size."
|
11
|
+
)
|
12
|
+
|
13
|
+
legend_fontweight: Union[int, str] = Field(
|
14
|
+
default="bold",
|
15
|
+
description="Legend font weight. A numeric value in range 0-1000 or a string.",
|
16
|
+
)
|
17
|
+
|
18
|
+
legend_loc: str = Field(
|
19
|
+
default="right margin",
|
20
|
+
description="Location of legend, either 'on data', 'right margin' or a valid keyword for the loc parameter.",
|
21
|
+
)
|
22
|
+
|
23
|
+
legend_fontoutline: Optional[int] = Field(
|
24
|
+
default=None, description="Line width of the legend font outline in pt."
|
25
|
+
)
|
26
|
+
|
27
|
+
|
28
|
+
class ColorMappingMixin:
|
29
|
+
"""处理颜色映射相关的字段"""
|
30
|
+
|
31
|
+
color_map: Optional[str] = Field(
|
32
|
+
default=None, description="Color map to use for continuous variables."
|
33
|
+
)
|
34
|
+
|
35
|
+
palette: Optional[Union[str, List[str], Mapping[str, str]]] = Field(
|
36
|
+
default=None,
|
37
|
+
description="Colors to use for plotting categorical annotation groups.",
|
38
|
+
)
|
39
|
+
|
40
|
+
vmax: Optional[Union[str, float, List[Union[str, float]]]] = Field(
|
41
|
+
default=None,
|
42
|
+
description="The value representing the upper limit of the color scale.",
|
43
|
+
)
|
44
|
+
|
45
|
+
vmin: Optional[Union[str, float, List[Union[str, float]]]] = Field(
|
46
|
+
default=None,
|
47
|
+
description="The value representing the lower limit of the color scale.",
|
48
|
+
)
|
49
|
+
|
50
|
+
vcenter: Optional[Union[str, float, List[Union[str, float]]]] = Field(
|
51
|
+
default=None,
|
52
|
+
description="The value representing the center of the color scale.",
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
class FigureSizeMixin:
|
57
|
+
"""处理图形大小相关的字段"""
|
58
|
+
|
59
|
+
figsize: Optional[Tuple[float, float]] = Field(
|
60
|
+
default=None, description="Figure size. Format is (width, height)."
|
61
|
+
)
|
62
|
+
|
63
|
+
|
64
|
+
# 基础可视化模型,包含所有可视化工具共享的字段
|
65
|
+
class BaseVisualizationParam(
|
66
|
+
BaseModel, LegendMixin, ColorMappingMixin, FigureSizeMixin
|
67
|
+
):
|
68
|
+
"""基础可视化模型,包含所有可视化工具共享的字段"""
|
69
|
+
|
70
|
+
pass
|
71
|
+
|
72
|
+
|
73
|
+
# 基础嵌入可视化模型,包含所有嵌入可视化工具共享的字段
|
74
|
+
class BaseEmbeddingParam(BaseVisualizationParam):
|
75
|
+
"""基础嵌入可视化模型,包含所有嵌入可视化工具共享的字段"""
|
76
|
+
|
77
|
+
color: Optional[Union[str, List[str]]] = Field(
|
78
|
+
default=None,
|
79
|
+
description="Keys for annotations of observations/cells or variables/genes.",
|
80
|
+
)
|
81
|
+
|
82
|
+
gene_symbols: Optional[str] = Field(
|
83
|
+
default=None,
|
84
|
+
description="Column name in .var DataFrame that stores gene symbols.",
|
85
|
+
)
|
86
|
+
|
87
|
+
use_raw: Optional[bool] = Field(
|
88
|
+
default=None,
|
89
|
+
description="Use .raw attribute of adata for coloring with gene expression.",
|
90
|
+
)
|
91
|
+
|
92
|
+
sort_order: bool = Field(
|
93
|
+
default=True,
|
94
|
+
description="For continuous annotations used as color parameter, plot data points with higher values on top of others.",
|
95
|
+
)
|
96
|
+
|
97
|
+
edges: bool = Field(default=False, description="Show edges between nodes.")
|
98
|
+
|
99
|
+
edges_width: float = Field(default=0.1, description="Width of edges.")
|
100
|
+
|
101
|
+
edges_color: Union[str, List[float], List[str]] = Field(
|
102
|
+
default="grey", description="Color of edges."
|
103
|
+
)
|
104
|
+
|
105
|
+
neighbors_key: Optional[str] = Field(
|
106
|
+
default=None, description="Where to look for neighbors connectivities."
|
107
|
+
)
|
108
|
+
|
109
|
+
arrows: bool = Field(default=False, description="Show arrows.")
|
110
|
+
|
111
|
+
groups: Optional[Union[str, List[str]]] = Field(
|
112
|
+
default=None,
|
113
|
+
description="Restrict to a few categories in categorical observation annotation.",
|
114
|
+
)
|
115
|
+
|
116
|
+
components: Optional[Union[str, List[str]]] = Field(
|
117
|
+
default=None,
|
118
|
+
description="For instance, ['1,2', '2,3']. To plot all available components use components='all'.",
|
119
|
+
)
|
120
|
+
|
121
|
+
dimensions: Optional[Union[Tuple[int, int], List[Tuple[int, int]]]] = Field(
|
122
|
+
default=None,
|
123
|
+
description="0-indexed dimensions of the embedding to plot as integers. E.g. [(0, 1), (1, 2)].",
|
124
|
+
)
|
125
|
+
|
126
|
+
layer: Optional[str] = Field(
|
127
|
+
default=None,
|
128
|
+
description="Name of the AnnData object layer that wants to be plotted.",
|
129
|
+
)
|
130
|
+
|
131
|
+
projection: Literal["2d", "3d"] = Field(
|
132
|
+
default="2d", description="Projection of plot."
|
133
|
+
)
|
134
|
+
|
135
|
+
size: Optional[Union[float, List[float]]] = Field(
|
136
|
+
default=None, description="Point size. If None, is automatically computed."
|
137
|
+
)
|
138
|
+
|
139
|
+
frameon: Optional[bool] = Field(
|
140
|
+
default=None, description="Draw a frame around the scatter plot."
|
141
|
+
)
|
142
|
+
|
143
|
+
add_outline: Optional[bool] = Field(
|
144
|
+
default=False, description="Add outline to scatter plot points."
|
145
|
+
)
|
146
|
+
|
147
|
+
ncols: int = Field(default=4, description="Number of columns for multiple plots.")
|
148
|
+
|
149
|
+
marker: Union[str, List[str]] = Field(
|
150
|
+
default=".", description="Matplotlib marker style for points."
|
151
|
+
)
|
152
|
+
|
153
|
+
|
154
|
+
# 重构 ScatterParam 作为基础散点图模型
|
155
|
+
class BaseScatterParam(BaseVisualizationParam):
|
156
|
+
"""基础散点图模型"""
|
157
|
+
|
158
|
+
x: Optional[str] = Field(default=None, description="x coordinate.")
|
159
|
+
|
160
|
+
y: Optional[str] = Field(default=None, description="y coordinate.")
|
161
|
+
|
162
|
+
color: Optional[Union[str, List[str]]] = Field(
|
163
|
+
default=None,
|
164
|
+
description="Keys for annotations of observations/cells or variables/genes, or a hex color specification.",
|
165
|
+
)
|
166
|
+
|
167
|
+
use_raw: Optional[bool] = Field(
|
168
|
+
default=None,
|
169
|
+
description="Whether to use raw attribute of adata. Defaults to True if .raw is present.",
|
170
|
+
)
|
171
|
+
|
172
|
+
layers: Optional[Union[str, List[str]]] = Field(
|
173
|
+
default=None,
|
174
|
+
description="Use the layers attribute of adata if present: specify the layer for x, y and color.",
|
175
|
+
)
|
176
|
+
|
177
|
+
basis: Optional[str] = Field(
|
178
|
+
default=None, description="Basis to use for embedding."
|
179
|
+
)
|
180
|
+
|
181
|
+
|
182
|
+
# 使用继承关系重构 EnhancedScatterParam
|
183
|
+
class EnhancedScatterParam(BaseScatterParam):
|
184
|
+
"""Input schema for the enhanced scatter plotting tool."""
|
185
|
+
|
186
|
+
sort_order: bool = Field(
|
187
|
+
default=True,
|
188
|
+
description="For continuous annotations used as color parameter, plot data points with higher values on top of others.",
|
189
|
+
)
|
190
|
+
|
191
|
+
alpha: Optional[float] = Field(
|
192
|
+
default=None, description="Alpha value for the plot.", ge=0, le=1
|
193
|
+
)
|
194
|
+
|
195
|
+
groups: Optional[Union[str, List[str]]] = Field(
|
196
|
+
default=None,
|
197
|
+
description="Restrict to a few categories in categorical observation annotation.",
|
198
|
+
)
|
199
|
+
|
200
|
+
components: Optional[Union[str, List[str]]] = Field(
|
201
|
+
default=None,
|
202
|
+
description="For instance, ['1,2', '2,3']. To plot all available components use components='all'.",
|
203
|
+
)
|
204
|
+
|
205
|
+
projection: Literal["2d", "3d"] = Field(
|
206
|
+
default="2d", description="Projection of plot."
|
207
|
+
)
|
208
|
+
|
209
|
+
right_margin: Optional[float] = Field(
|
210
|
+
default=None, description="Adjust the width of the right margin."
|
211
|
+
)
|
212
|
+
|
213
|
+
left_margin: Optional[float] = Field(
|
214
|
+
default=None, description="Adjust the width of the left margin."
|
215
|
+
)
|
216
|
+
|
217
|
+
@field_validator("alpha")
|
218
|
+
def validate_alpha(cls, v: Optional[float]) -> Optional[float]:
|
219
|
+
"""Validate alpha is between 0 and 1"""
|
220
|
+
if v is not None and (v < 0 or v > 1):
|
221
|
+
raise ValueError("alpha must be between 0 and 1")
|
222
|
+
return v
|
223
|
+
|
224
|
+
|
225
|
+
# 创建基础统计可视化模型
|
226
|
+
class BaseStatPlotParam(BaseVisualizationParam):
|
227
|
+
"""基础统计可视化模型,包含统计图表共享的字段"""
|
228
|
+
|
229
|
+
groupby: Optional[str] = Field(
|
230
|
+
default=None, description="The key of the observation grouping to consider."
|
231
|
+
)
|
232
|
+
|
233
|
+
log: bool = Field(default=False, description="Plot on logarithmic axis.")
|
234
|
+
|
235
|
+
use_raw: Optional[bool] = Field(
|
236
|
+
default=None, description="Use raw attribute of adata if present."
|
237
|
+
)
|
238
|
+
|
239
|
+
var_names: Optional[Union[str, List[str]]] = Field(
|
240
|
+
default=None,
|
241
|
+
description="var_names should be a valid subset of adata.var_names.",
|
242
|
+
)
|
243
|
+
|
244
|
+
layer: Optional[str] = Field(
|
245
|
+
default=None,
|
246
|
+
description="Name of the AnnData object layer that wants to be plotted.",
|
247
|
+
)
|
248
|
+
|
249
|
+
gene_symbols: Optional[str] = Field(
|
250
|
+
default=None,
|
251
|
+
description="Column name in .var DataFrame that stores gene symbols.",
|
252
|
+
)
|
253
|
+
|
254
|
+
# 添加共享的小提琴图相关字段
|
255
|
+
stripplot: bool = Field(
|
256
|
+
default=True, description="Add a stripplot on top of the violin plot."
|
257
|
+
)
|
258
|
+
|
259
|
+
jitter: Union[float, bool] = Field(
|
260
|
+
default=True,
|
261
|
+
description="Add jitter to the stripplot (only when stripplot is True).",
|
262
|
+
)
|
263
|
+
|
264
|
+
size: int = Field(default=1, description="Size of the jitter points.", gt=0)
|
265
|
+
|
266
|
+
order: Optional[List[str]] = Field(
|
267
|
+
default=None, description="Order in which to show the categories."
|
268
|
+
)
|
269
|
+
|
270
|
+
scale: Literal["area", "count", "width"] = Field(
|
271
|
+
default="width",
|
272
|
+
description="The method used to scale the width of each violin.",
|
273
|
+
)
|
274
|
+
|
275
|
+
@field_validator("size")
|
276
|
+
def validate_size(cls, v: int) -> int:
|
277
|
+
"""Validate size is positive"""
|
278
|
+
if v <= 0:
|
279
|
+
raise ValueError("size must be a positive integer")
|
280
|
+
return v
|
281
|
+
|
282
|
+
|
283
|
+
# 添加缺失的 BaseMatrixParam 类
|
284
|
+
class BaseMatrixParam(BaseVisualizationParam):
|
285
|
+
"""基础矩阵可视化模型,包含所有矩阵可视化工具共享的字段"""
|
286
|
+
|
287
|
+
var_names: Union[List[str], Mapping[str, List[str]]] = Field(
|
288
|
+
default=None,
|
289
|
+
description="var_names should be a valid subset of adata.var_names or a mapping where the key is used as label to group the values.",
|
290
|
+
)
|
291
|
+
groupby: Union[str, List[str]] = Field(
|
292
|
+
..., # Required field
|
293
|
+
description="The key of the observation grouping to consider.",
|
294
|
+
)
|
295
|
+
use_raw: Optional[bool] = Field(
|
296
|
+
default=None, description="Use raw attribute of adata if present."
|
297
|
+
)
|
298
|
+
log: bool = Field(default=False, description="Plot on logarithmic axis.")
|
299
|
+
dendrogram: Union[bool, str] = Field(
|
300
|
+
default=False,
|
301
|
+
description="If True or a valid dendrogram key, a dendrogram based on the hierarchical clustering between the groupby categories is added.",
|
302
|
+
)
|
303
|
+
|
304
|
+
gene_symbols: Optional[str] = Field(
|
305
|
+
default=None,
|
306
|
+
description="Column name in .var DataFrame that stores gene symbols.",
|
307
|
+
)
|
308
|
+
|
309
|
+
var_group_positions: Optional[List[Tuple[int, int]]] = Field(
|
310
|
+
default=None,
|
311
|
+
description="Use this parameter to highlight groups of var_names with brackets or color blocks between the given start and end positions.",
|
312
|
+
)
|
313
|
+
|
314
|
+
var_group_labels: Optional[List[str]] = Field(
|
315
|
+
default=None,
|
316
|
+
description="Labels for each of the var_group_positions that want to be highlighted.",
|
317
|
+
)
|
318
|
+
|
319
|
+
layer: Optional[str] = Field(
|
320
|
+
default=None,
|
321
|
+
description="Name of the AnnData object layer that wants to be plotted.",
|
322
|
+
)
|
323
|
+
|
324
|
+
|
325
|
+
# 重构 HeatmapParam
|
326
|
+
class HeatmapParam(BaseMatrixParam):
|
327
|
+
"""Input schema for the heatmap plotting tool."""
|
328
|
+
|
329
|
+
num_categories: int = Field(
|
330
|
+
default=7,
|
331
|
+
description="Only used if groupby observation is not categorical. This value determines the number of groups into which the groupby observation should be subdivided.",
|
332
|
+
gt=0,
|
333
|
+
)
|
334
|
+
|
335
|
+
var_group_rotation: Optional[float] = Field(
|
336
|
+
default=None,
|
337
|
+
description="Label rotation degrees. By default, labels larger than 4 characters are rotated 90 degrees.",
|
338
|
+
)
|
339
|
+
|
340
|
+
standard_scale: Optional[Literal["var", "obs"]] = Field(
|
341
|
+
default=None,
|
342
|
+
description="Whether or not to standardize that dimension between 0 and 1.",
|
343
|
+
)
|
344
|
+
|
345
|
+
swap_axes: bool = Field(
|
346
|
+
default=False,
|
347
|
+
description="By default, the x axis contains var_names and the y axis the groupby categories. By setting swap_axes then x are the groupby categories and y the var_names.",
|
348
|
+
)
|
349
|
+
|
350
|
+
show_gene_labels: Optional[bool] = Field(
|
351
|
+
default=None,
|
352
|
+
description="By default gene labels are shown when there are 50 or less genes. Otherwise the labels are removed.",
|
353
|
+
)
|
354
|
+
|
355
|
+
@field_validator("num_categories")
|
356
|
+
def validate_num_categories(cls, v: int) -> int:
|
357
|
+
"""Validate num_categories is positive"""
|
358
|
+
if v <= 0:
|
359
|
+
raise ValueError("num_categories must be a positive integer")
|
360
|
+
return v
|
361
|
+
|
362
|
+
|
363
|
+
# 重构 TracksplotParam
|
364
|
+
class TracksplotParam(BaseMatrixParam):
|
365
|
+
"""Input schema for the tracksplot plotting tool."""
|
366
|
+
|
367
|
+
# 所有需要的字段已经在 BaseMatrixParam 中定义
|
368
|
+
|
369
|
+
|
370
|
+
# 重构 ViolinParam
|
371
|
+
class ViolinParam(BaseStatPlotParam):
|
372
|
+
"""Input schema for the violin plotting tool."""
|
373
|
+
|
374
|
+
keys: Union[str, List[str]] = Field(
|
375
|
+
..., # Required field
|
376
|
+
description="Keys for accessing variables of adata.var or adata.obs. or variables of adata.obsm when obsm_key is not None.",
|
377
|
+
)
|
378
|
+
use_obsm: str = Field(
|
379
|
+
default=None, description="using data of adata.obsm instead of adata.X"
|
380
|
+
)
|
381
|
+
stripplot: bool = Field(
|
382
|
+
default=True, description="Add a stripplot on top of the violin plot."
|
383
|
+
)
|
384
|
+
jitter: Union[float, bool] = Field(
|
385
|
+
default=True,
|
386
|
+
description="Add jitter to the stripplot (only when stripplot is True).",
|
387
|
+
)
|
388
|
+
size: int = Field(default=1, description="Size of the jitter points.", gt=0)
|
389
|
+
scale: Literal["area", "count", "width"] = Field(
|
390
|
+
default="width",
|
391
|
+
description="The method used to scale the width of each violin.",
|
392
|
+
)
|
393
|
+
order: Optional[List[str]] = Field(
|
394
|
+
default=None, description="Order in which to show the categories."
|
395
|
+
)
|
396
|
+
multi_panel: Optional[bool] = Field(
|
397
|
+
default=None,
|
398
|
+
description="Display keys in multiple panels also when groupby is not None.",
|
399
|
+
)
|
400
|
+
xlabel: str = Field(
|
401
|
+
default="",
|
402
|
+
description="Label of the x axis. Defaults to groupby if rotation is None, otherwise, no label is shown.",
|
403
|
+
)
|
404
|
+
ylabel: Optional[Union[str, List[str]]] = Field(
|
405
|
+
default=None, description="Label of the y axis."
|
406
|
+
)
|
407
|
+
|
408
|
+
rotation: Optional[float] = Field(
|
409
|
+
default=None, description="Rotation of xtick labels."
|
410
|
+
)
|
411
|
+
|
412
|
+
@field_validator("size")
|
413
|
+
def validate_size(cls, v: int) -> int:
|
414
|
+
"""Validate size is positive"""
|
415
|
+
if v <= 0:
|
416
|
+
raise ValueError("size must be a positive integer")
|
417
|
+
return v
|
418
|
+
|
419
|
+
|
420
|
+
# 重构 MatrixplotParam
|
421
|
+
class MatrixplotParam(BaseMatrixParam):
|
422
|
+
"""Input schema for the matrixplot plotting tool."""
|
423
|
+
|
424
|
+
num_categories: int = Field(
|
425
|
+
default=7,
|
426
|
+
description="Only used if groupby observation is not categorical. This value determines the number of groups into which the groupby observation should be subdivided.",
|
427
|
+
gt=0,
|
428
|
+
)
|
429
|
+
|
430
|
+
cmap: Optional[str] = Field(
|
431
|
+
default="viridis", description="String denoting matplotlib color map."
|
432
|
+
)
|
433
|
+
|
434
|
+
colorbar_title: Optional[str] = Field(
|
435
|
+
default="Mean expression\nin group",
|
436
|
+
description="Title for the color bar. New line character (\\n) can be used.",
|
437
|
+
)
|
438
|
+
|
439
|
+
var_group_rotation: Optional[float] = Field(
|
440
|
+
default=None,
|
441
|
+
description="Label rotation degrees. By default, labels larger than 4 characters are rotated 90 degrees.",
|
442
|
+
)
|
443
|
+
|
444
|
+
standard_scale: Optional[Literal["var", "group"]] = Field(
|
445
|
+
default=None,
|
446
|
+
description="Whether or not to standardize the given dimension between 0 and 1.",
|
447
|
+
)
|
448
|
+
|
449
|
+
swap_axes: bool = Field(
|
450
|
+
default=False,
|
451
|
+
description="By default, the x axis contains var_names and the y axis the groupby categories. By setting swap_axes then x are the groupby categories and y the var_names.",
|
452
|
+
)
|
453
|
+
use_obsm: str = Field(
|
454
|
+
default=None, description="using data of adata.obsm instead of adata.X"
|
455
|
+
)
|
456
|
+
|
457
|
+
@field_validator("num_categories")
|
458
|
+
def validate_num_categories(cls, v: int) -> int:
|
459
|
+
"""Validate num_categories is positive"""
|
460
|
+
if v <= 0:
|
461
|
+
raise ValueError("num_categories must be a positive integer")
|
462
|
+
return v
|
463
|
+
|
464
|
+
|
465
|
+
# 重构 DotplotParam
|
466
|
+
class DotplotParam(BaseMatrixParam):
|
467
|
+
"""Input schema for the dotplot plotting tool."""
|
468
|
+
|
469
|
+
expression_cutoff: float = Field(
|
470
|
+
default=0.0,
|
471
|
+
description="Expression cutoff that is used for binarizing the gene expression.",
|
472
|
+
)
|
473
|
+
|
474
|
+
mean_only_expressed: bool = Field(
|
475
|
+
default=False,
|
476
|
+
description="If True, gene expression is averaged only over the cells expressing the given genes.",
|
477
|
+
)
|
478
|
+
|
479
|
+
standard_scale: Optional[Literal["var", "group"]] = Field(
|
480
|
+
default=None,
|
481
|
+
description="Whether or not to standardize that dimension between 0 and 1.",
|
482
|
+
)
|
483
|
+
|
484
|
+
swap_axes: bool = Field(
|
485
|
+
default=False,
|
486
|
+
description="By default, the x axis contains var_names and the y axis the groupby categories. By setting swap_axes then x are the groupby categories and y the var_names.",
|
487
|
+
)
|
488
|
+
|
489
|
+
dot_max: Optional[float] = Field(
|
490
|
+
default=None, description="The maximum size of the dots."
|
491
|
+
)
|
492
|
+
|
493
|
+
dot_min: Optional[float] = Field(
|
494
|
+
default=None, description="The minimum size of the dots."
|
495
|
+
)
|
496
|
+
|
497
|
+
smallest_dot: Optional[float] = Field(
|
498
|
+
default=None, description="The smallest dot size."
|
499
|
+
)
|
500
|
+
var_group_rotation: Optional[float] = Field(
|
501
|
+
default=None,
|
502
|
+
description="Label rotation degrees. By default, labels larger than 4 characters are rotated 90 degrees.",
|
503
|
+
)
|
504
|
+
|
505
|
+
colorbar_title: Optional[str] = Field(
|
506
|
+
default="Mean expression\nin group",
|
507
|
+
description="Title for the color bar. New line character (\\n) can be used.",
|
508
|
+
)
|
509
|
+
|
510
|
+
size_title: Optional[str] = Field(
|
511
|
+
default="Fraction of cells\nin group (%)",
|
512
|
+
description="Title for the size legend. New line character (\\n) can be used.",
|
513
|
+
)
|
514
|
+
|
515
|
+
|
516
|
+
# 重构 RankGenesGroupsParam
|
517
|
+
class RankGenesGroupsParam(BaseVisualizationParam):
|
518
|
+
"""Input schema for the rank_genes_groups plotting tool."""
|
519
|
+
|
520
|
+
n_genes: int = Field(default=20, description="Number of genes to show.", gt=0)
|
521
|
+
|
522
|
+
gene_symbols: Optional[str] = Field(
|
523
|
+
default=None,
|
524
|
+
description="Column name in `.var` DataFrame that stores gene symbols.",
|
525
|
+
)
|
526
|
+
|
527
|
+
groupby: Optional[str] = Field(
|
528
|
+
default=None, description="The key of the observation grouping to consider."
|
529
|
+
)
|
530
|
+
|
531
|
+
groups: Optional[Union[str, List[str]]] = Field(
|
532
|
+
default=None, description="Subset of groups, e.g. ['g1', 'g2', 'g3']."
|
533
|
+
)
|
534
|
+
|
535
|
+
key: Optional[str] = Field(
|
536
|
+
default="rank_genes_groups",
|
537
|
+
description="Key used to store the rank_genes_groups parameters.",
|
538
|
+
)
|
539
|
+
|
540
|
+
fontsize: int = Field(default=8, description="Fontsize for gene names.")
|
541
|
+
|
542
|
+
ncols: int = Field(default=4, description="Number of columns.")
|
543
|
+
|
544
|
+
sharey: bool = Field(
|
545
|
+
default=True,
|
546
|
+
description="Controls if the y-axis of each panels should be shared.",
|
547
|
+
)
|
548
|
+
|
549
|
+
@field_validator("n_genes", "fontsize")
|
550
|
+
def validate_positive_int(cls, v: int) -> int:
|
551
|
+
"""Validate positive integers"""
|
552
|
+
if v <= 0:
|
553
|
+
raise ValueError("Value must be a positive integer")
|
554
|
+
return v
|
555
|
+
|
556
|
+
|
557
|
+
# 重构 ClusterMapModel
|
558
|
+
class ClusterMapParam(BaseModel):
|
559
|
+
"""Input schema for the clustermap plotting tool."""
|
560
|
+
|
561
|
+
obs_keys: Optional[str] = Field(
|
562
|
+
default=None,
|
563
|
+
description="key column in adata.obs, categorical annotation to plot with a different color map.",
|
564
|
+
)
|
565
|
+
use_raw: Optional[bool] = Field(
|
566
|
+
default=None,
|
567
|
+
description="Whether to use `raw` attribute of `adata`. Defaults to `True` if `.raw` is present.",
|
568
|
+
)
|
569
|
+
|
570
|
+
|
571
|
+
# 重构 StackedViolinParam
|
572
|
+
class StackedViolinParam(BaseStatPlotParam):
|
573
|
+
"""Input schema for the stacked_violin plotting tool."""
|
574
|
+
|
575
|
+
stripplot: bool = Field(
|
576
|
+
default=True, description="Add a stripplot on top of the violin plot."
|
577
|
+
)
|
578
|
+
|
579
|
+
jitter: Union[float, bool] = Field(
|
580
|
+
default=True,
|
581
|
+
description="Add jitter to the stripplot (only when stripplot is True).",
|
582
|
+
)
|
583
|
+
|
584
|
+
size: int = Field(default=1, description="Size of the jitter points.", gt=0)
|
585
|
+
|
586
|
+
order: Optional[List[str]] = Field(
|
587
|
+
default=None, description="Order in which to show the categories."
|
588
|
+
)
|
589
|
+
|
590
|
+
scale: Literal["area", "count", "width"] = Field(
|
591
|
+
default="width",
|
592
|
+
description="The method used to scale the width of each violin.",
|
593
|
+
)
|
594
|
+
|
595
|
+
swap_axes: bool = Field(
|
596
|
+
default=False, description="Swap axes such that observations are on the x-axis."
|
597
|
+
)
|
598
|
+
|
599
|
+
@field_validator("size")
|
600
|
+
def validate_size(cls, v: int) -> int:
|
601
|
+
"""Validate size is positive"""
|
602
|
+
if v <= 0:
|
603
|
+
raise ValueError("size must be a positive integer")
|
604
|
+
return v
|
605
|
+
|
606
|
+
|
607
|
+
# 重构 TrackingParam
|
608
|
+
class TrackingParam(BaseVisualizationParam):
|
609
|
+
"""Input schema for the tracking plotting tool."""
|
610
|
+
|
611
|
+
groupby: str = Field(
|
612
|
+
..., # Required field
|
613
|
+
description="The key of the observation grouping to consider.",
|
614
|
+
)
|
615
|
+
|
616
|
+
min_group_size: int = Field(
|
617
|
+
default=1,
|
618
|
+
description="Minimal number of cells in a group for the group to be considered.",
|
619
|
+
gt=0,
|
620
|
+
)
|
621
|
+
|
622
|
+
min_split_size: int = Field(
|
623
|
+
default=1,
|
624
|
+
description="Minimal number of cells in a split for the split to be shown.",
|
625
|
+
gt=0,
|
626
|
+
)
|
627
|
+
|
628
|
+
@field_validator("min_group_size", "min_split_size")
|
629
|
+
def validate_positive_int(cls, v: int) -> int:
|
630
|
+
"""Validate positive integers"""
|
631
|
+
if v <= 0:
|
632
|
+
raise ValueError("Value must be a positive integer")
|
633
|
+
return v
|
634
|
+
|
635
|
+
|
636
|
+
# 重构 EmbeddingDensityParam
|
637
|
+
class EmbeddingDensityParam(BaseEmbeddingParam):
|
638
|
+
"""Input schema for the embedding_density plotting tool."""
|
639
|
+
|
640
|
+
basis: str = Field(
|
641
|
+
..., # Required field
|
642
|
+
description="Basis to use for embedding.",
|
643
|
+
)
|
644
|
+
|
645
|
+
key: Optional[str] = Field(
|
646
|
+
default=None,
|
647
|
+
description="Key for annotation of observations/cells or variables/genes.",
|
648
|
+
)
|
649
|
+
|
650
|
+
convolve: Optional[float] = Field(
|
651
|
+
default=None, description="Sigma for Gaussian kernel used for convolution."
|
652
|
+
)
|
653
|
+
|
654
|
+
alpha: float = Field(
|
655
|
+
default=0.5, description="Alpha value for the plot.", ge=0, le=1
|
656
|
+
)
|
657
|
+
|
658
|
+
@field_validator("alpha")
|
659
|
+
def validate_alpha(cls, v: float) -> float:
|
660
|
+
"""Validate alpha is between 0 and 1"""
|
661
|
+
if v < 0 or v > 1:
|
662
|
+
raise ValueError("alpha must be between 0 and 1")
|
663
|
+
return v
|
664
|
+
|
665
|
+
|
666
|
+
class PCAParam(BaseEmbeddingParam):
|
667
|
+
"""Input schema for the PCA plotting tool."""
|
668
|
+
|
669
|
+
annotate_var_explained: bool = Field(
|
670
|
+
default=False, description="Annotate the explained variance."
|
671
|
+
)
|
672
|
+
|
673
|
+
|
674
|
+
# 重构 UMAP 模型
|
675
|
+
class UMAPParam(BaseEmbeddingParam):
|
676
|
+
"""Input schema for the UMAP plotting tool."""
|
677
|
+
|
678
|
+
# 所有需要的字段已经在 BaseEmbeddingParam 中定义
|
679
|
+
|
680
|
+
|
681
|
+
# 重构 TSNE 模型
|
682
|
+
class TSNEParam(BaseEmbeddingParam):
|
683
|
+
"""Input schema for the TSNE plotting tool."""
|
684
|
+
|
685
|
+
# 所有需要的字段已经在 BaseEmbeddingParam 中定义
|
686
|
+
|
687
|
+
|
688
|
+
# 重构 DiffusionMapParam
|
689
|
+
class DiffusionMapParam(BaseEmbeddingParam):
|
690
|
+
"""Input schema for the diffusion map plotting tool."""
|
691
|
+
|
692
|
+
# 所有需要的字段已经在 BaseEmbeddingParam 中定义
|
693
|
+
|
694
|
+
|
695
|
+
class HighestExprGenesParam(BaseVisualizationParam):
|
696
|
+
"""Input schema for the highest_expr_genes plotting tool."""
|
697
|
+
|
698
|
+
n_top: int = Field(default=30, description="Number of top genes to plot.", gt=0)
|
699
|
+
|
700
|
+
gene_symbols: Optional[str] = Field(
|
701
|
+
default=None,
|
702
|
+
description="Key for field in .var that stores gene symbols if you do not want to use .var_names.",
|
703
|
+
)
|
704
|
+
|
705
|
+
log: bool = Field(default=False, description="Plot x-axis in log scale.")
|
706
|
+
|
707
|
+
@field_validator("n_top")
|
708
|
+
def validate_n_top(cls, v: int) -> int:
|
709
|
+
"""Validate n_top is positive"""
|
710
|
+
if v <= 0:
|
711
|
+
raise ValueError("n_top must be a positive integer")
|
712
|
+
return v
|
713
|
+
|
714
|
+
|
715
|
+
class HighlyVariableGenesParam(BaseVisualizationParam):
|
716
|
+
"""Input schema for the highly_variable_genes plotting tool."""
|
717
|
+
|
718
|
+
log: bool = Field(default=False, description="Plot on logarithmic axes.")
|
719
|
+
|
720
|
+
highly_variable_genes: bool = Field(
|
721
|
+
default=True, description="Whether to plot highly variable genes or all genes."
|
722
|
+
)
|
723
|
+
|
724
|
+
|
725
|
+
class PCAVarianceRatioParam(BaseVisualizationParam):
|
726
|
+
"""Input schema for the pca_variance_ratio plotting tool."""
|
727
|
+
|
728
|
+
n_pcs: int = Field(default=30, description="Number of PCs to show.", gt=0)
|
729
|
+
|
730
|
+
log: bool = Field(default=False, description="Plot on logarithmic scale.")
|
731
|
+
|
732
|
+
@field_validator("n_pcs")
|
733
|
+
def validate_n_pcs(cls, v: int) -> int:
|
734
|
+
"""Validate n_pcs is positive"""
|
735
|
+
if v <= 0:
|
736
|
+
raise ValueError("n_pcs must be a positive integer")
|
737
|
+
return v
|
738
|
+
|
739
|
+
|
740
|
+
# ... existing code ...
|
741
|
+
|
742
|
+
|
743
|
+
class RankGenesGroupsDotplotParam(BaseMatrixParam):
|
744
|
+
"""Input schema for the rank_genes_groups_dotplot plotting tool."""
|
745
|
+
|
746
|
+
groups: Optional[Union[str, List[str]]] = Field(
|
747
|
+
default=None, description="The groups for which to show the gene ranking."
|
748
|
+
)
|
749
|
+
n_genes: Optional[int] = Field(
|
750
|
+
default=None,
|
751
|
+
description="Number of genes to show. This can be a negative number to show down regulated genes. Ignored if var_names is passed.",
|
752
|
+
)
|
753
|
+
values_to_plot: Optional[
|
754
|
+
Literal[
|
755
|
+
"scores",
|
756
|
+
"logfoldchanges",
|
757
|
+
"pvals",
|
758
|
+
"pvals_adj",
|
759
|
+
"log10_pvals",
|
760
|
+
"log10_pvals_adj",
|
761
|
+
]
|
762
|
+
] = Field(
|
763
|
+
default=None,
|
764
|
+
description="Instead of the mean gene value, plot the values computed by sc.rank_genes_groups.",
|
765
|
+
)
|
766
|
+
min_logfoldchange: Optional[float] = Field(
|
767
|
+
default=None,
|
768
|
+
description="Value to filter genes in groups if their logfoldchange is less than the min_logfoldchange.",
|
769
|
+
)
|
770
|
+
key: Optional[str] = Field(
|
771
|
+
default=None, description="Key used to store the ranking results in adata.uns."
|
772
|
+
)
|
773
|
+
var_names: Union[List[str], Mapping[str, List[str]]] = Field(
|
774
|
+
default=None,
|
775
|
+
description="Genes to plot. Sometimes is useful to pass a specific list of var names (e.g. genes) to check their fold changes or p-values",
|
776
|
+
)
|
777
|
+
|
778
|
+
@field_validator("n_genes")
|
779
|
+
def validate_n_genes(cls, v: Optional[int]) -> Optional[int]:
|
780
|
+
"""Validate n_genes"""
|
781
|
+
# n_genes can be positive or negative, so no validation needed
|
782
|
+
return v
|
783
|
+
|
784
|
+
|
785
|
+
class EmbeddingParam(BaseEmbeddingParam):
|
786
|
+
"""Input schema for the embedding plotting tool."""
|
787
|
+
|
788
|
+
basis: str = Field(
|
789
|
+
..., # Required field
|
790
|
+
description="Name of the obsm basis to use.",
|
791
|
+
)
|
792
|
+
use_obsm: str = Field(
|
793
|
+
default=None, description="using data of adata.obsm instead of adata.X"
|
794
|
+
)
|
795
|
+
mask_obs: Optional[str] = Field(
|
796
|
+
default=None,
|
797
|
+
description="A boolean array or a string mask expression to subset observations.",
|
798
|
+
)
|
799
|
+
|
800
|
+
arrows_kwds: Optional[dict] = Field(
|
801
|
+
default=None,
|
802
|
+
description="Passed to matplotlib's quiver function for drawing arrows.",
|
803
|
+
)
|
804
|
+
|
805
|
+
scale_factor: Optional[float] = Field(
|
806
|
+
default=None, description="Scale factor for the plot."
|
807
|
+
)
|
808
|
+
|
809
|
+
cmap: Optional[str] = Field(
|
810
|
+
default=None,
|
811
|
+
description="Color map to use for continuous variables. Overrides color_map.",
|
812
|
+
)
|
813
|
+
|
814
|
+
na_color: str = Field(
|
815
|
+
default="lightgray", description="Color to use for null or masked values."
|
816
|
+
)
|
817
|
+
|
818
|
+
na_in_legend: bool = Field(
|
819
|
+
default=True, description="Whether to include null values in the legend."
|
820
|
+
)
|
821
|
+
|
822
|
+
outline_width: Tuple[float, float] = Field(
|
823
|
+
default=(0.3, 0.05), description="Width of the outline for highlighted points."
|
824
|
+
)
|
825
|
+
|
826
|
+
outline_color: Tuple[str, str] = Field(
|
827
|
+
default=("black", "white"),
|
828
|
+
description="Color of the outline for highlighted points.",
|
829
|
+
)
|
830
|
+
|
831
|
+
colorbar_loc: Optional[str] = Field(
|
832
|
+
default="right", description="Location of the colorbar."
|
833
|
+
)
|
834
|
+
|
835
|
+
hspace: float = Field(default=0.25, description="Height space between panels.")
|
836
|
+
|
837
|
+
wspace: Optional[float] = Field(
|
838
|
+
default=None, description="Width space between panels."
|
839
|
+
)
|
840
|
+
|
841
|
+
title: Optional[Union[str, List[str]]] = Field(
|
842
|
+
default=None, description="Title for the plot."
|
843
|
+
)
|