scmcp-shared 0.2.5__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scmcp_shared/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
1
 
2
- __version__ = "0.2.5"
2
+ __version__ = "0.3.5"
3
3
 
scmcp_shared/cli.py ADDED
@@ -0,0 +1,111 @@
1
+ import argparse
2
+ from typing import Optional, Union, Type, Dict, Callable
3
+ from enum import Enum
4
+ from .util import add_figure_route, set_env
5
+ import os
6
+
7
+
8
+ class MCPCLI:
9
+ """Base class for CLI applications with support for dynamic modules and parameters."""
10
+
11
+ def __init__(self, name: str, help_text: str, mcp=None, manager=None):
12
+ self.name = name
13
+ self.mcp = mcp
14
+ self.manager = manager
15
+ self.parser = argparse.ArgumentParser(
16
+ description=help_text,
17
+ prog=name
18
+ )
19
+ self.subcommands: Dict[str, tuple[argparse.ArgumentParser, Callable]] = {}
20
+ self._setup_commands()
21
+
22
+ def _setup_commands(self):
23
+ """Setup the main commands for the CLI."""
24
+ subparsers = self.parser.add_subparsers(dest='command', help='Available commands')
25
+ run_parser = subparsers.add_parser('run', help='Start the server with the specified configuration')
26
+ self._setup_run_command(run_parser)
27
+ self.subcommands['run'] = (run_parser, self._run_command)
28
+
29
+ def _setup_run_command(self, parser: argparse.ArgumentParser):
30
+ """Setup run command arguments."""
31
+ parser.add_argument('-t', '--transport', default="stdio",
32
+ choices=["stdio", "shttp", "sse"],
33
+ help='specify transport type')
34
+ parser.add_argument('-p', '--port', type=int, default=8000, help='transport port')
35
+ parser.add_argument('--host', default='127.0.0.1', help='transport host')
36
+ parser.add_argument('-f', '--forward', help='forward request to another server')
37
+ parser.add_argument('-wd', '--working-dir', default=".", help='working directory')
38
+ parser.add_argument('--log-file', help='log file path, use stdout if None')
39
+
40
+ def add_command(self, name: str, help_text: str, handler: Callable) -> argparse.ArgumentParser:
41
+ """add new subcommand
42
+
43
+ Args:
44
+ name: subcommand name
45
+ help_text: help text
46
+ handler: handler function
47
+
48
+ Returns:
49
+ ArgumentParser: parser for the subcommand
50
+ """
51
+ subparsers = self.parser._subparsers._group_actions[0]
52
+ parser = subparsers.add_parser(name, help=help_text)
53
+ self.subcommands[name] = (parser, handler)
54
+ return parser
55
+
56
+ def get_command_parser(self, name: str) -> Optional[argparse.ArgumentParser]:
57
+ """get the parser for the subcommand
58
+
59
+ Args:
60
+ name: subcommand name
61
+
62
+ Returns:
63
+ ArgumentParser: parser for the subcommand, return None if the subcommand does not exist
64
+ """
65
+ if name in self.subcommands:
66
+ return self.subcommands[name][0]
67
+ return None
68
+
69
+ def _run_command(self, args):
70
+ """Start the server with the specified configuration."""
71
+ os.chdir(args.working_dir)
72
+ if hasattr(args, 'module'):
73
+ if "all" in args.module:
74
+ modules = None
75
+ elif isinstance(args.module, list) and bool(args.module):
76
+ modules = args.module
77
+ else:
78
+ modules = None
79
+ if self.manager is not None:
80
+ self.mcp = self.manager(self.name, include_modules=modules).mcp
81
+ elif self.mcp is not None:
82
+ pass
83
+ else:
84
+ raise ValueError("No manager or mcp provided")
85
+ transport = args.transport
86
+ self.run_mcp(args.log_file, args.forward, transport, args.host, args.port)
87
+
88
+ def run_mcp(self, log_file, forward, transport, host, port):
89
+ set_env(log_file, forward, transport, host, port)
90
+ from .logging_config import setup_logger
91
+ setup_logger(log_file)
92
+ if transport == "stdio":
93
+ self.mcp.run()
94
+ elif transport in ["sse", "shttp"]:
95
+ transport = "streamable-http" if transport == "shttp" else transport
96
+ add_figure_route(self.mcp)
97
+ self.mcp.run(
98
+ transport=transport,
99
+ host=host,
100
+ port=port,
101
+ log_level="info"
102
+ )
103
+
104
+ def run(self):
105
+ """Run the CLI application."""
106
+ args = self.parser.parse_args()
107
+ if args.command in self.subcommands:
108
+ handler = self.subcommands[args.command][1]
109
+ handler(args)
110
+ else:
111
+ self.parser.print_help()
@@ -11,3 +11,12 @@ class AdataModel(BaseModel):
11
11
  model_config = ConfigDict(
12
12
  extra="ignore"
13
13
  )
