scmcp-shared 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,153 @@
1
+ import inspect
2
+ from fastmcp import FastMCP
3
+ from ..schema import AdataInfo
4
+ from ..util import filter_tools
5
+ from collections.abc import AsyncIterator
6
+ from contextlib import asynccontextmanager
7
+ import asyncio
8
+ from typing import Optional, List, Any, Iterable
9
+
10
+
11
+ class BaseMCP:
12
+ """Base class for all Scanpy MCP classes."""
13
+
14
+ def __init__(self, name: str, include_tools: list = None, exclude_tools: list = None, AdataInfo = AdataInfo):
15
+ """
16
+ Initialize BaseMCP with optional tool filtering.
17
+
18
+ Args:
19
+ name (str): Name of the MCP server
20
+ include_tools (list, optional): List of tool names to include. If None, all tools are included.
21
+ exclude_tools (list, optional): List of tool names to exclude. If None, no tools are excluded.
22
+ AdataInfo: The AdataInfo class to use for type annotations.
23
+ """
24
+ self.mcp = FastMCP(name)
25
+ self.include_tools = include_tools
26
+ self.exclude_tools = exclude_tools
27
+ self.AdataInfo = AdataInfo
28
+ self._register_tools()
29
+
30
+ def _register_tools(self):
31
+ """Register all tool methods with the FastMCP instance based on include/exclude filters"""
32
+ # Get all methods of the class
33
+ methods = inspect.getmembers(self, predicate=inspect.ismethod)
34
+
35
+ # Filter methods that start with _tool_
36
+ tool_methods = {
37
+ name[6:]: method # Remove '_tool_' prefix
38
+ for name, method in methods
39
+ if name.startswith('_tool_')
40
+ }
41
+
42
+ # Filter tools based on include/exclude lists
43
+ if self.include_tools is not None:
44
+ tool_methods = {k: v for k, v in tool_methods.items() if k in self.include_tools}
45
+
46
+ if self.exclude_tools is not None:
47
+ tool_methods = {k: v for k, v in tool_methods.items() if k not in self.exclude_tools}
48
+
49
+ # Register filtered tools
50
+ for tool_name, tool_method in tool_methods.items():
51
+ # Get the function returned by the tool method
52
+ tool_func = tool_method()
53
+ if tool_func is not None:
54
+ self.mcp.add_tool(tool_func, name=tool_name)
55
+
56
+
57
+ class AdataState:
58
+ def __init__(self, add_adtypes=None):
59
+ self.adata_dic = {"exp": {}, "activity": {}, "cnv": {}, "splicing": {}}
60
+ if isinstance(add_adtypes, str):
61
+ self.adata_dic[add_adtypes] = {}
62
+ elif isinstance(add_adtypes, Iterable):
63
+ self.adata_dic.update({adtype: {} for adtype in add_adtypes})
64
+ self.active_id = None
65
+ self.metadatWa = {}
66
+ self.cr_kernel = {}
67
+ self.cr_estimator = {}
68
+
69
+ def get_adata(self, sampleid=None, adtype="exp", adinfo=None):
70
+ if adinfo is not None:
71
+ kwargs = adinfo.model_dump()
72
+ sampleid = kwargs.get("sampleid", None)
73
+ adtype = kwargs.get("adtype", "exp")
74
+ try:
75
+ if self.active_id is None:
76
+ return None
77
+ sampleid = sampleid or self.active_id
78
+ return self.adata_dic[adtype][sampleid]
79
+ except KeyError as e:
80
+ raise KeyError(f"Key {e} not found in adata_dic[{adtype}].Please check the sampleid or adtype.")
81
+ except Exception as e:
82
+ raise Exception(f"fuck {e} {type(e)}")
83
+
84
+ def set_adata(self, adata, sampleid=None, sdtype="exp", adinfo=None):
85
+ if adinfo is not None:
86
+ kwargs = adinfo.model_dump()
87
+ sampleid = kwargs.get("sampleid", None)
88
+ sdtype = kwargs.get("adtype", "exp")
89
+ sampleid = sampleid or self.active_id
90
+ if sdtype not in self.adata_dic:
91
+ self.adata_dic[sdtype] = {}
92
+ self.adata_dic[sdtype][sampleid] = adata
93
+
94
+
95
+ class BaseMCPManager:
96
+ """Base class for MCP module management."""
97
+
98
+ def __init__(self,
99
+ name: str,
100
+ include_modules: Optional[List[str]] = None,
101
+ exclude_modules: Optional[List[str]] = None,
102
+ include_tools: Optional[List[str]] = None,
103
+ exclude_tools: Optional[List[str]] = None,
104
+ ):
105
+ """
106
+ Initialize BaseMCPManager with optional module filtering.
107
+
108
+ Args:
109
+ name (str): Name of the MCP server
110
+ include_modules (List[str], optional): List of module names to include. If None, all modules are included.
111
+ exclude_modules (List[str], optional): List of module names to exclude. If None, no modules are excluded.
112
+ include_tools (List[str], optional): List of tool names to include. If None, all tools are included.
113
+ exclude_tools (List[str], optional): List of tool names to exclude. If None, no tools are excluded.
114
+ """
115
+ self.ads = AdataState()
116
+ self.mcp = FastMCP(name, lifespan=self.adata_lifespan)
117
+ self.include_modules = include_modules
118
+ self.exclude_modules = exclude_modules
119
+ self.include_tools = include_tools
120
+ self.exclude_tools = exclude_tools
121
+ self.available_modules = {}
122
+ self._init_modules()
123
+ self._register_modules()
124
+
125
+ def _init_modules(self):
126
+ """Initialize available modules. To be implemented by subclasses."""
127
+ raise NotImplementedError("Subclasses must implement _init_modules")
128
+
129
+ def _register_modules(self):
130
+ """Register modules based on include/exclude filters."""
131
+ # Filter modules based on include/exclude lists
132
+ if self.include_modules is not None:
133
+ self.available_modules = {k: v for k, v in self.available_modules.items() if k in self.include_modules}
134
+
135
+ if self.exclude_modules is not None:
136
+ self.available_modules = {k: v for k, v in self.available_modules.items() if k not in self.exclude_modules}
137
+
138
+ # Register each module
139
+ for module_name, mcpi in self.available_modules.items():
140
+ if isinstance(mcpi, FastMCP):
141
+ if self.include_tools is not None and module_name in self.include_tools:
142
+ mcpi = filter_tools(mcpi, include_tools= self.include_tools[module_name])
143
+ if self.exclude_tools is not None and module_name in self.exclude_tools:
144
+ mcpi = filter_tools(mcpi, exclude_tools=self.exclude_tools[module_name])
145
+
146
+ asyncio.run(self.mcp.import_server(module_name, mcpi))
147
+ else:
148
+ asyncio.run(self.mcp.import_server(module_name, mcpi().mcp))
149
+
150
+ @asynccontextmanager
151
+ async def adata_lifespan(self, server: FastMCP) -> AsyncIterator[Any]:
152
+ """Context manager for AdataState lifecycle."""
153
+ yield self.ads
scmcp_shared/server/io.py CHANGED
@@ -2,73 +2,83 @@ import os
2
2
  import inspect
