sufy-mcp-server 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_server/__init__.py +11 -0
- mcp_server/application.py +53 -0
- mcp_server/config/__init__.py +0 -0
- mcp_server/config/config.py +52 -0
- mcp_server/consts/__init__.py +0 -0
- mcp_server/consts/consts.py +1 -0
- mcp_server/core/__init__.py +17 -0
- mcp_server/core/media_processing/__init__.py +11 -0
- mcp_server/core/media_processing/tools.py +179 -0
- mcp_server/core/media_processing/utils.py +157 -0
- mcp_server/core/storage/__init__.py +13 -0
- mcp_server/core/storage/resource.py +158 -0
- mcp_server/core/storage/storage.py +203 -0
- mcp_server/core/storage/tools.py +154 -0
- mcp_server/core/version/__init__.py +9 -0
- mcp_server/core/version/tools.py +31 -0
- mcp_server/core/version/version.py +2 -0
- mcp_server/resource/__init__.py +0 -0
- mcp_server/resource/resource.py +61 -0
- mcp_server/server.py +72 -0
- mcp_server/tools/__init__.py +0 -0
- mcp_server/tools/tools.py +138 -0
- sufy_mcp_server-1.0.0.dist-info/METADATA +13 -0
- sufy_mcp_server-1.0.0.dist-info/RECORD +27 -0
- sufy_mcp_server-1.0.0.dist-info/WHEEL +4 -0
- sufy_mcp_server-1.0.0.dist-info/entry_points.txt +2 -0
- sufy_mcp_server-1.0.0.dist-info/licenses/LICENSE.txt +21 -0
@@ -0,0 +1,203 @@
|
|
1
|
+
import aioboto3
|
2
|
+
import asyncio
|
3
|
+
import logging
|
4
|
+
|
5
|
+
from typing import List, Dict, Any, Optional
|
6
|
+
from botocore.config import Config as S3Config
|
7
|
+
|
8
|
+
from ...config import config
|
9
|
+
from ...consts import consts
|
10
|
+
|
11
|
+
logger = logging.getLogger(consts.LOGGER_NAME)
|
12
|
+
|
13
|
+
|
14
|
+
class StorageService:
|
15
|
+
|
16
|
+
def __init__(self, cfg: config.Config = None):
|
17
|
+
# Configure boto3 with retries and timeouts
|
18
|
+
self.s3_config = S3Config(
|
19
|
+
retries=dict(max_attempts=3, mode="adaptive"),
|
20
|
+
connect_timeout=5,
|
21
|
+
read_timeout=60,
|
22
|
+
max_pool_connections=50,
|
23
|
+
)
|
24
|
+
self.config = cfg
|
25
|
+
self.s3_session = aioboto3.Session()
|
26
|
+
|
27
|
+
async def get_object_url(
|
28
|
+
self, bucket: str, key: str, disable_ssl: bool = False, expires: int = 3600
|
29
|
+
) -> list[dict[str:Any]]:
|
30
|
+
async with self.s3_session.client(
|
31
|
+
"s3",
|
32
|
+
aws_access_key_id=self.config.access_key,
|
33
|
+
aws_secret_access_key=self.config.secret_key,
|
34
|
+
endpoint_url=self.config.endpoint_url,
|
35
|
+
region_name=self.config.region_name,
|
36
|
+
) as s3:
|
37
|
+
object_url = await s3.generate_presigned_url(
|
38
|
+
'get_object',
|
39
|
+
Params={
|
40
|
+
'Bucket': bucket,
|
41
|
+
'Key': key
|
42
|
+
},
|
43
|
+
ExpiresIn=expires
|
44
|
+
)
|
45
|
+
|
46
|
+
# 将 https 替换为 http
|
47
|
+
if disable_ssl:
|
48
|
+
object_url = object_url.replace('https://', 'http://')
|
49
|
+
|
50
|
+
return [{
|
51
|
+
"object_url": object_url,
|
52
|
+
}]
|
53
|
+
|
54
|
+
async def list_buckets(self, prefix: Optional[str] = None) -> List[dict]:
|
55
|
+
max_buckets = 50
|
56
|
+
if not self.config.buckets or len(self.config.buckets) == 0:
|
57
|
+
return []
|
58
|
+
|
59
|
+
async with self.s3_session.client(
|
60
|
+
"s3",
|
61
|
+
aws_access_key_id=self.config.access_key,
|
62
|
+
aws_secret_access_key=self.config.secret_key,
|
63
|
+
endpoint_url=self.config.endpoint_url,
|
64
|
+
region_name=self.config.region_name,
|
65
|
+
) as s3:
|
66
|
+
response = await s3.list_buckets()
|
67
|
+
all_buckets = response.get("Buckets", [])
|
68
|
+
|
69
|
+
configured_bucket_list = [
|
70
|
+
bucket
|
71
|
+
for bucket in all_buckets
|
72
|
+
if bucket["Name"] in self.config.buckets
|
73
|
+
]
|
74
|
+
|
75
|
+
if prefix:
|
76
|
+
configured_bucket_list = [
|
77
|
+
b for b in configured_bucket_list if b["Name"] > prefix
|
78
|
+
]
|
79
|
+
|
80
|
+
return configured_bucket_list[:max_buckets]
|
81
|
+
|
82
|
+
async def list_objects(
|
83
|
+
self, bucket: str, prefix: str = "", max_keys: int = 20, start_after: str = ""
|
84
|
+
) -> List[dict]:
|
85
|
+
#
|
86
|
+
if self.config.buckets and bucket not in self.config.buckets:
|
87
|
+
logger.warning(f"Bucket {bucket} not in configured bucket list")
|
88
|
+
return []
|
89
|
+
|
90
|
+
if isinstance(max_keys, str):
|
91
|
+
max_keys = int(max_keys)
|
92
|
+
|
93
|
+
if max_keys > 100:
|
94
|
+
max_keys = 100
|
95
|
+
|
96
|
+
async with self.s3_session.client(
|
97
|
+
"s3",
|
98
|
+
aws_access_key_id=self.config.access_key,
|
99
|
+
aws_secret_access_key=self.config.secret_key,
|
100
|
+
endpoint_url=self.config.endpoint_url,
|
101
|
+
region_name=self.config.region_name,
|
102
|
+
) as s3:
|
103
|
+
response = await s3.list_objects_v2(
|
104
|
+
Bucket=bucket,
|
105
|
+
Prefix=prefix,
|
106
|
+
MaxKeys=max_keys,
|
107
|
+
StartAfter=start_after,
|
108
|
+
)
|
109
|
+
return response.get("Contents", [])
|
110
|
+
|
111
|
+
async def get_object(
|
112
|
+
self, bucket: str, key: str, max_retries: int = 3
|
113
|
+
) -> Dict[str, Any]:
|
114
|
+
if self.config.buckets and bucket not in self.config.buckets:
|
115
|
+
logger.warning(f"Bucket {bucket} not in configured bucket list")
|
116
|
+
return {}
|
117
|
+
|
118
|
+
attempt = 0
|
119
|
+
last_exception = None
|
120
|
+
|
121
|
+
while attempt < max_retries:
|
122
|
+
try:
|
123
|
+
async with self.s3_session.client(
|
124
|
+
"s3",
|
125
|
+
aws_access_key_id=self.config.access_key,
|
126
|
+
aws_secret_access_key=self.config.secret_key,
|
127
|
+
endpoint_url=self.config.endpoint_url,
|
128
|
+
region_name=self.config.region_name,
|
129
|
+
config=self.s3_config,
|
130
|
+
) as s3:
|
131
|
+
# Get the object and its stream
|
132
|
+
response = await s3.get_object(Bucket=bucket, Key=key)
|
133
|
+
stream = response["Body"]
|
134
|
+
|
135
|
+
# Read the entire stream in chunks
|
136
|
+
chunks = []
|
137
|
+
async for chunk in stream:
|
138
|
+
chunks.append(chunk)
|
139
|
+
|
140
|
+
# Replace the stream with the complete data
|
141
|
+
response["Body"] = b"".join(chunks)
|
142
|
+
return response
|
143
|
+
|
144
|
+
except Exception as e:
|
145
|
+
last_exception = e
|
146
|
+
if "NoSuchKey" in str(e):
|
147
|
+
raise
|
148
|
+
|
149
|
+
attempt += 1
|
150
|
+
if attempt < max_retries:
|
151
|
+
wait_time = 2 ** attempt
|
152
|
+
logger.warning(
|
153
|
+
f"Attempt {attempt} failed, retrying in {wait_time} seconds: {str(e)}"
|
154
|
+
)
|
155
|
+
await asyncio.sleep(wait_time)
|
156
|
+
continue
|
157
|
+
|
158
|
+
raise last_exception or Exception("Failed to get object after all retries")
|
159
|
+
|
160
|
+
def is_text_file(self, key: str) -> bool:
|
161
|
+
"""Determine if a file is text-based by its extension"""
|
162
|
+
text_extensions = {
|
163
|
+
".txt",
|
164
|
+
".log",
|
165
|
+
".json",
|
166
|
+
".xml",
|
167
|
+
".yml",
|
168
|
+
".yaml",
|
169
|
+
".md",
|
170
|
+
".csv",
|
171
|
+
".ini",
|
172
|
+
".conf",
|
173
|
+
".py",
|
174
|
+
".js",
|
175
|
+
".html",
|
176
|
+
".css",
|
177
|
+
".sh",
|
178
|
+
".bash",
|
179
|
+
".cfg",
|
180
|
+
".properties",
|
181
|
+
}
|
182
|
+
return any(key.lower().endswith(ext) for ext in text_extensions)
|
183
|
+
|
184
|
+
def is_image_file(self, key: str) -> bool:
|
185
|
+
"""Determine if a file is text-based by its extension"""
|
186
|
+
text_extensions = {
|
187
|
+
".png",
|
188
|
+
".jpeg",
|
189
|
+
".jpg",
|
190
|
+
".gif",
|
191
|
+
".bmp",
|
192
|
+
".tiff",
|
193
|
+
".svg",
|
194
|
+
".webp",
|
195
|
+
}
|
196
|
+
return any(key.lower().endswith(ext) for ext in text_extensions)
|
197
|
+
|
198
|
+
def is_markdown_file(self, key: str) -> bool:
|
199
|
+
"""Determine if a file is text-based by its extension"""
|
200
|
+
text_extensions = {
|
201
|
+
".md",
|
202
|
+
}
|
203
|
+
return any(key.lower().endswith(ext) for ext in text_extensions)
|
@@ -0,0 +1,154 @@
|
|
1
|
+
import logging
|
2
|
+
import base64
|
3
|
+
|
4
|
+
from mcp import types
|
5
|
+
from mcp.types import ImageContent, TextContent
|
6
|
+
|
7
|
+
from .storage import StorageService
|
8
|
+
from ...consts import consts
|
9
|
+
from ...tools import tools
|
10
|
+
|
11
|
+
logger = logging.getLogger(consts.LOGGER_NAME)
|
12
|
+
|
13
|
+
_BUCKET_DESC = """When you use this operation with a directory bucket, you must use virtual-hosted-style requests in the format ${bucket_name}.mos.${region_id}.sufybkt.com. Path-style requests are not supported. Directory bucket names must be unique in the chosen Availability Zone.
|
14
|
+
"""
|
15
|
+
|
16
|
+
class _ToolImpl:
|
17
|
+
def __init__(self, storage: StorageService):
|
18
|
+
self.storage = storage
|
19
|
+
|
20
|
+
@tools.tool_meta(
|
21
|
+
types.Tool(
|
22
|
+
name="ListBuckets",
|
23
|
+
description="Returns a list of all buckets of config. To grant IAM permission to use this operation, you must add the s3:ListAllMyBuckets policy action.",
|
24
|
+
inputSchema={
|
25
|
+
"type": "object",
|
26
|
+
"properties": {
|
27
|
+
"prefix": {
|
28
|
+
"type": "string",
|
29
|
+
"description": "Bucket prefix. The listed Buckets will be filtered based on this prefix, and only those matching the prefix will be output.",
|
30
|
+
},
|
31
|
+
},
|
32
|
+
"required": [],
|
33
|
+
},
|
34
|
+
)
|
35
|
+
)
|
36
|
+
async def list_buckets(self, **kwargs) -> list[types.TextContent]:
|
37
|
+
buckets = await self.storage.list_buckets(**kwargs)
|
38
|
+
return [types.TextContent(type="text", text=str(buckets))]
|
39
|
+
|
40
|
+
@tools.tool_meta(
|
41
|
+
types.Tool(
|
42
|
+
name="ListObjects",
|
43
|
+
description="Each request will return some or all (up to 100) objects in the bucket. You can use request parameters as selection criteria to return some objects in the bucket. If you want to continue listing, set start_after to the key of the last file in the last listing result so that you can list new content. To get a list of buckets, see ListBuckets.",
|
44
|
+
inputSchema={
|
45
|
+
"type": "object",
|
46
|
+
"properties": {
|
47
|
+
"bucket": {
|
48
|
+
"type": "string",
|
49
|
+
"description": _BUCKET_DESC,
|
50
|
+
},
|
51
|
+
"max_keys": {
|
52
|
+
"type": "integer",
|
53
|
+
"description": "Sets the maximum number of keys returned in the response. By default, the action returns up to 20 key names. The response might contain fewer keys but will never contain more.",
|
54
|
+
},
|
55
|
+
"prefix": {
|
56
|
+
"type": "string",
|
57
|
+
"description": "Limits the response to keys that begin with the specified prefix.",
|
58
|
+
},
|
59
|
+
"start_after": {
|
60
|
+
"type": "string",
|
61
|
+
"description": "start_after is where you want S3 to start listing from. S3 starts listing after this specified key. start_after can be any key in the bucket.",
|
62
|
+
},
|
63
|
+
},
|
64
|
+
"required": ["bucket"],
|
65
|
+
},
|
66
|
+
)
|
67
|
+
)
|
68
|
+
async def list_objects(self, **kwargs) -> list[types.TextContent]:
|
69
|
+
objects = await self.storage.list_objects(**kwargs)
|
70
|
+
return [types.TextContent(type="text", text=str(objects))]
|
71
|
+
|
72
|
+
@tools.tool_meta(
|
73
|
+
types.Tool(
|
74
|
+
name="GetObject",
|
75
|
+
description="Retrieves an object from bucket. In the GetObject request, specify the full key name for the object. Path-style requests are not supported.",
|
76
|
+
inputSchema={
|
77
|
+
"type": "object",
|
78
|
+
"properties": {
|
79
|
+
"bucket": {
|
80
|
+
"type": "string",
|
81
|
+
"description": _BUCKET_DESC,
|
82
|
+
},
|
83
|
+
"key": {
|
84
|
+
"type": "string",
|
85
|
+
"description": "Key of the object to get. Length Constraints: Minimum length of 1.",
|
86
|
+
},
|
87
|
+
},
|
88
|
+
"required": ["bucket", "key"],
|
89
|
+
},
|
90
|
+
)
|
91
|
+
)
|
92
|
+
async def get_object(self, **kwargs) -> list[ImageContent] | list[TextContent]:
|
93
|
+
response = await self.storage.get_object(**kwargs)
|
94
|
+
file_content = response["Body"]
|
95
|
+
content_type = response.get("ContentType", "application/octet-stream")
|
96
|
+
|
97
|
+
# 根据内容类型返回不同的响应
|
98
|
+
if content_type.startswith("image/"):
|
99
|
+
base64_data = base64.b64encode(file_content).decode("utf-8")
|
100
|
+
return [
|
101
|
+
types.ImageContent(
|
102
|
+
type="image", data=base64_data, mimeType=content_type
|
103
|
+
)
|
104
|
+
]
|
105
|
+
|
106
|
+
if isinstance(file_content, bytes):
|
107
|
+
text_content = file_content.decode("utf-8")
|
108
|
+
else:
|
109
|
+
text_content = str(file_content)
|
110
|
+
return [types.TextContent(type="text", text=text_content)]
|
111
|
+
|
112
|
+
@tools.tool_meta(
|
113
|
+
types.Tool(
|
114
|
+
name="GetObjectURL",
|
115
|
+
description="Get the object download URL",
|
116
|
+
inputSchema={
|
117
|
+
"type": "object",
|
118
|
+
"properties": {
|
119
|
+
"bucket": {
|
120
|
+
"type": "string",
|
121
|
+
"description": _BUCKET_DESC,
|
122
|
+
},
|
123
|
+
"key": {
|
124
|
+
"type": "string",
|
125
|
+
"description": "Key of the object to get. Length Constraints: Minimum length of 1.",
|
126
|
+
},
|
127
|
+
"disable_ssl": {
|
128
|
+
"type": "boolean",
|
129
|
+
"description": "Whether to disable SSL. By default, it is not disabled (HTTP protocol is used). If disabled, the HTTP protocol will be used.",
|
130
|
+
},
|
131
|
+
"expires": {
|
132
|
+
"type": "integer",
|
133
|
+
"description": "Token expiration time (in seconds) for download links. When the bucket is private, a signed Token is required to access file objects. Public buckets do not require Token signing.",
|
134
|
+
},
|
135
|
+
},
|
136
|
+
"required": ["bucket", "key"],
|
137
|
+
},
|
138
|
+
)
|
139
|
+
)
|
140
|
+
async def get_object_url(self, **kwargs) -> list[types.TextContent]:
|
141
|
+
urls = await self.storage.get_object_url(**kwargs)
|
142
|
+
return [types.TextContent(type="text", text=str(urls))]
|
143
|
+
|
144
|
+
|
145
|
+
def register_tools(storage: StorageService):
|
146
|
+
tool_impl = _ToolImpl(storage)
|
147
|
+
tools.auto_register_tools(
|
148
|
+
[
|
149
|
+
tool_impl.list_buckets,
|
150
|
+
tool_impl.list_objects,
|
151
|
+
tool_impl.get_object,
|
152
|
+
tool_impl.get_object_url,
|
153
|
+
]
|
154
|
+
)
|
@@ -0,0 +1,31 @@
|
|
1
|
+
|
2
|
+
from mcp import types
|
3
|
+
|
4
|
+
from . import version
|
5
|
+
from ...tools import tools
|
6
|
+
|
7
|
+
|
8
|
+
class _ToolImpl:
|
9
|
+
def __init__(self):
|
10
|
+
pass
|
11
|
+
|
12
|
+
@tools.tool_meta(
|
13
|
+
types.Tool(
|
14
|
+
name="Version",
|
15
|
+
description="Sufy MCP Server version info.",
|
16
|
+
inputSchema={
|
17
|
+
"type": "object",
|
18
|
+
"required": [],
|
19
|
+
}
|
20
|
+
)
|
21
|
+
)
|
22
|
+
def version(self, **kwargs) -> list[types.TextContent]:
|
23
|
+
return [types.TextContent(type="text", text=version.VERSION)]
|
24
|
+
|
25
|
+
def register_tools():
|
26
|
+
tool_impl = _ToolImpl()
|
27
|
+
tools.auto_register_tools(
|
28
|
+
[
|
29
|
+
tool_impl.version,
|
30
|
+
]
|
31
|
+
)
|
File without changes
|
@@ -0,0 +1,61 @@
|
|
1
|
+
import logging
|
2
|
+
from abc import abstractmethod
|
3
|
+
from typing import Dict, AsyncGenerator, Iterable
|
4
|
+
|
5
|
+
from mcp import types
|
6
|
+
from mcp.server.lowlevel import helper_types as low_types
|
7
|
+
|
8
|
+
from ..consts import consts
|
9
|
+
|
10
|
+
logger = logging.getLogger(consts.LOGGER_NAME)
|
11
|
+
|
12
|
+
|
13
|
+
class ResourceProvider:
|
14
|
+
def __init__(self, scheme: str):
|
15
|
+
self.scheme = scheme
|
16
|
+
|
17
|
+
@abstractmethod
|
18
|
+
async def list_resources(self, **kwargs) -> list[types.Resource]:
|
19
|
+
pass
|
20
|
+
|
21
|
+
@abstractmethod
|
22
|
+
async def read_resource(self, uri: types.AnyUrl, **kwargs) -> [str | bytes | Iterable[low_types.ReadResourceContents]]:
|
23
|
+
pass
|
24
|
+
|
25
|
+
|
26
|
+
_all_resource_providers: Dict[str, ResourceProvider] = {}
|
27
|
+
|
28
|
+
|
29
|
+
async def list_resources(**kwargs) -> AsyncGenerator[types.Resource, None]:
|
30
|
+
if len(_all_resource_providers) == 0:
|
31
|
+
return
|
32
|
+
|
33
|
+
for provider in _all_resource_providers.values():
|
34
|
+
resources = await provider.list_resources(**kwargs)
|
35
|
+
for resource in resources:
|
36
|
+
yield resource
|
37
|
+
return
|
38
|
+
|
39
|
+
|
40
|
+
async def read_resource(uri: types.AnyUrl, **kwargs) -> [str | bytes | Iterable[low_types.ReadResourceContents]]:
|
41
|
+
if len(_all_resource_providers) == 0:
|
42
|
+
return ""
|
43
|
+
|
44
|
+
provider = _all_resource_providers.get(uri.scheme)
|
45
|
+
return await provider.read_resource(uri=uri, **kwargs)
|
46
|
+
|
47
|
+
|
48
|
+
def register_resource_provider(provider: ResourceProvider):
|
49
|
+
"""注册工具,禁止重复名称"""
|
50
|
+
name = provider.scheme
|
51
|
+
if name in _all_resource_providers:
|
52
|
+
raise ValueError(f"Resource Provider {name} already registered")
|
53
|
+
_all_resource_providers[name] = provider
|
54
|
+
|
55
|
+
|
56
|
+
__all__ = [
|
57
|
+
"ResourceProvider",
|
58
|
+
"list_resources",
|
59
|
+
"read_resource",
|
60
|
+
"register_resource_provider",
|
61
|
+
]
|
mcp_server/server.py
ADDED
@@ -0,0 +1,72 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
|
4
|
+
import anyio
|
5
|
+
import click
|
6
|
+
|
7
|
+
from . import application
|
8
|
+
from .consts import consts
|
9
|
+
|
10
|
+
logger = logging.getLogger(consts.LOGGER_NAME)
|
11
|
+
logger.info("Starting MCP server")
|
12
|
+
|
13
|
+
SAMPLE_RESOURCES = {
|
14
|
+
"greeting": "Hello! This is a MCP Server for Sufy.",
|
15
|
+
"help": "This server provides a few resources and tools for Sufy.",
|
16
|
+
"about": "This is the MCP server implementation.",
|
17
|
+
}
|
18
|
+
|
19
|
+
|
20
|
+
@click.command()
|
21
|
+
@click.option("--port", default=8000, help="Port to listen on for SSE")
|
22
|
+
@click.option(
|
23
|
+
"--transport",
|
24
|
+
type=click.Choice(["stdio", "sse"]),
|
25
|
+
default="stdio",
|
26
|
+
help="Transport type",
|
27
|
+
)
|
28
|
+
def main(port: int, transport: str) -> int:
|
29
|
+
app = application.server
|
30
|
+
|
31
|
+
if transport == "sse":
|
32
|
+
from mcp.server.sse import SseServerTransport
|
33
|
+
from starlette.applications import Starlette
|
34
|
+
from starlette.routing import Mount, Route
|
35
|
+
|
36
|
+
sse = SseServerTransport("/messages/")
|
37
|
+
|
38
|
+
async def handle_sse(request):
|
39
|
+
async with sse.connect_sse(
|
40
|
+
request.scope, request.receive, request._send
|
41
|
+
) as streams:
|
42
|
+
await app.run(
|
43
|
+
streams[0], streams[1], app.create_initialization_options()
|
44
|
+
)
|
45
|
+
|
46
|
+
starlette_app = Starlette(
|
47
|
+
debug=True,
|
48
|
+
routes=[
|
49
|
+
Route("/sse", endpoint=handle_sse),
|
50
|
+
Mount("/messages/", app=sse.handle_post_message),
|
51
|
+
],
|
52
|
+
)
|
53
|
+
|
54
|
+
import uvicorn
|
55
|
+
|
56
|
+
uvicorn.run(starlette_app, host="0.0.0.0", port=port)
|
57
|
+
else:
|
58
|
+
from mcp.server.stdio import stdio_server
|
59
|
+
|
60
|
+
async def arun():
|
61
|
+
async with stdio_server() as streams:
|
62
|
+
await app.run(
|
63
|
+
streams[0], streams[1], app.create_initialization_options()
|
64
|
+
)
|
65
|
+
|
66
|
+
anyio.run(arun)
|
67
|
+
|
68
|
+
return 0
|
69
|
+
|
70
|
+
|
71
|
+
if __name__ == "__main__":
|
72
|
+
asyncio.run(main())
|
File without changes
|
@@ -0,0 +1,138 @@
|
|
1
|
+
import functools
|
2
|
+
import inspect
|
3
|
+
import asyncio
|
4
|
+
import logging
|
5
|
+
import fastjsonschema
|
6
|
+
|
7
|
+
from typing import List, Dict, Callable, Optional, Union, Awaitable
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from mcp import types
|
10
|
+
from .. import consts
|
11
|
+
|
12
|
+
logger = logging.getLogger(consts.LOGGER_NAME)
|
13
|
+
|
14
|
+
ToolResult = list[types.TextContent | types.ImageContent | types.EmbeddedResource]
|
15
|
+
ToolFunc = Callable[..., ToolResult]
|
16
|
+
AsyncToolFunc = Callable[..., Awaitable[ToolResult]]
|
17
|
+
|
18
|
+
|
19
|
+
@dataclass
|
20
|
+
class _ToolEntry:
|
21
|
+
meta: types.Tool
|
22
|
+
func: Optional[ToolFunc]
|
23
|
+
async_func: Optional[AsyncToolFunc]
|
24
|
+
input_validator: Optional[Callable[..., None]]
|
25
|
+
|
26
|
+
|
27
|
+
# 初始化全局工具字典
|
28
|
+
_all_tools: Dict[str, _ToolEntry] = {}
|
29
|
+
|
30
|
+
|
31
|
+
def all_tools() -> List[types.Tool]:
|
32
|
+
"""获取所有工具"""
|
33
|
+
if not _all_tools:
|
34
|
+
raise ValueError("No tools registered")
|
35
|
+
return list(map(lambda x: x.meta, _all_tools.values()))
|
36
|
+
|
37
|
+
|
38
|
+
def register_tool(
|
39
|
+
meta: types.Tool,
|
40
|
+
func: Union[ToolFunc, AsyncToolFunc],
|
41
|
+
) -> None:
|
42
|
+
"""注册工具,禁止重复名称"""
|
43
|
+
name = meta.name
|
44
|
+
if name in _all_tools:
|
45
|
+
raise ValueError(f"Tool {name} already registered")
|
46
|
+
|
47
|
+
# 判断是否为异步函数
|
48
|
+
if inspect.iscoroutinefunction(func):
|
49
|
+
async_func = func
|
50
|
+
func = None
|
51
|
+
else:
|
52
|
+
async_func = None
|
53
|
+
entry = _ToolEntry(
|
54
|
+
meta=meta,
|
55
|
+
func=func,
|
56
|
+
async_func=async_func,
|
57
|
+
input_validator=fastjsonschema.compile(meta.inputSchema),
|
58
|
+
)
|
59
|
+
_all_tools[name] = entry
|
60
|
+
|
61
|
+
|
62
|
+
def tool_meta(meta: types.Tool):
|
63
|
+
def _add_metadata(**kwargs):
|
64
|
+
def decorator(func):
|
65
|
+
if inspect.iscoroutinefunction(func):
|
66
|
+
|
67
|
+
@functools.wraps(func)
|
68
|
+
async def async_wrapper(*args, **kwargs):
|
69
|
+
return await func(*args, **kwargs)
|
70
|
+
|
71
|
+
wrapper = async_wrapper
|
72
|
+
else:
|
73
|
+
|
74
|
+
@functools.wraps(func)
|
75
|
+
def sync_wrapper(*args, **kwargs):
|
76
|
+
return func(*args, **kwargs)
|
77
|
+
|
78
|
+
wrapper = sync_wrapper
|
79
|
+
for key, value in kwargs.items():
|
80
|
+
setattr(wrapper, key, value)
|
81
|
+
return wrapper
|
82
|
+
|
83
|
+
return decorator
|
84
|
+
|
85
|
+
return _add_metadata(tool_meta=meta)
|
86
|
+
|
87
|
+
|
88
|
+
def auto_register_tools(func_list: list[Union[ToolFunc, AsyncToolFunc]]):
|
89
|
+
"""尝试自动注册带有 tool_meta 的工具"""
|
90
|
+
for func in func_list:
|
91
|
+
if hasattr(func, "tool_meta"):
|
92
|
+
meta = getattr(func, "tool_meta")
|
93
|
+
register_tool(meta=meta, func=func)
|
94
|
+
else:
|
95
|
+
raise ValueError("func must have tool_meta attribute")
|
96
|
+
|
97
|
+
|
98
|
+
async def call_tool(name: str, arguments: dict) -> ToolResult:
|
99
|
+
"""执行工具并处理异常"""
|
100
|
+
|
101
|
+
# 工具存在性校验
|
102
|
+
if (tool_entry := _all_tools.get(name)) is None:
|
103
|
+
raise ValueError(f"Tool {name} not found")
|
104
|
+
|
105
|
+
# 工具输入参数校验
|
106
|
+
arguments = {k: v for k, v in arguments.items() if v is not None}
|
107
|
+
try:
|
108
|
+
tool_entry.input_validator(arguments)
|
109
|
+
except fastjsonschema.JsonSchemaException as e:
|
110
|
+
raise ValueError(f"Invalid arguments for tool {name}: {e}")
|
111
|
+
|
112
|
+
try:
|
113
|
+
if tool_entry.async_func is not None:
|
114
|
+
# 异步函数直接执行
|
115
|
+
result = await tool_entry.async_func(**arguments)
|
116
|
+
return result
|
117
|
+
elif tool_entry.func is not None:
|
118
|
+
# 同步函数需要到线程池中转化为异步函数执行
|
119
|
+
loop = asyncio.get_event_loop()
|
120
|
+
result = await loop.run_in_executor(
|
121
|
+
executor=None, # 使用全局线程池
|
122
|
+
func=lambda: tool_entry.func(**arguments),
|
123
|
+
)
|
124
|
+
return result
|
125
|
+
else:
|
126
|
+
raise ValueError(f"Unexpected tool entry: {tool_entry}")
|
127
|
+
except Exception as e:
|
128
|
+
raise RuntimeError(f"Tool {name} execution error: {str(e)}") from e
|
129
|
+
|
130
|
+
|
131
|
+
# 明确导出接口
|
132
|
+
__all__ = [
|
133
|
+
"all_tools",
|
134
|
+
"register_tool",
|
135
|
+
"call_tool",
|
136
|
+
"tool_meta",
|
137
|
+
"auto_register_tools",
|
138
|
+
]
|
@@ -0,0 +1,13 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: sufy-mcp-server
|
3
|
+
Version: 1.0.0
|
4
|
+
Summary: A MCP server project of Sufy.
|
5
|
+
License-File: LICENSE.txt
|
6
|
+
Requires-Python: >=3.12
|
7
|
+
Requires-Dist: aioboto3>=13.2.0
|
8
|
+
Requires-Dist: fastjsonschema>=2.21.1
|
9
|
+
Requires-Dist: httpx>=0.28.1
|
10
|
+
Requires-Dist: mcp[cli]>=1.0.0
|
11
|
+
Requires-Dist: openai>=1.66.3
|
12
|
+
Requires-Dist: pip>=25.0.1
|
13
|
+
Requires-Dist: python-dotenv>=1.0.1
|