14
+
15
+ class AdataInfo(BaseModel):
16
+ """Input schema for the adata tool."""
17
+ sampleid: str | None = Field(default=None, description="adata sampleid")
18
+ adtype: str = Field(default="exp", description="The input adata.X data type for preprocess/analysis/plotting")
19
+
20
+ model_config = ConfigDict(
21
+ extra="ignore"
22
+ )
scmcp_shared/schema/io.py CHANGED
@@ -22,10 +22,6 @@ class ReadModel(BaseModel):
22
22
  default=None,
23
23
  description="Name of sheet/table in hdf5 or Excel file."
24
24
  )
25
- ext: str = Field(
26
- default=None,
27
- description="Extension that indicates the file type. If None, uses extension of filename."
28
- )
29
25
  delimiter: str = Field(
30
26
  default=None,
31
27
  description="Delimiter that separates data within text file. If None, will split at arbitrary number of white spaces, which is different from enforcing splitting at any single white space."
scmcp_shared/schema/tl.py CHANGED
@@ -300,11 +300,6 @@ class LeidenModel(BaseModel):
300
300
  description="Which package's implementation to use."
301
301
  )
302
302
 
303
- clustering_args: Optional[Dict[str, Any]] = Field(
304
- default=None,
305
- description="Any further arguments to pass to the clustering algorithm."
306
- )
307
-
308
303
  @field_validator('resolution')
309
304
  def validate_resolution(cls, v: float) -> float:
310
305
  """Validate resolution is positive"""
@@ -943,6 +938,11 @@ class PCAModel(BaseModel):
943
938
  gt=0
944
939
  )
945
940
 
941
+ key_added: str = Field(
942
+ default="X_pca",
943
+ description="PCA embedding stored key in adata.obsm."
944
+ )
945
+
946
946
  @field_validator('n_comps', 'chunk_size')
947
947
  def validate_positive_integers(cls, v: Optional[int]) -> Optional[int]:
948
948
  """Validate positive integers"""