3
3
  from pathlib import Path
4
4
  import scanpy as sc
5
- from fastmcp import FastMCP , Context
5
+ from fastmcp import FastMCP, Context
6
+ from fastmcp.exceptions import ToolError
7
+ from ..schema import AdataInfo
6
8
  from ..schema.io import *
7
9
  from ..util import filter_args, forward_request, get_ads, generate_msg
10
+ from .base import BaseMCP
8
11
 
9
12
 
10
- io_mcp = FastMCP("SCMCP-IO-Server")
13
+ class ScanpyIOMCP(BaseMCP):
14
+ def __init__(self, include_tools: list = None, exclude_tools: list = None, AdataInfo = AdataInfo):
15
+ """Initialize ScanpyIOMCP with optional tool filtering."""
16
+ super().__init__("SCMCP-IO-Server", include_tools, exclude_tools, AdataInfo)
11
17
 
18
+ def _tool_read(self):
19
+ def _read(request: ReadModel, adinfo: self.AdataInfo=self.AdataInfo()):
20
+ """
21
+ Read data from 10X directory or various file formats (h5ad, 10x, text files, etc.).
22
+ """
23
+ try:
24
+ res = forward_request("io_read", request, adinfo)
25
+ if res is not None:
26
+ return res
27
+ kwargs = request.model_dump()
28
+ file = Path(kwargs.get("filename", None))
29
+ if file.is_dir():
30
+ kwargs["path"] = kwargs["filename"]
31
+ func_kwargs = filter_args(request, sc.read_10x_mtx)
32
+ adata = sc.read_10x_mtx(kwargs["path"], **func_kwargs)
33
+ elif file.is_file():
34
+ func_kwargs = filter_args(request, sc.read)
35
+ adata = sc.read(**func_kwargs)
36
+ if not kwargs.get("first_column_obs", True):
37
+ adata = adata.T
38
+ else:
39
+ raise FileNotFoundError(f"{kwargs['filename']} does not exist")
12
40
 
