scmcp-shared 0.3.6__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scmcp_shared/__init__.py +1 -1
- scmcp_shared/agent.py +30 -0
- scmcp_shared/cli.py +10 -0
- scmcp_shared/schema/io.py +2 -2
- scmcp_shared/schema/pl.py +42 -42
- scmcp_shared/schema/pp.py +10 -10
- scmcp_shared/schema/tl.py +18 -18
- scmcp_shared/schema/tool.py +11 -0
- scmcp_shared/schema/util.py +9 -9
- scmcp_shared/server/auto.py +52 -0
- scmcp_shared/server/base.py +7 -12
- scmcp_shared/server/io.py +5 -4
- scmcp_shared/server/pl.py +33 -32
- scmcp_shared/server/pp.py +32 -24
- scmcp_shared/server/tl.py +35 -34
- scmcp_shared/server/util.py +19 -18
- {scmcp_shared-0.3.6.dist-info → scmcp_shared-0.4.0.dist-info}/METADATA +5 -2
- scmcp_shared-0.4.0.dist-info/RECORD +24 -0
- scmcp_shared-0.3.6.dist-info/RECORD +0 -21
- {scmcp_shared-0.3.6.dist-info → scmcp_shared-0.4.0.dist-info}/WHEEL +0 -0
- {scmcp_shared-0.3.6.dist-info → scmcp_shared-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,52 @@
|
|
1
|
+
from fastmcp import FastMCP
|
2
|
+
from fastmcp.server.dependencies import get_context
|
3
|
+
from ..agent import select_tool
|
4
|
+
from pydantic import Field
|
5
|
+
|
6
|
+
auto_mcp = FastMCP("SmartMCP-select-Server")
|
7
|
+
|
8
|
+
|
9
|
+
@auto_mcp.tool()
|
10
|
+
def search_tool(
|
11
|
+
task: str= Field(description="The tasks or questions that needs to be solved using available tools")
|
12
|
+
):
|
13
|
+
"""search the tools that can be used to solve the user's tasks or questions"""
|
14
|
+
ctx = get_context()
|
15
|
+
fastmcp = ctx.fastmcp
|
16
|
+
all_tools = fastmcp._tool_manager._all_tools
|
17
|
+
auto_tools = fastmcp._tool_manager._tools
|
18
|
+
fastmcp._tool_manager._tools = all_tools
|
19
|
+
query = f"<task>{task}</task>\n"
|
20
|
+
for name in all_tools:
|
21
|
+
tool = all_tools[name]
|
22
|
+
query += f"<Tool>\n<name>{name}</name>\n<description>{tool.description}</description>\n</Tool>\n"
|
23
|
+
|
24
|
+
results = select_tool(query)
|
25
|
+
tool_list = []
|
26
|
+
for tool in results:
|
27
|
+
tool = tool.model_dump()
|
28
|
+
tool["parameters"] = all_tools[tool["name"]].parameters
|
29
|
+
tool_list.append(tool)
|
30
|
+
fastmcp._tool_manager._tools = auto_tools
|
31
|
+
return tool_list
|
32
|
+
|
33
|
+
|
34
|
+
@auto_mcp.tool()
|
35
|
+
async def run_tool(
|
36
|
+
name: str= Field(description="The name of the tool to run"),
|
37
|
+
parameter: dict = Field(description="The parameters to pass to the tool")
|
38
|
+
):
|
39
|
+
"""run the tool with the given name and parameters. Only start call the tool when last tool is finished."""
|
40
|
+
ctx = get_context()
|
41
|
+
fastmcp = ctx.fastmcp
|
42
|
+
all_tools = fastmcp._tool_manager._all_tools
|
43
|
+
auto_tools = fastmcp._tool_manager._tools
|
44
|
+
fastmcp._tool_manager._tools = all_tools
|
45
|
+
|
46
|
+
try:
|
47
|
+
result = await fastmcp._tool_manager.call_tool(name, parameter)
|
48
|
+
except Exception as e:
|
49
|
+
fastmcp._tool_manager._tools = auto_tools
|
50
|
+
result = {"error": str(e)}
|
51
|
+
|
52
|
+
return result
|
scmcp_shared/server/base.py
CHANGED
@@ -6,7 +6,7 @@ from collections.abc import AsyncIterator
|
|
6
6
|
from contextlib import asynccontextmanager
|
7
7
|
import asyncio
|
8
8
|
from typing import Optional, List, Any, Iterable
|
9
|
-
|
9
|
+
from .auto import auto_mcp
|
10
10
|
|
11
11
|
class BaseMCP:
|
12
12
|
"""Base class for all Scanpy MCP classes."""
|
@@ -33,25 +33,20 @@ class BaseMCP:
|
|
33
33
|
methods = inspect.getmembers(self, predicate=inspect.ismethod)
|
34
34
|
|
35
35
|
# Filter methods that start with _tool_
|
36
|
-
tool_methods =
|
37
|
-
name[6:]: method # Remove '_tool_' prefix
|
38
|
-
for name, method in methods
|
39
|
-
if name.startswith('_tool_')
|
40
|
-
}
|
36
|
+
tool_methods = [tl_method() for name, tl_method in methods if name.startswith('_tool_')]
|
41
37
|
|
42
38
|
# Filter tools based on include/exclude lists
|
43
39
|
if self.include_tools is not None:
|
44
|
-
tool_methods =
|
40
|
+
tool_methods = [tl for tl in tool_methods if tl.name in self.include_tools]
|
45
41
|
|
46
42
|
if self.exclude_tools is not None:
|
47
|
-
tool_methods =
|
43
|
+
tool_methods = [tl for tl in tool_methods if tl.name not in self.exclude_tools]
|
48
44
|
|
49
45
|
# Register filtered tools
|
50
|
-
for
|
46
|
+
for tool in tool_methods:
|
51
47
|
# Get the function returned by the tool method
|
52
|
-
|
53
|
-
|
54
|
-
self.mcp.add_tool(tool_func, name=tool_name)
|
48
|
+
if tool is not None:
|
49
|
+
self.mcp.add_tool(tool)
|
55
50
|
|
56
51
|
|
57
52
|
class AdataState:
|
scmcp_shared/server/io.py
CHANGED
@@ -3,6 +3,7 @@ import inspect
|
|
3
3
|
from pathlib import Path
|
4
4
|
import scanpy as sc
|
5
5
|
from fastmcp import FastMCP, Context
|
6
|
+
from fastmcp.tools.tool import Tool
|
6
7
|
from fastmcp.exceptions import ToolError
|
7
8
|
from ..schema import AdataInfo
|
8
9
|
from ..schema.io import *
|
@@ -16,7 +17,7 @@ class ScanpyIOMCP(BaseMCP):
|
|
16
17
|
super().__init__("SCMCP-IO-Server", include_tools, exclude_tools, AdataInfo)
|
17
18
|
|
18
19
|
def _tool_read(self):
|
19
|
-
def _read(request:
|
20
|
+
def _read(request: ReadParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
20
21
|
"""
|
21
22
|
Read data from 10X directory or various file formats (h5ad, 10x, text files, etc.).
|
22
23
|
"""
|
@@ -66,10 +67,10 @@ class ScanpyIOMCP(BaseMCP):
|
|
66
67
|
raise ToolError(e.__context__)
|
67
68
|
else:
|
68
69
|
raise ToolError(e)
|
69
|
-
return _read
|
70
|
+
return Tool.from_function(_read, name="read")
|
70
71
|
|
71
72
|
def _tool_write(self):
|
72
|
-
def _write(request:
|
73
|
+
def _write(request: WriteParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
73
74
|
"""save adata into a file."""
|
74
75
|
try:
|
75
76
|
res = forward_request("io_write", request, adinfo)
|
@@ -87,7 +88,7 @@ class ScanpyIOMCP(BaseMCP):
|
|
87
88
|
raise ToolError(e.__context__)
|
88
89
|
else:
|
89
90
|
raise ToolError(e)
|
90
|
-
return _write
|
91
|
+
return Tool.from_function(_write, name="write")
|
91
92
|
|
92
93
|
|
93
94
|
# Create an instance of the class
|
scmcp_shared/server/pl.py
CHANGED
@@ -3,6 +3,7 @@ import inspect
|
|
3
3
|
from functools import partial
|
4
4
|
import scanpy as sc
|
5
5
|
from fastmcp import FastMCP, Context
|
6
|
+
from fastmcp.tools.tool import Tool
|
6
7
|
from fastmcp.exceptions import ToolError
|
7
8
|
from ..schema.pl import *
|
8
9
|
from ..schema import AdataInfo
|
@@ -24,7 +25,7 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
24
25
|
super().__init__("ScanpyMCP-PL-Server", include_tools, exclude_tools, AdataInfo)
|
25
26
|
|
26
27
|
def _tool_pca(self):
|
27
|
-
def _pca(request:
|
28
|
+
def _pca(request: PCAParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
28
29
|
"""Scatter plot in PCA coordinates. default figure for PCA plot"""
|
29
30
|
try:
|
30
31
|
if (res := forward_request("pl_pca", request, adinfo)) is not None:
|
@@ -39,10 +40,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
39
40
|
raise ToolError(e.__context__)
|
40
41
|
else:
|
41
42
|
raise ToolError(e)
|
42
|
-
return _pca
|
43
|
+
return Tool.from_function(_pca, name="pca")
|
43
44
|
|
44
45
|
def _tool_diffmap(self):
|
45
|
-
def _diffmap(request:
|
46
|
+
def _diffmap(request: DiffusionMapParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
46
47
|
"""Plot diffusion map embedding of cells."""
|
47
48
|
try:
|
48
49
|
if (res := forward_request("pl_diffmap", request, adinfo)) is not None:
|
@@ -57,10 +58,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
57
58
|
raise ToolError(e.__context__)
|
58
59
|
else:
|
59
60
|
raise ToolError(e)
|
60
|
-
return _diffmap
|
61
|
+
return Tool.from_function(_diffmap, name="diffmap")
|
61
62
|
|
62
63
|
def _tool_violin(self):
|
63
|
-
def _violin(request:
|
64
|
+
def _violin(request: ViolinParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
64
65
|
"""Plot violin plot of one or more variables."""
|
65
66
|
try:
|
66
67
|
if (res := forward_request("pl_violin", request, adinfo)) is not None:
|
@@ -77,10 +78,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
77
78
|
raise ToolError(e.__context__)
|
78
79
|
else:
|
79
80
|
raise ToolError(e)
|
80
|
-
return _violin
|
81
|
+
return Tool.from_function(_violin, name="violin")
|
81
82
|
|
82
83
|
def _tool_stacked_violin(self):
|
83
|
-
def _stacked_violin(request:
|
84
|
+
def _stacked_violin(request: StackedViolinParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
84
85
|
"""Plot stacked violin plots. Makes a compact image composed of individual violin plots stacked on top of each other."""
|
85
86
|
try:
|
86
87
|
if (res := forward_request("pl_stacked_violin", request, adinfo)) is not None:
|
@@ -95,10 +96,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
95
96
|
raise ToolError(e.__context__)
|
96
97
|
else:
|
97
98
|
raise ToolError(e)
|
98
|
-
return _stacked_violin
|
99
|
+
return Tool.from_function(_stacked_violin, name="stacked_violin")
|
99
100
|
|
100
101
|
def _tool_heatmap(self):
|
101
|
-
async def _heatmap(request:
|
102
|
+
async def _heatmap(request: HeatmapParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
102
103
|
"""Heatmap of the expression values of genes."""
|
103
104
|
try:
|
104
105
|
if (res := forward_request("pl_heatmap", request, adinfo)) is not None:
|
@@ -113,10 +114,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
113
114
|
raise ToolError(e.__context__)
|
114
115
|
else:
|
115
116
|
raise ToolError(e)
|
116
|
-
return _heatmap
|
117
|
+
return Tool.from_function(_heatmap, name="heatmap")
|
117
118
|
|
118
119
|
def _tool_dotplot(self):
|
119
|
-
def _dotplot(request:
|
120
|
+
def _dotplot(request: DotplotParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
120
121
|
"""Plot dot plot of expression values per gene for each group."""
|
121
122
|
try:
|
122
123
|
if (res := forward_request("pl_dotplot", request, adinfo)) is not None:
|
@@ -131,10 +132,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
131
132
|
raise ToolError(e.__context__)
|
132
133
|
else:
|
133
134
|
raise ToolError(e)
|
134
|
-
return _dotplot
|
135
|
+
return Tool.from_function(_dotplot, name="dotplot")
|
135
136
|
|
136
137
|
def _tool_matrixplot(self):
|
137
|
-
def _matrixplot(request:
|
138
|
+
def _matrixplot(request: MatrixplotParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
138
139
|
"""matrixplot, Create a heatmap of the mean expression values per group of each var_names."""
|
139
140
|
try:
|
140
141
|
if (res := forward_request("pl_matrixplot", request, adinfo)) is not None:
|
@@ -149,10 +150,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
149
150
|
raise ToolError(e.__context__)
|
150
151
|
else:
|
151
152
|
raise ToolError(e)
|
152
|
-
return _matrixplot
|
153
|
+
return Tool.from_function(_matrixplot, name="matrixplot")
|
153
154
|
|
154
155
|
def _tool_tracksplot(self):
|
155
|
-
def _tracksplot(request:
|
156
|
+
def _tracksplot(request: TracksplotParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
156
157
|
"""tracksplot, compact plot of expression of a list of genes."""
|
157
158
|
try:
|
158
159
|
if (res := forward_request("pl_tracksplot", request, adinfo)) is not None:
|
@@ -167,10 +168,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
167
168
|
raise ToolError(e.__context__)
|
168
169
|
else:
|
169
170
|
raise ToolError(e)
|
170
|
-
return _tracksplot
|
171
|
+
return Tool.from_function(_tracksplot, name="tracksplot")
|
171
172
|
|
172
173
|
def _tool_scatter(self):
|
173
|
-
def _scatter(request:
|
174
|
+
def _scatter(request: EnhancedScatterParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
174
175
|
"""Plot a scatter plot of two variables, Scatter plot along observations or variables axes."""
|
175
176
|
try:
|
176
177
|
if (res := forward_request("pl_scatter", request, adinfo)) is not None:
|
@@ -185,10 +186,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
185
186
|
raise ToolError(e.__context__)
|
186
187
|
else:
|
187
188
|
raise ToolError(e)
|
188
|
-
return _scatter
|
189
|
+
return Tool.from_function(_scatter, name="scatter")
|
189
190
|
|
190
191
|
def _tool_embedding(self):
|
191
|
-
def _embedding(request:
|
192
|
+
def _embedding(request: EmbeddingParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
192
193
|
"""Scatter plot for user specified embedding basis (e.g. umap, tsne, etc)."""
|
193
194
|
try:
|
194
195
|
if (res := forward_request("pl_embedding", request, adinfo)) is not None:
|
@@ -203,10 +204,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
203
204
|
raise ToolError(e.__context__)
|
204
205
|
else:
|
205
206
|
raise ToolError(e)
|
206
|
-
return _embedding
|
207
|
+
return Tool.from_function(_embedding, name="embedding")
|
207
208
|
|
208
209
|
def _tool_embedding_density(self):
|
209
|
-
def _embedding_density(request:
|
210
|
+
def _embedding_density(request: EmbeddingDensityParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
210
211
|
"""Plot the density of cells in an embedding."""
|
211
212
|
try:
|
212
213
|
if (res := forward_request("pl_embedding_density", request, adinfo)) is not None:
|
@@ -221,10 +222,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
221
222
|
raise ToolError(e.__context__)
|
222
223
|
else:
|
223
224
|
raise ToolError(e)
|
224
|
-
return _embedding_density
|
225
|
+
return Tool.from_function(_embedding_density, name="embedding_density")
|
225
226
|
|
226
227
|
def _tool_rank_genes_groups(self):
|
227
|
-
def _rank_genes_groups(request:
|
228
|
+
def _rank_genes_groups(request: RankGenesGroupsParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
228
229
|
"""Plot ranking of genes based on differential expression."""
|
229
230
|
try:
|
230
231
|
if (res := forward_request("pl_rank_genes_groups", request, adinfo)) is not None:
|
@@ -239,10 +240,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
239
240
|
raise ToolError(e.__context__)
|
240
241
|
else:
|
241
242
|
raise ToolError(e)
|
242
|
-
return _rank_genes_groups
|
243
|
+
return Tool.from_function(_rank_genes_groups, name="rank_genes_groups")
|
243
244
|
|
244
245
|
def _tool_rank_genes_groups_dotplot(self):
|
245
|
-
def _rank_genes_groups_dotplot(request:
|
246
|
+
def _rank_genes_groups_dotplot(request: RankGenesGroupsDotplotParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
246
247
|
"""Plot ranking of genes(DEGs) using dotplot visualization. Defualt plot DEGs for rank_genes_groups tool"""
|
247
248
|
try:
|
248
249
|
if (res := forward_request("pl_rank_genes_groups_dotplot", request, adinfo)) is not None:
|
@@ -257,10 +258,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
257
258
|
raise ToolError(e.__context__)
|
258
259
|
else:
|
259
260
|
raise ToolError(e)
|
260
|
-
return _rank_genes_groups_dotplot
|
261
|
+
return Tool.from_function(_rank_genes_groups_dotplot, name="rank_genes_groups_dotplot")
|
261
262
|
|
262
263
|
def _tool_clustermap(self):
|
263
|
-
def _clustermap(request:
|
264
|
+
def _clustermap(request: ClusterMapParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
264
265
|
"""Plot hierarchical clustering of cells and genes."""
|
265
266
|
try:
|
266
267
|
if (res := forward_request("pl_clustermap", request, adinfo)) is not None:
|
@@ -275,10 +276,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
275
276
|
raise ToolError(e.__context__)
|
276
277
|
else:
|
277
278
|
raise ToolError(e)
|
278
|
-
return _clustermap
|
279
|
+
return Tool.from_function(_clustermap, name="clustermap")
|
279
280
|
|
280
281
|
def _tool_highly_variable_genes(self):
|
281
|
-
def _highly_variable_genes(request:
|
282
|
+
def _highly_variable_genes(request: HighlyVariableGenesParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
282
283
|
"""plot highly variable genes; Plot dispersions or normalized variance versus means for genes."""
|
283
284
|
try:
|
284
285
|
if (res := forward_request("pl_highly_variable_genes", request, adinfo)) is not None:
|
@@ -293,10 +294,10 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
293
294
|
raise ToolError(e.__context__)
|
294
295
|
else:
|
295
296
|
raise ToolError(e)
|
296
|
-
return _highly_variable_genes
|
297
|
+
return Tool.from_function(_highly_variable_genes, name="highly_variable_genes")
|
297
298
|
|
298
299
|
def _tool_pca_variance_ratio(self):
|
299
|
-
def _pca_variance_ratio(request:
|
300
|
+
def _pca_variance_ratio(request: PCAVarianceRatioParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
300
301
|
"""Plot the PCA variance ratio to visualize explained variance."""
|
301
302
|
try:
|
302
303
|
if (res := forward_request("pl_pca_variance_ratio", request, adinfo)) is not None:
|
@@ -311,5 +312,5 @@ class ScanpyPlottingMCP(BaseMCP):
|
|
311
312
|
raise ToolError(e.__context__)
|
312
313
|
else:
|
313
314
|
raise ToolError(e)
|
314
|
-
return _pca_variance_ratio
|
315
|
+
return Tool.from_function(_pca_variance_ratio, name="pca_variance_ratio")
|
315
316
|
|
scmcp_shared/server/pp.py
CHANGED
@@ -2,6 +2,7 @@ import os
|
|
2
2
|
import inspect
|
3
3
|
import scanpy as sc
|
4
4
|
from fastmcp import FastMCP, Context
|
5
|
+
from fastmcp.tools.tool import Tool
|
5
6
|
from fastmcp.exceptions import ToolError
|
6
7
|
from ..schema.pp import *
|
7
8
|
from ..schema import AdataInfo
|
@@ -22,7 +23,7 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
22
23
|
super().__init__("ScanpyMCP-PP-Server", include_tools, exclude_tools, AdataInfo)
|
23
24
|
|
24
25
|
def _tool_subset_cells(self):
|
25
|
-
def _subset_cells(request:
|
26
|
+
def _subset_cells(request: SubsetCellParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
26
27
|
"""filter or subset cells based on total genes expressed counts and numbers. or values in adata.obs[obs_key]"""
|
27
28
|
try:
|
28
29
|
result = forward_request("subset_cells", request, adinfo)
|
@@ -62,10 +63,10 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
62
63
|
raise ToolError(e.__context__)
|
63
64
|
else:
|
64
65
|
raise ToolError(e)
|
65
|
-
return _subset_cells
|
66
|
+
return Tool.from_function(_subset_cells, name="subset_cells")
|
66
67
|
|
67
68
|
def _tool_subset_genes(self):
|
68
|
-
def _subset_genes(request:
|
69
|
+
def _subset_genes(request: SubsetGeneParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
69
70
|
"""filter or subset genes based on number of cells or counts, or values in adata.var[var_key] or subset highly variable genes"""
|
70
71
|
try:
|
71
72
|
result = forward_request("pp_subset_genes", request, adinfo)
|
@@ -74,6 +75,11 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
74
75
|
func_kwargs = filter_args(request, sc.pp.filter_genes)
|
75
76
|
ads = get_ads()
|
76
77
|
adata = ads.get_adata(adinfo=adinfo).copy()
|
78
|
+
if request.highly_variable:
|
79
|
+
adata = adata[:, adata.var.highly_variable]
|
80
|
+
add_op_log(adata, "subset_genes",
|
81
|
+
{"hpv": "true"}, adinfo
|
82
|
+
)
|
77
83
|
if func_kwargs:
|
78
84
|
sc.pp.filter_genes(adata, **func_kwargs)
|
79
85
|
add_op_log(adata, sc.pp.filter_genes, func_kwargs, adinfo)
|
@@ -86,8 +92,6 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
86
92
|
if request.var_max is not None:
|
87
93
|
mask = mask & (adata.var[request.var_key] <= request.var_max)
|
88
94
|
adata = adata[:, mask]
|
89
|
-
if request.highly_variable:
|
90
|
-
adata = adata[:, mask & adata.var.highly_variable]
|
91
95
|
add_op_log(adata, "subset_genes",
|
92
96
|
{
|
93
97
|
"var_key": request.var_key, "var_min": request.var_min, "var_max": request.var_max,
|
@@ -103,7 +107,7 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
103
107
|
raise ToolError(e.__context__)
|
104
108
|
else:
|
105
109
|
raise ToolError(e)
|
106
|
-
return _subset_genes
|
110
|
+
return Tool.from_function(_subset_genes, name="subset_genes")
|
107
111
|
|
108
112
|
def _tool_calculate_qc_metrics(self):
|
109
113
|
def _calculate_qc_metrics(request: CalculateQCMetrics, adinfo: self.AdataInfo=self.AdataInfo()):
|
@@ -116,12 +120,16 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
116
120
|
func_kwargs = filter_args(request, sc.pp.calculate_qc_metrics)
|
117
121
|
ads = get_ads()
|
118
122
|
adata = ads.get_adata(adinfo=adinfo)
|
123
|
+
if request.qc_vars:
|
124
|
+
for var in request.qc_vars:
|
125
|
+
if var not in adata.var.columns:
|
126
|
+
return f"Cound find {var} in adata.var, consider to use mark_var tool to mark the variable"
|
119
127
|
func_kwargs["inplace"] = True
|
120
128
|
try:
|
121
129
|
sc.pp.calculate_qc_metrics(adata, **func_kwargs)
|
122
130
|
add_op_log(adata, sc.pp.calculate_qc_metrics, func_kwargs, adinfo)
|
123
131
|
except KeyError as e:
|
124
|
-
raise KeyError(f"Cound find {e} in adata.var")
|
132
|
+
raise KeyError(f"Cound find {e} in adata.var, consider to use mark_var tool to mark the variable")
|
125
133
|
return [generate_msg(adinfo, adata, ads)]
|
126
134
|
except ToolError as e:
|
127
135
|
raise ToolError(e)
|
@@ -130,10 +138,10 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
130
138
|
raise ToolError(e.__context__)
|
131
139
|
else:
|
132
140
|
raise ToolError(e)
|
133
|
-
return _calculate_qc_metrics
|
141
|
+
return Tool.from_function(_calculate_qc_metrics, name="calculate_qc_metrics")
|
134
142
|
|
135
143
|
def _tool_log1p(self):
|
136
|
-
def _log1p(request:
|
144
|
+
def _log1p(request: Log1PParams=Log1PParams(), adinfo: self.AdataInfo=self.AdataInfo()):
|
137
145
|
"""Logarithmize the data matrix"""
|
138
146
|
try:
|
139
147
|
result = forward_request("pp_log1p", request, adinfo)
|
@@ -157,10 +165,10 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
157
165
|
raise ToolError(e.__context__)
|
158
166
|
else:
|
159
167
|
raise ToolError(e)
|
160
|
-
return _log1p
|
168
|
+
return Tool.from_function(_log1p, name="log1p")
|
161
169
|
|
162
170
|
def _tool_normalize_total(self):
|
163
|
-
def _normalize_total(request:
|
171
|
+
def _normalize_total(request: NormalizeTotalParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
164
172
|
"""Normalize counts per cell to the same total count"""
|
165
173
|
try:
|
166
174
|
result = forward_request("pp_normalize_total", request, adinfo)
|
@@ -180,10 +188,10 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
180
188
|
raise ToolError(e.__context__)
|
181
189
|
else:
|
182
190
|
raise ToolError(e)
|
183
|
-
return _normalize_total
|
191
|
+
return Tool.from_function(_normalize_total, name="normalize_total")
|
184
192
|
|
185
193
|
def _tool_highly_variable_genes(self):
|
186
|
-
def _highly_variable_genes(request:
|
194
|
+
def _highly_variable_genes(request: HighlyVariableGenesParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
187
195
|
"""Annotate highly variable genes"""
|
188
196
|
try:
|
189
197
|
result = forward_request("pp_highly_variable_genes", request, adinfo)
|
@@ -205,10 +213,10 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
205
213
|
raise ToolError(e.__context__)
|
206
214
|
else:
|
207
215
|
raise ToolError(e)
|
208
|
-
return _highly_variable_genes
|
216
|
+
return Tool.from_function(_highly_variable_genes, name="highly_variable_genes")
|
209
217
|
|
210
218
|
def _tool_regress_out(self):
|
211
|
-
def _regress_out(request:
|
219
|
+
def _regress_out(request: RegressOutParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
212
220
|
"""Regress out (mostly) unwanted sources of variation."""
|
213
221
|
try:
|
214
222
|
result = forward_request("pp_regress_out", request, adinfo)
|
@@ -228,10 +236,10 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
228
236
|
raise ToolError(e.__context__)
|
229
237
|
else:
|
230
238
|
raise ToolError(e)
|
231
|
-
return _regress_out
|
239
|
+
return Tool.from_function(_regress_out, name="regress_out")
|
232
240
|
|
233
241
|
def _tool_scale(self):
|
234
|
-
def _scale(request:
|
242
|
+
def _scale(request: ScaleParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
235
243
|
"""Scale data to unit variance and zero mean"""
|
236
244
|
try:
|
237
245
|
result = forward_request("pp_scale", request, adinfo)
|
@@ -253,10 +261,10 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
253
261
|
raise ToolError(e.__context__)
|
254
262
|
else:
|
255
263
|
raise ToolError(e)
|
256
|
-
return _scale
|
264
|
+
return Tool.from_function(_scale, name="scale")
|
257
265
|
|
258
266
|
def _tool_combat(self):
|
259
|
-
def _combat(request:
|
267
|
+
def _combat(request: CombatParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
260
268
|
"""ComBat function for batch effect correction"""
|
261
269
|
try:
|
262
270
|
result = forward_request("pp_combat", request, adinfo)
|
@@ -278,10 +286,10 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
278
286
|
raise ToolError(e.__context__)
|
279
287
|
else:
|
280
288
|
raise ToolError(e)
|
281
|
-
return _combat
|
289
|
+
return Tool.from_function(_combat, name="combat")
|
282
290
|
|
283
291
|
def _tool_scrublet(self):
|
284
|
-
def _scrublet(request:
|
292
|
+
def _scrublet(request: ScrubletParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
285
293
|
"""Predict doublets using Scrublet"""
|
286
294
|
try:
|
287
295
|
result = forward_request("pp_scrublet", request, adinfo)
|
@@ -300,10 +308,10 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
300
308
|
raise ToolError(e.__context__)
|
301
309
|
else:
|
302
310
|
raise ToolError(e)
|
303
|
-
return _scrublet
|
311
|
+
return Tool.from_function(_scrublet, name="scrublet")
|
304
312
|
|
305
313
|
def _tool_neighbors(self):
|
306
|
-
def _neighbors(request:
|
314
|
+
def _neighbors(request: NeighborsParams, adinfo: self.AdataInfo=self.AdataInfo()):
|
307
315
|
"""Compute nearest neighbors distance matrix and neighborhood graph"""
|
308
316
|
try:
|
309
317
|
result = forward_request("pp_neighbors", request, adinfo)
|
@@ -322,4 +330,4 @@ class ScanpyPreprocessingMCP(BaseMCP):
|
|
322
330
|
raise ToolError(e.__context__)
|
323
331
|
else:
|
324
332
|
raise ToolError(e)
|
325
|
-
return _neighbors
|
333
|
+
return Tool.from_function(_neighbors, name="neighbors")
|