@@ -1,51 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
- from collections.abc import Iterable
4
- from typing import Any
5
-
6
- from .io import io_mcp
7
- from .util import ul_mcp
8
- from .pl import pl_mcp
9
- from .pp import pp_mcp
10
- from .tl import tl_mcp
11
-
12
-
13
-
14
- class AdataState:
15
- def __init__(self, add_adtypes=None):
16
- self.adata_dic = {"exp": {}, "activity": {}, "cnv": {}, "splicing": {}}
17
- if isinstance(add_adtypes, str):
18
- self.adata_dic[add_adtypes] = {}
19
- elif isinstance(add_adtypes, Iterable):
20
- self.adata_dic.update({adtype: {} for adtype in add_adtypes})
21
- self.active_id = None
22
- self.metadatWa = {}
23
- self.cr_kernel = {}
24
- self.cr_estimator = {}
25
-
26
- def get_adata(self, sampleid=None, adtype="exp", adinfo=None):
27
- if adinfo is not None:
28
- kwargs = adinfo.model_dump()
29
- sampleid = kwargs.get("sampleid", None)
30
- adtype = kwargs.get("adtype", "exp")
31
- try:
32
- if self.active_id is None:
33
- return None
34
- sampleid = sampleid or self.active_id
35
- return self.adata_dic[adtype][sampleid]
36
- except KeyError as e:
37
- raise KeyError(f"Key {e} not found in adata_dic[{adtype}].Please check the sampleid or adtype.")
38
- except Exception as e:
39
- raise Exception(f"fuck {e} {type(e)}")
40
-
41
- def set_adata(self, adata, sampleid=None, sdtype="exp", adinfo=None):
42
- if adinfo is not None:
43
- kwargs = adinfo.model_dump()
44
- sampleid = kwargs.get("sampleid", None)
45
- sdtype = kwargs.get("adtype", "exp")
46
- sampleid = sampleid or self.active_id
47
- if sdtype not in self.adata_dic:
48
- self.adata_dic[sdtype] = {}
49
- self.adata_dic[sdtype][sampleid] = adata
50
-
51
-
3
+ from collections.abc import Iterable, AsyncIterator
4
+ from typing import Any, Dict, List, Optional
5
+ from contextlib import asynccontextmanager
6
+ import asyncio
7
+
8
+ from .base import BaseMCP,AdataState,BaseMCPManager
9
+ from .io import ScanpyIOMCP, io_mcp
10
+ from .util import ScanpyUtilMCP
11
+ from .pl import ScanpyPlottingMCP
12
+ from .pp import ScanpyPreprocessingMCP
13
+ from .tl import ScanpyToolsMCP
@@ -0,0 +1,153 @@
1
+ import inspect
2
+ from fastmcp import FastMCP
3
+ from ..schema import AdataInfo
4
+ from ..util import filter_tools
5
+ from collections.abc import AsyncIterator
6
+ from contextlib import asynccontextmanager
7
+ import asyncio
8
+ from typing import Optional, List, Any, Iterable
9
+
10
+
11
+ class BaseMCP:
12
+ """Base class for all Scanpy MCP classes."""
13
+
14
+ def __init__(self, name: str, include_tools: list = None, exclude_tools: list = None, AdataInfo = AdataInfo):
15
+ """
16
+ Initialize BaseMCP with optional tool filtering.
17
+
18
+ Args:
19
+ name (str): Name of the MCP server
20
+ include_tools (list, optional): List of tool names to include. If None, all tools are included.
21
+ exclude_tools (list, optional): List of tool names to exclude. If None, no tools are excluded.
22
+ AdataInfo: The AdataInfo class to use for type annotations.
23
+ """
24
+ self.mcp = FastMCP(name)
25
+ self.include_tools = include_tools
26
+ self.exclude_tools = exclude_tools
27
+ self.AdataInfo = AdataInfo
28
+ self._register_tools()
29
+
30
+ def _register_tools(self):
31
+ """Register all tool methods with the FastMCP instance based on include/exclude filters"""
32
+ # Get all methods of the class
33
+ methods = inspect.getmembers(self, predicate=inspect.ismethod)
34
+
35
+ # Filter methods that start with _tool_
36
+ tool_methods = {
37
+ name[6:]: method # Remove '_tool_' prefix
38
+ for name, method in methods
39
+ if name.startswith('_tool_')
40
+ }
41
+
42
+ # Filter tools based on include/exclude lists
43
+ if self.include_tools is not None:
44
+ tool_methods = {k: v for k, v in tool_methods.items() if k in self.include_tools}
45
+
46
+ if self.exclude_tools is not None:
47
+ tool_methods = {k: v for k, v in tool_methods.items() if k not in self.exclude_tools}
48
+
49
+ # Register filtered tools
50
+ for tool_name, tool_method in tool_methods.items():
51
+ # Get the function returned by the tool method
52
+ tool_func = tool_method()
53
+ if tool_func is not None:
54
+ self.mcp.add_tool(tool_func, name=tool_name)
55
+
56
+
57
+ class AdataState:
58
+ def __init__(self, add_adtypes=None):
59
+ self.adata_dic = {"exp": {}, "activity": {}, "cnv": {}, "splicing": {}}
60
+ if isinstance(add_adtypes, str):
61
+ self.adata_dic[add_adtypes] = {}
62
+ elif isinstance(add_adtypes, Iterable):
63
+ self.adata_dic.update({adtype: {} for adtype in add_adtypes})
64
+ self.active_id = None
65
+ self.metadatWa = {}
66
+ self.cr_kernel = {}
67
+ self.cr_estimator = {}
68
+
69
+ def get_adata(self, sampleid=None, adtype="exp", adinfo=None):
70
+ if adinfo is not None:
71
+ kwargs = adinfo.model_dump()
72
+ sampleid = kwargs.get("sampleid", None)
73
+ adtype = kwargs.get("adtype", "exp")
74
+ try:
75
+ if self.active_id is None:
76
+ return None
77
+ sampleid = sampleid or self.active_id
78
+ return self.adata_dic[adtype][sampleid]
79
+ except KeyError as e:
80
+ raise KeyError(f"Key {e} not found in adata_dic[{adtype}].Please check the sampleid or adtype.")
81
+ except Exception as e:
82
+ raise Exception(f"fuck {e} {type(e)}")
83
+
84
+ def set_adata(self, adata, sampleid=None, sdtype="exp", adinfo=None):
85
+ if adinfo is not None:
86
+ kwargs = adinfo.model_dump()
87
+ sampleid = kwargs.get("sampleid", None)
88
+ sdtype = kwargs.get("adtype", "exp")
89
+ sampleid = sampleid or self.active_id
90
+ if sdtype not in self.adata_dic:
91
+ self.adata_dic[sdtype] = {}
92
+ self.adata_dic[sdtype][sampleid] = adata
93
+
94
+
95
+ class BaseMCPManager:
96
+ """Base class for MCP module management."""
97
+
98
+ def __init__(self,
99
+ name: str,
100
+ include_modules: Optional[List[str]] = None,
101
+ exclude_modules: Optional[List[str]] = None,
102
+ include_tools: Optional[List[str]] = None,
103
+ exclude_tools: Optional[List[str]] = None,
104
+ ):
105
+ """
106
+ Initialize BaseMCPManager with optional module filtering.
107
+
108
+ Args:
109
+ name (str): Name of the MCP server
110
+ include_modules (List[str], optional): List of module names to include. If None, all modules are included.
111
+ exclude_modules (List[str], optional): List of module names to exclude. If None, no modules are excluded.
112
+ include_tools (List[str], optional): List of tool names to include. If None, all tools are included.
113
+ exclude_tools (List[str], optional): List of tool names to exclude. If None, no tools are excluded.
114
+ """
115
+ self.ads = AdataState()
116
+ self.mcp = FastMCP(name, lifespan=self.adata_lifespan)
117
+ self.include_modules = include_modules
118
+ self.exclude_modules = exclude_modules
119
+ self.include_tools = include_tools
120
+ self.exclude_tools = exclude_tools
121
+ self.available_modules = {}
122
+ self._init_modules()
123
+ self._register_modules()
124
+
125
+ def _init_modules(self):
126
+ """Initialize available modules. To be implemented by subclasses."""
127
+ raise NotImplementedError("Subclasses must implement _init_modules")
128
+
129
+ def _register_modules(self):
130
+ """Register modules based on include/exclude filters."""
131
+ # Filter modules based on include/exclude lists
132
+ if self.include_modules is not None:
133
+ self.available_modules = {k: v for k, v in self.available_modules.items() if k in self.include_modules}
134
+
135
+ if self.exclude_modules is not None:
136
+ self.available_modules = {k: v for k, v in self.available_modules.items() if k not in self.exclude_modules}
137
+
138
+ # Register each module
139
+ for module_name, mcpi in self.available_modules.items():
140
+ if isinstance(mcpi, FastMCP):
141
+ if self.include_tools is not None and module_name in self.include_tools:
142
+ mcpi = filter_tools(mcpi, include_tools= self.include_tools[module_name])
143
+ if self.exclude_tools is not None and module_name in self.exclude_tools:
144
+ mcpi = filter_tools(mcpi, exclude_tools=self.exclude_tools[module_name])
145
+
146
+ asyncio.run(self.mcp.import_server(module_name, mcpi))
147
+ else:
148
+ asyncio.run(self.mcp.import_server(module_name, mcpi().mcp))
149
+
150
+ @asynccontextmanager
151
+ async def adata_lifespan(self, server: FastMCP) -> AsyncIterator[Any]:
152
+ """Context manager for AdataState lifecycle."""
153
+ yield self.ads
scmcp_shared/server/io.py CHANGED
@@ -2,83 +2,83 @@ import os
2
2
  import inspect