13
- @io_mcp.tool()
14
- async def read(request: ReadModel):
15
- """
16
- Read data from 10X directory or various file formats (h5ad, 10x, text files, etc.).
17
- """
18
- try:
19
- result = await forward_request("io_read", request)
20
- if result is not None:
21
- return result
22
- kwargs = request.model_dump()
41
+ ads = get_ads()
42
+ if adinfo.sampleid is not None:
43
+ ads.active_id = adinfo.sampleid
44
+ else:
45
+ ads.active_id = f"adata{len(ads.adata_dic[adinfo.adtype])}"
46
+
47
+ adata.layers["counts"] = adata.X
48
+ adata.var_names_make_unique()
49
+ adata.obs_names_make_unique()
50
+ ads.set_adata(adata, adinfo=adinfo)
51
+ return generate_msg(adinfo, adata, ads)
52
+ except ToolError as e:
53
+ raise ToolError(e)
54
+ except Exception as e:
55
+ if hasattr(e, '__context__') and e.__context__:
56
+ raise ToolError(e.__context__)
57
+ else:
58
+ raise ToolError(e)
59
+ return _read
23
60
 
24
- file = Path(kwargs.get("filename", None))
25
- if file.is_dir():
26
- kwargs["path"] = kwargs["filename"]
27
- func_kwargs = filter_args(request, sc.read_10x_mtx)
28
- adata = sc.read_10x_mtx(kwargs["path"], **func_kwargs)
29
- elif file.is_file():
30
- func_kwargs = filter_args(request, sc.read)
31
- adata = sc.read(**func_kwargs)
32
- if not kwargs.get("first_column_obs", True):
33
- adata = adata.T
34
- else:
35
- raise FileNotFoundError(f"{kwargs['filename']} does not exist")
61
+ def _tool_write(self):
62
+ def _write(request: WriteModel, adinfo: self.AdataInfo=self.AdataInfo()):
63
+ """save adata into a file."""
64
+ try:
65
+ res = forward_request("io_write", request, adinfo)
66
+ if res is not None:
67
+ return res
68
+ ads = get_ads()
69
+ adata = ads.get_adata(adinfo=adinfo)
70
+ kwargs = request.model_dump()
71
+ sc.write(kwargs["filename"], adata)
72
+ return {"filename": kwargs["filename"], "msg": "success to save file"}
73
+ except ToolError as e:
74
+ raise ToolError(e)
75
+ except Exception as e:
76
+ if hasattr(e, '__context__') and e.__context__:
77
+ raise ToolError(e.__context__)
78
+ else:
79
+ raise ToolError(e)
80
+ return _write
36
81
 
37
- sampleid = kwargs.get("sampleid", None)
38
- adtype = kwargs.get("adtype", "exp")
39
- ads = get_ads()
40
- if sampleid is not None:
41
- ads.active_id = sampleid
42
- else:
43
- ads.active_id = f"adata{len(ads.adata_dic[adtype])}"
44
-
45
- adata.layers["counts"] = adata.X
46
- adata.var_names_make_unique()
47
- adata.obs_names_make_unique()
48
- ads.set_adata(adata, request=request)
49
- return generate_msg(request, adata, ads)
50
- except Exception as e:
51
- if hasattr(e, '__context__') and e.__context__:
52
- raise Exception(f"{str(e.__context__)}")
53
- else:
54
- raise e
55
82
 
56
-
57
- @io_mcp.tool()
58
- async def write(request: WriteModel):
59
- """save adata into a file.
60
- """
61
- try:
62
- result = await forward_request("io_write", request)
63
- if result is not None:
64
- return result
65
- ads = get_ads()
66
- adata = ads.get_adata(request=request)
67
- kwargs = request.model_dump()
68
- sc.write(kwargs["filename"], adata)
69
- return {"filename": kwargs["filename"], "msg": "success to save file"}
70
- except Exception as e:
71
- if hasattr(e, '__context__') and e.__context__:
72
- raise Exception(f"{str(e.__context__)}")
73
- else:
74
- raise e
83
+ # Create an instance of the class
84
+ io_mcp = ScanpyIOMCP().mcp