3
3
  from pathlib import Path
4
4
  import scanpy as sc
5
- from fastmcp import FastMCP , Context
5
+ from fastmcp import FastMCP, Context
6
6
  from fastmcp.exceptions import ToolError
7
- from ..schema import AdataModel
7
+ from ..schema import AdataInfo
8
8
  from ..schema.io import *
9
9
  from ..util import filter_args, forward_request, get_ads, generate_msg
10
+ from .base import BaseMCP
10
11
 
11
12
 
12
- io_mcp = FastMCP("SCMCP-IO-Server")
13
+ class ScanpyIOMCP(BaseMCP):
14
+ def __init__(self, include_tools: list = None, exclude_tools: list = None, AdataInfo = AdataInfo):
15
+ """Initialize ScanpyIOMCP with optional tool filtering."""
16
+ super().__init__("SCMCP-IO-Server", include_tools, exclude_tools, AdataInfo)
13
17
 
18
+ def _tool_read(self):
19
+ def _read(request: ReadModel, adinfo: self.AdataInfo=self.AdataInfo()):
20
+ """
21
+ Read data from 10X directory or various file formats (h5ad, 10x, text files, etc.).
22
+ """
23
+ try:
24
+ res = forward_request("io_read", request, adinfo)
25
+ if res is not None:
26
+ return res
27
+ kwargs = request.model_dump()
28
+ file = Path(kwargs.get("filename", None))
29
+ if file.is_dir():
30
+ kwargs["path"] = kwargs["filename"]
31
+ func_kwargs = filter_args(request, sc.read_10x_mtx)
32
+ adata = sc.read_10x_mtx(kwargs["path"], **func_kwargs)
33
+ elif file.is_file():
34
+ func_kwargs = filter_args(request, sc.read)
35
+ adata = sc.read(**func_kwargs)
36
+ if not kwargs.get("first_column_obs", True):
37
+ adata = adata.T
38
+ else:
39
+ raise FileNotFoundError(f"{kwargs['filename']} does not exist")
14
40
 
15
- @io_mcp.tool()
16
- async def read(
17
- request: ReadModel,
18
- adinfo: AdataModel = AdataModel()
19
- ):
20
- """
21
- Read data from 10X directory or various file formats (h5ad, 10x, text files, etc.).
22
- """
23
- try:
24
- result = await forward_request("io_read", request, adinfo)
25
- if result is not None:
26
- return result
27
- kwargs = request.model_dump()
41
+ ads = get_ads()
42
+ if adinfo.sampleid is not None:
43
+ ads.active_id = adinfo.sampleid
44
+ else:
45
+ ads.active_id = f"adata{len(ads.adata_dic[adinfo.adtype])}"
46
+
47
+ adata.layers["counts"] = adata.X
48
+ adata.var_names_make_unique()
49
+ adata.obs_names_make_unique()
50
+ ads.set_adata(adata, adinfo=adinfo)
51
+ return generate_msg(adinfo, adata, ads)
52
+ except ToolError as e:
53
+ raise ToolError(e)
54
+ except Exception as e:
55
+ if hasattr(e, '__context__') and e.__context__:
56
+ raise ToolError(e.__context__)
57
+ else:
58
+ raise ToolError(e)
59
+ return _read
28
60
 
29
- file = Path(kwargs.get("filename", None))
30
- if file.is_dir():
31
- kwargs["path"] = kwargs["filename"]
32
- func_kwargs = filter_args(request, sc.read_10x_mtx)
33
- adata = sc.read_10x_mtx(kwargs["path"], **func_kwargs)
34
- elif file.is_file():
35
- func_kwargs = filter_args(request, sc.read)
36
- adata = sc.read(**func_kwargs)
37
- if not kwargs.get("first_column_obs", True):
38
- adata = adata.T
39
- else:
40
- raise FileNotFoundError(f"{kwargs['filename']} does not exist")
61
+ def _tool_write(self):
62
+ def _write(request: WriteModel, adinfo: self.AdataInfo=self.AdataInfo()):
63
+ """save adata into a file."""
64
+ try:
65
+ res = forward_request("io_write", request, adinfo)
66
+ if res is not None:
67
+ return res
68
+ ads = get_ads()
69
+ adata = ads.get_adata(adinfo=adinfo)
70
+ kwargs = request.model_dump()
71
+ sc.write(kwargs["filename"], adata)
72
+ return {"filename": kwargs["filename"], "msg": "success to save file"}
73
+ except ToolError as e:
74
+ raise ToolError(e)
75
+ except Exception as e:
76
+ if hasattr(e, '__context__') and e.__context__:
77
+ raise ToolError(e.__context__)
78
+ else:
79
+ raise ToolError(e)
80
+ return _write
41
81
 
42
- ads = get_ads()
43
- if adinfo.sampleid is not None:
44
- ads.active_id = adinfo.sampleid
45
- else:
46
- ads.active_id = f"adata{len(ads.adata_dic[adinfo.adtype])}"
47
-
48
- adata.layers["counts"] = adata.X
49
- adata.var_names_make_unique()
50
- adata.obs_names_make_unique()
51
- ads.set_adata(adata, adinfo=adinfo)
52
- return generate_msg(adinfo, adata, ads)
53
- except ToolError as e:
54
- raise ToolError(e)
55
- except Exception as e:
56
- if hasattr(e, '__context__') and e.__context__:
57
- raise ToolError(e.__context__)
58
- else:
59
- raise ToolError(e)
60
82
 
61
-
62
- @io_mcp.tool()
63
- async def write(
64
- request: WriteModel,
65
- adinfo: AdataModel = AdataModel()
66
- ):
67
- """save adata into a file.
68
- """
69
- try:
70
- result = await forward_request("io_write", request, adinfo)
71
- if result is not None:
72
- return result
73
- ads = get_ads()
74
- adata = ads.get_adata(adinfo=adinfo)
75
- kwargs = request.model_dump()
76
- sc.write(kwargs["filename"], adata)
77
- return {"filename": kwargs["filename"], "msg": "success to save file"}
78
- except ToolError as e:
79
- raise ToolError(e)
80
- except Exception as e:
81
- if hasattr(e, '__context__') and e.__context__:
82
- raise ToolError(e.__context__)
83
- else:
84
- raise ToolError(e)
83
+ # Create an instance of the class
84
+ io_mcp = ScanpyIOMCP().mcp