sufy-mcp-server 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mcp_server/__init__.py +11 -0
- mcp_server/application.py +53 -0
- mcp_server/config/__init__.py +0 -0
- mcp_server/config/config.py +52 -0
- mcp_server/consts/__init__.py +0 -0
- mcp_server/consts/consts.py +1 -0
- mcp_server/core/__init__.py +17 -0
- mcp_server/core/media_processing/__init__.py +11 -0
- mcp_server/core/media_processing/tools.py +179 -0
- mcp_server/core/media_processing/utils.py +157 -0
- mcp_server/core/storage/__init__.py +13 -0
- mcp_server/core/storage/resource.py +158 -0
- mcp_server/core/storage/storage.py +203 -0
- mcp_server/core/storage/tools.py +154 -0
- mcp_server/core/version/__init__.py +9 -0
- mcp_server/core/version/tools.py +31 -0
- mcp_server/core/version/version.py +2 -0
- mcp_server/resource/__init__.py +0 -0
- mcp_server/resource/resource.py +61 -0
- mcp_server/server.py +72 -0
- mcp_server/tools/__init__.py +0 -0
- mcp_server/tools/tools.py +138 -0
- sufy_mcp_server-1.0.0.dist-info/METADATA +13 -0
- sufy_mcp_server-1.0.0.dist-info/RECORD +27 -0
- sufy_mcp_server-1.0.0.dist-info/WHEEL +4 -0
- sufy_mcp_server-1.0.0.dist-info/entry_points.txt +2 -0
- sufy_mcp_server-1.0.0.dist-info/licenses/LICENSE.txt +21 -0
mcp_server/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
from .consts import consts
|
4
|
+
from .server import main
|
5
|
+
|
6
|
+
# Configure logging
|
7
|
+
logging.basicConfig(level=logging.ERROR)
|
8
|
+
logger = logging.getLogger(consts.LOGGER_NAME)
|
9
|
+
logger.info("Initializing MCP server package")
|
10
|
+
|
11
|
+
__all__ = ["main"]
|
@@ -0,0 +1,53 @@
|
|
1
|
+
import logging
|
2
|
+
from contextlib import aclosing
|
3
|
+
from typing import Iterable
|
4
|
+
|
5
|
+
import mcp.types as types
|
6
|
+
from mcp.types import EmptyResult, AnyUrl, Tool
|
7
|
+
|
8
|
+
from mcp import LoggingLevel
|
9
|
+
from mcp.server.lowlevel import Server, helper_types as low_types
|
10
|
+
|
11
|
+
from . import core
|
12
|
+
from .consts import consts
|
13
|
+
from .resource import resource
|
14
|
+
from .tools import tools
|
15
|
+
|
16
|
+
|
17
|
+
logger = logging.getLogger(consts.LOGGER_NAME)
|
18
|
+
|
19
|
+
core.load()
|
20
|
+
server = Server("sufy-mcp-server")
|
21
|
+
|
22
|
+
|
23
|
+
@server.set_logging_level()
|
24
|
+
async def set_logging_level(level: LoggingLevel) -> EmptyResult:
|
25
|
+
logger.setLevel(level.lower())
|
26
|
+
await server.request_context.session.send_log_message(
|
27
|
+
level="warning", data=f"Log level set to {level}", logger=consts.LOGGER_NAME
|
28
|
+
)
|
29
|
+
return EmptyResult()
|
30
|
+
|
31
|
+
|
32
|
+
@server.list_resources()
|
33
|
+
async def list_resources(**kwargs) -> list[types.Resource]:
|
34
|
+
resource_list = []
|
35
|
+
async with aclosing(resource.list_resources(**kwargs)) as results:
|
36
|
+
async for result in results:
|
37
|
+
resource_list.append(result)
|
38
|
+
return resource_list
|
39
|
+
|
40
|
+
|
41
|
+
@server.read_resource()
|
42
|
+
async def read_resource(uri: AnyUrl) -> [str | bytes | Iterable[low_types.ReadResourceContents]]:
|
43
|
+
return await resource.read_resource(uri)
|
44
|
+
|
45
|
+
|
46
|
+
@server.list_tools()
|
47
|
+
async def handle_list_tools() -> list[Tool]:
|
48
|
+
return tools.all_tools()
|
49
|
+
|
50
|
+
|
51
|
+
@server.call_tool()
|
52
|
+
async def call_tool(name: str, arguments: dict):
|
53
|
+
return await tools.call_tool(name, arguments)
|
File without changes
|
@@ -0,0 +1,52 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from typing import List
|
4
|
+
from attr import dataclass
|
5
|
+
from dotenv import load_dotenv
|
6
|
+
|
7
|
+
from ..consts import consts
|
8
|
+
|
9
|
+
_CONFIG_ENV_KEY_ACCESS_KEY = "SUFY_ACCESS_KEY"
|
10
|
+
_CONFIG_ENV_KEY_SECRET_KEY = "SUFY_SECRET_KEY"
|
11
|
+
_CONFIG_ENV_KEY_ENDPOINT_URL = "SUFY_ENDPOINT_URL"
|
12
|
+
_CONFIG_ENV_KEY_REGION_NAME = "SUFY_REGION_NAME"
|
13
|
+
_CONFIG_ENV_KEY_BUCKETS = "SUFY_BUCKETS"
|
14
|
+
|
15
|
+
logger = logging.getLogger(consts.LOGGER_NAME)
|
16
|
+
|
17
|
+
# Load environment variables at package initialization
|
18
|
+
load_dotenv(override=True)
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class Config:
|
23
|
+
access_key: str
|
24
|
+
secret_key: str
|
25
|
+
endpoint_url: str
|
26
|
+
region_name: str
|
27
|
+
buckets: List[str]
|
28
|
+
|
29
|
+
|
30
|
+
def load_config() -> Config:
|
31
|
+
config = Config(
|
32
|
+
access_key=os.getenv(_CONFIG_ENV_KEY_ACCESS_KEY, "SUFY_ACCESS_KEY"),
|
33
|
+
secret_key=os.getenv(_CONFIG_ENV_KEY_SECRET_KEY, "SUFY_SECRET_KEY"),
|
34
|
+
endpoint_url=os.getenv(_CONFIG_ENV_KEY_ENDPOINT_URL, "SUFY_ENDPOINT_URL"),
|
35
|
+
region_name=os.getenv(_CONFIG_ENV_KEY_REGION_NAME, "SUFY_REGION_NAME"),
|
36
|
+
buckets=_get_configured_buckets_from_env(),
|
37
|
+
)
|
38
|
+
|
39
|
+
logger.info(f"Configured access_key: {config.access_key}")
|
40
|
+
logger.info(f"Configured endpoint_url: {config.endpoint_url}")
|
41
|
+
logger.info(f"Configured region_name: {config.region_name}")
|
42
|
+
logger.info(f"Configured buckets: {config.buckets}")
|
43
|
+
return config
|
44
|
+
|
45
|
+
|
46
|
+
def _get_configured_buckets_from_env() -> List[str]:
|
47
|
+
bucket_list = os.getenv(_CONFIG_ENV_KEY_BUCKETS)
|
48
|
+
if bucket_list:
|
49
|
+
buckets = [b.strip() for b in bucket_list.split(",")]
|
50
|
+
return buckets
|
51
|
+
else:
|
52
|
+
return []
|
File without changes
|
@@ -0,0 +1 @@
|
|
1
|
+
LOGGER_NAME = "sufy-mcp"
|
@@ -0,0 +1,17 @@
|
|
1
|
+
from ..config import config
|
2
|
+
from .storage import load as load_storage
|
3
|
+
from .media_processing import load as load_media_processing
|
4
|
+
from .version import load as load_version
|
5
|
+
|
6
|
+
|
7
|
+
def load():
|
8
|
+
# 加载配置
|
9
|
+
cfg = config.load_config()
|
10
|
+
|
11
|
+
# 版本
|
12
|
+
load_version(cfg)
|
13
|
+
# 存储业务
|
14
|
+
load_storage(cfg)
|
15
|
+
# 智能多媒体
|
16
|
+
load_media_processing(cfg)
|
17
|
+
|
@@ -0,0 +1,179 @@
|
|
1
|
+
import logging
|
2
|
+
from mcp import types
|
3
|
+
|
4
|
+
from . import utils
|
5
|
+
from ...consts import consts
|
6
|
+
from ...tools import tools
|
7
|
+
from ...config import config
|
8
|
+
|
9
|
+
logger = logging.getLogger(consts.LOGGER_NAME)
|
10
|
+
|
11
|
+
_OBJECT_URL_DESC = "The URL of the image. This can be a URL obtained via the GetObjectURL tool or a URL generated by other Fop tools. Length Constraints: Minimum length of 1."
|
12
|
+
|
13
|
+
_COMMON_DESC = """
|
14
|
+
The information includes the object_url of the scaled image, which users can directly use for HTTP GET requests to retrieve the image content or open in a browser to view the file.
|
15
|
+
The image must be stored in a Sufy Bucket.
|
16
|
+
Supported original image formats: psd, jpeg, png, gif, webp, tiff, bmp, avif, heic. Image width and height cannot exceed 30,000 pixels, and total pixels cannot exceed 150 million.
|
17
|
+
"""
|
18
|
+
|
19
|
+
|
20
|
+
class _ToolImpl:
|
21
|
+
def __init__(self, cfg: config):
|
22
|
+
self.config = cfg
|
23
|
+
|
24
|
+
@tools.tool_meta(
|
25
|
+
types.Tool(
|
26
|
+
name="ImageScaleByPercent",
|
27
|
+
description="""Image scaling tool that resizes images based on a percentage and returns information about the scaled image.
|
28
|
+
""" + _COMMON_DESC,
|
29
|
+
inputSchema={
|
30
|
+
"type": "object",
|
31
|
+
"properties": {
|
32
|
+
"object_url": {
|
33
|
+
"type": "string",
|
34
|
+
"description": _OBJECT_URL_DESC
|
35
|
+
},
|
36
|
+
"percent": {
|
37
|
+
"type": "integer",
|
38
|
+
"description": "Scaling percentage, range [1,999]. For example: 90 means the image width and height are reduced to 90% of the original; 200 means the width and height are enlarged to 200% of the original.",
|
39
|
+
"minimum": 1,
|
40
|
+
"maximum": 999
|
41
|
+
},
|
42
|
+
},
|
43
|
+
"required": ["object_url", "percent"],
|
44
|
+
},
|
45
|
+
)
|
46
|
+
)
|
47
|
+
async def image_scale_by_percent(
|
48
|
+
self, **kwargs
|
49
|
+
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
|
50
|
+
object_url = kwargs.get("object_url", "")
|
51
|
+
percent = kwargs.get("percent", "")
|
52
|
+
func = f"imgmogr/thumbnail/!{percent}p"
|
53
|
+
object_url = await utils.url_add_processing_func(self.config, object_url, func)
|
54
|
+
return [
|
55
|
+
types.TextContent(
|
56
|
+
type="text",
|
57
|
+
text=str(
|
58
|
+
{
|
59
|
+
"object_url": object_url,
|
60
|
+
}
|
61
|
+
),
|
62
|
+
)
|
63
|
+
]
|
64
|
+
|
65
|
+
@tools.tool_meta(
|
66
|
+
types.Tool(
|
67
|
+
name="ImageScaleBySize",
|
68
|
+
description="""Image scaling tool that resizes images based on a specified width or height and returns information about the scaled image.
|
69
|
+
""" + _COMMON_DESC,
|
70
|
+
inputSchema={
|
71
|
+
"type": "object",
|
72
|
+
"properties": {
|
73
|
+
"object_url": {
|
74
|
+
"type": "string",
|
75
|
+
"description": _OBJECT_URL_DESC
|
76
|
+
},
|
77
|
+
"width": {
|
78
|
+
"type": "integer",
|
79
|
+
"description": "Specifies the width for image scaling. The image will be scaled to the specified width, and the height will be adjusted proportionally.",
|
80
|
+
"minimum": 1
|
81
|
+
},
|
82
|
+
"height": {
|
83
|
+
"type": "integer",
|
84
|
+
"description": "Specifies the height for image scaling. The image will be scaled to the specified height, and the width will be adjusted proportionally.",
|
85
|
+
"minimum": 1
|
86
|
+
},
|
87
|
+
},
|
88
|
+
"required": ["object_url"],
|
89
|
+
"anyOf": [
|
90
|
+
{"required": ["width"]},
|
91
|
+
{"required": ["height"]}
|
92
|
+
]
|
93
|
+
},
|
94
|
+
)
|
95
|
+
)
|
96
|
+
async def image_scale_by_size(
|
97
|
+
self, **kwargs
|
98
|
+
) -> list[types.TextContent]:
|
99
|
+
object_url = kwargs.get("object_url", "")
|
100
|
+
width = kwargs.get("width", "")
|
101
|
+
height = kwargs.get("height", "")
|
102
|
+
|
103
|
+
func = f"{width}x{height}"
|
104
|
+
if len(func) == 1:
|
105
|
+
return [
|
106
|
+
types.TextContent(
|
107
|
+
type="text", text="At least one width or height must be set"
|
108
|
+
)
|
109
|
+
]
|
110
|
+
|
111
|
+
fop = f"imgmogr/thumbnail/{func}"
|
112
|
+
object_url = await utils.url_add_processing_func(self.config, object_url, fop)
|
113
|
+
return [
|
114
|
+
types.TextContent(
|
115
|
+
type="text",
|
116
|
+
text=str(
|
117
|
+
{
|
118
|
+
"object_url": object_url,
|
119
|
+
}
|
120
|
+
),
|
121
|
+
)
|
122
|
+
]
|
123
|
+
|
124
|
+
@tools.tool_meta(
|
125
|
+
types.Tool(
|
126
|
+
name="ImageBlur",
|
127
|
+
description="""Applies Gaussian blur to images. Important notes:
|
128
|
+
1. Does NOT affect original GIF format images
|
129
|
+
2. Takes effect after WebP-to-GIF conversion
|
130
|
+
""" + _COMMON_DESC,
|
131
|
+
inputSchema={
|
132
|
+
"type": "object",
|
133
|
+
"properties": {
|
134
|
+
"object_url": {
|
135
|
+
"type": "string",
|
136
|
+
"description": _OBJECT_URL_DESC,
|
137
|
+
"format": "uri"
|
138
|
+
},
|
139
|
+
"radius": {
|
140
|
+
"type": "integer",
|
141
|
+
"description": "Gaussian blur radius determining intensity (1-200)",
|
142
|
+
"minimum": 1,
|
143
|
+
"maximum": 200
|
144
|
+
},
|
145
|
+
"sigma": {
|
146
|
+
"type": "integer",
|
147
|
+
"description": "Standard deviation of normal distribution controlling smoothness (>0) ",
|
148
|
+
"minimum": 1
|
149
|
+
},
|
150
|
+
},
|
151
|
+
"required": ["object_url", "radius", "sigma"],
|
152
|
+
}
|
153
|
+
)
|
154
|
+
)
|
155
|
+
async def image_blur(self, **kwargs) -> list[types.TextContent]:
|
156
|
+
object_url = kwargs.get("object_url", "")
|
157
|
+
radius = kwargs.get("radius", "")
|
158
|
+
sigma = kwargs.get("sigma", "")
|
159
|
+
func = f"imgmogr/blur/{radius}x{sigma}"
|
160
|
+
object_url = await utils.url_add_processing_func(self.config, object_url, func)
|
161
|
+
return [
|
162
|
+
types.TextContent(
|
163
|
+
type="text",
|
164
|
+
text=str({
|
165
|
+
"object_url": object_url,
|
166
|
+
})
|
167
|
+
)
|
168
|
+
]
|
169
|
+
|
170
|
+
|
171
|
+
def register_tools(cfg: config.Config):
|
172
|
+
tool_impl = _ToolImpl(cfg)
|
173
|
+
tools.auto_register_tools(
|
174
|
+
[
|
175
|
+
tool_impl.image_scale_by_percent,
|
176
|
+
tool_impl.image_scale_by_size,
|
177
|
+
tool_impl.image_blur,
|
178
|
+
]
|
179
|
+
)
|
@@ -0,0 +1,157 @@
|
|
1
|
+
from urllib import parse
|
2
|
+
from botocore import auth as s3_auth
|
3
|
+
from botocore import awsrequest
|
4
|
+
from botocore import credentials as s3_credentials
|
5
|
+
|
6
|
+
from ...config import config
|
7
|
+
|
8
|
+
_FUNC_PREFIXES = [
|
9
|
+
"imgmogr/"
|
10
|
+
]
|
11
|
+
|
12
|
+
async def url_add_processing_func(cfg: config.Config, url: str, func: str) -> str:
|
13
|
+
func_items = func.split("/")
|
14
|
+
func_prefix = func_items[0]
|
15
|
+
|
16
|
+
url_info = parse.urlparse(url)
|
17
|
+
new_query, expire = _remove_query_sign_info(url_info.query)
|
18
|
+
new_query = _query_add_processing_func(new_query, func, func_prefix)
|
19
|
+
new_query = parse.quote(new_query, safe='&=')
|
20
|
+
url_info = url_info._replace(query=new_query)
|
21
|
+
new_url = parse.urlunparse(url_info)
|
22
|
+
new_url = await _presigned_url(cfg, str(new_url), expire)
|
23
|
+
return str(new_url)
|
24
|
+
|
25
|
+
|
26
|
+
def _query_add_processing_func(query: str, func: str, func_prefix: str) -> str:
|
27
|
+
queries = query.split("&")
|
28
|
+
if '' in queries:
|
29
|
+
queries.remove('')
|
30
|
+
|
31
|
+
# query 中不包含任何数据
|
32
|
+
if len(queries) == 0:
|
33
|
+
return func
|
34
|
+
|
35
|
+
# funcs 会放在第一个元素中
|
36
|
+
first_query = parse.unquote(queries[0])
|
37
|
+
|
38
|
+
# funcs 不存在
|
39
|
+
if len(first_query) == 0:
|
40
|
+
queries.insert(0, func)
|
41
|
+
return "&".join(queries)
|
42
|
+
|
43
|
+
# first_query 不是 funcs
|
44
|
+
if not _is_func(first_query):
|
45
|
+
queries.insert(0, func)
|
46
|
+
return "&".join(queries)
|
47
|
+
|
48
|
+
# 移除后面的 =
|
49
|
+
first_query = first_query.removesuffix("=")
|
50
|
+
queries.remove(queries[0])
|
51
|
+
|
52
|
+
# 未找到当前类别的 func
|
53
|
+
if first_query.find(func_prefix) < 0:
|
54
|
+
func = first_query + "|" + func
|
55
|
+
queries.insert(0, func)
|
56
|
+
return "&".join(queries)
|
57
|
+
|
58
|
+
query_funcs = first_query.split("|")
|
59
|
+
if '' in query_funcs:
|
60
|
+
query_funcs.remove('')
|
61
|
+
|
62
|
+
# 只有一个 func,且和当前 func 相同,拼接其后
|
63
|
+
if len(query_funcs) == 1:
|
64
|
+
func = first_query + func.removeprefix(func_prefix)
|
65
|
+
queries.insert(0, func)
|
66
|
+
return "&".join(queries)
|
67
|
+
|
68
|
+
# 多个 func,查看最后一个是否和当前 func 匹配
|
69
|
+
last_func = query_funcs[-1]
|
70
|
+
|
71
|
+
# 最后一个不匹配,只用管道符拼接
|
72
|
+
if last_func.find(func_prefix) < 0:
|
73
|
+
func = first_query + "|" + func
|
74
|
+
queries.insert(0, func)
|
75
|
+
return "&".join(queries)
|
76
|
+
|
77
|
+
# 最后一个匹配,则直接拼接在后面
|
78
|
+
func = first_query + func.removeprefix(func_prefix)
|
79
|
+
queries.insert(0, func)
|
80
|
+
return "&".join(queries)
|
81
|
+
|
82
|
+
|
83
|
+
async def _presigned_url(cfg: config.Config, original_url: str, expires: int = 3600) -> str:
|
84
|
+
try:
|
85
|
+
# 创建凭证对象
|
86
|
+
creds = s3_credentials.Credentials(
|
87
|
+
access_key=cfg.access_key,
|
88
|
+
secret_key=cfg.secret_key,
|
89
|
+
)
|
90
|
+
|
91
|
+
# 创建 AWS 请求对象
|
92
|
+
request = awsrequest.AWSRequest(
|
93
|
+
method="GET",
|
94
|
+
url=original_url,
|
95
|
+
headers={}
|
96
|
+
)
|
97
|
+
|
98
|
+
# 创建签名器(禁用路径转义)
|
99
|
+
signer = s3_auth.S3SigV4QueryAuth(
|
100
|
+
credentials=creds,
|
101
|
+
region_name=cfg.region_name,
|
102
|
+
expires=expires,
|
103
|
+
service_name="s3",
|
104
|
+
)
|
105
|
+
signer.URI_ESCAPE_PATH = False # 对应 DisableURIPathEscaping=True
|
106
|
+
|
107
|
+
# 进行签名(直接修改请求对象)
|
108
|
+
signer.add_auth(request)
|
109
|
+
|
110
|
+
# 构造最终 URL
|
111
|
+
signed_url = request.url
|
112
|
+
return signed_url
|
113
|
+
|
114
|
+
except Exception as e:
|
115
|
+
raise Exception(f"Presign url:{original_url} error: {e}")
|
116
|
+
|
117
|
+
|
118
|
+
_S3_SIGN_URL_QUERY_KEYS_EXPIRES = "x-amz-expires"
|
119
|
+
_S3_SIGN_URL_QUERY_KEYS = [
|
120
|
+
_S3_SIGN_URL_QUERY_KEYS_EXPIRES,
|
121
|
+
"x-amz-algorithm",
|
122
|
+
"x-amz-credential",
|
123
|
+
"x-amz-date",
|
124
|
+
"x-amz-signedheaders",
|
125
|
+
"x-amz-signature",
|
126
|
+
]
|
127
|
+
|
128
|
+
def _remove_query_sign_info(query: str) -> (str, int):
|
129
|
+
queries = query.split("&")
|
130
|
+
if '' in queries:
|
131
|
+
queries.remove('')
|
132
|
+
|
133
|
+
expire = 3600
|
134
|
+
new_queries = []
|
135
|
+
for item in queries:
|
136
|
+
# 移除签名信息
|
137
|
+
found_sign_info = ""
|
138
|
+
for sign_info in _S3_SIGN_URL_QUERY_KEYS:
|
139
|
+
if item.lower().find(sign_info) >= 0:
|
140
|
+
found_sign_info = sign_info
|
141
|
+
break
|
142
|
+
|
143
|
+
if len(found_sign_info) == 0:
|
144
|
+
# 不是签名信息
|
145
|
+
new_queries.append(item)
|
146
|
+
elif found_sign_info == _S3_SIGN_URL_QUERY_KEYS_EXPIRES:
|
147
|
+
expires = item.split("=")
|
148
|
+
if len(expires) == 2:
|
149
|
+
expire = int(expires[1])
|
150
|
+
|
151
|
+
return "&".join(new_queries), expire
|
152
|
+
|
153
|
+
def _is_func(func: str) -> bool:
|
154
|
+
for prefix in _FUNC_PREFIXES:
|
155
|
+
if func.startswith(prefix):
|
156
|
+
return True
|
157
|
+
return False
|
@@ -0,0 +1,13 @@
|
|
1
|
+
from .storage import StorageService
|
2
|
+
from .tools import register_tools
|
3
|
+
from .resource import register_resource_provider
|
4
|
+
from ...config import config
|
5
|
+
|
6
|
+
|
7
|
+
def load(cfg: config.Config):
|
8
|
+
storage_service = StorageService(cfg)
|
9
|
+
register_tools(storage_service)
|
10
|
+
register_resource_provider(storage_service)
|
11
|
+
|
12
|
+
|
13
|
+
__all__ = ["load", "StorageService"]
|
@@ -0,0 +1,158 @@
|
|
1
|
+
import asyncio
|
2
|
+
import logging
|
3
|
+
import base64
|
4
|
+
from typing import Iterable
|
5
|
+
|
6
|
+
from mcp import types
|
7
|
+
from mcp.server.lowlevel import helper_types as low_types
|
8
|
+
from urllib.parse import unquote
|
9
|
+
|
10
|
+
from .storage import StorageService
|
11
|
+
from ...consts import consts
|
12
|
+
from ...resource import resource
|
13
|
+
|
14
|
+
logger = logging.getLogger(consts.LOGGER_NAME)
|
15
|
+
|
16
|
+
|
17
|
+
class _ResourceProvider(resource.ResourceProvider):
|
18
|
+
def __init__(self, storage: StorageService):
|
19
|
+
super().__init__("s3")
|
20
|
+
self.storage = storage
|
21
|
+
|
22
|
+
async def list_resources(
|
23
|
+
self, prefix: str = "", max_keys: int = 20, **kwargs
|
24
|
+
) -> list[types.Resource]:
|
25
|
+
"""
|
26
|
+
List S3 buckets and their contents as resources with pagination
|
27
|
+
Args:
|
28
|
+
prefix: Prefix listing after this bucket name
|
29
|
+
max_keys: Returns the maximum number of keys (up to 100), default 20
|
30
|
+
"""
|
31
|
+
resources = []
|
32
|
+
logger.debug("Starting to list resources")
|
33
|
+
logger.debug(f"Configured buckets: {self.storage.config.buckets}")
|
34
|
+
|
35
|
+
try:
|
36
|
+
# Get limited number of buckets
|
37
|
+
buckets = await self.storage.list_buckets(prefix)
|
38
|
+
if not buckets or len(buckets) == 0:
|
39
|
+
logger.warning("No buckets found")
|
40
|
+
return []
|
41
|
+
|
42
|
+
# limit concurrent operations
|
43
|
+
async def process_bucket(bucket):
|
44
|
+
bucket_name = bucket["Name"]
|
45
|
+
logger.debug(f"Processing bucket: {bucket_name}")
|
46
|
+
|
47
|
+
try:
|
48
|
+
# List objects in the bucket with a reasonable limit
|
49
|
+
objects = await self.storage.list_objects(
|
50
|
+
bucket_name, max_keys=max_keys
|
51
|
+
)
|
52
|
+
|
53
|
+
for obj in objects:
|
54
|
+
if "Key" in obj and not obj["Key"].endswith("/"):
|
55
|
+
object_key = obj["Key"]
|
56
|
+
if self.storage.is_markdown_file(object_key):
|
57
|
+
mime_type = "text/markdown"
|
58
|
+
elif self.storage.is_image_file(object_key):
|
59
|
+
mime_type = "image/png"
|
60
|
+
else:
|
61
|
+
mime_type = "text/plain"
|
62
|
+
|
63
|
+
resource_entry = types.Resource(
|
64
|
+
uri=f"s3://{bucket_name}/{object_key}",
|
65
|
+
name=object_key,
|
66
|
+
mimeType=mime_type,
|
67
|
+
description=str(obj),
|
68
|
+
)
|
69
|
+
resources.append(resource_entry)
|
70
|
+
logger.debug(f"Added resource: {resource_entry.uri}")
|
71
|
+
|
72
|
+
except Exception as e:
|
73
|
+
logger.error(
|
74
|
+
f"Error listing objects in bucket {bucket_name}: {str(e)}"
|
75
|
+
)
|
76
|
+
|
77
|
+
# Use semaphore to limit concurrent bucket processing
|
78
|
+
semaphore = asyncio.Semaphore(3) # Limit concurrent bucket processing
|
79
|
+
|
80
|
+
async def process_bucket_with_semaphore(bucket):
|
81
|
+
async with semaphore:
|
82
|
+
await process_bucket(bucket)
|
83
|
+
|
84
|
+
# Process buckets concurrently
|
85
|
+
await asyncio.gather(
|
86
|
+
*[process_bucket_with_semaphore(bucket) for bucket in buckets]
|
87
|
+
)
|
88
|
+
|
89
|
+
except Exception as e:
|
90
|
+
logger.error(f"Error listing buckets: {str(e)}")
|
91
|
+
raise
|
92
|
+
|
93
|
+
logger.info(f"Returning {len(resources)} resources")
|
94
|
+
return resources
|
95
|
+
|
96
|
+
async def read_resource(self, uri: types.AnyUrl, **kwargs) -> [
|
97
|
+
str | bytes | Iterable[low_types.ReadResourceContents]]:
|
98
|
+
"""
|
99
|
+
Read content from an S3 resource and return structured response
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
Dict containing 'contents' list with uri, mimeType, and text for each resource
|
103
|
+
"""
|
104
|
+
uri_str = str(uri)
|
105
|
+
logger.debug(f"Reading resource: {uri_str}")
|
106
|
+
|
107
|
+
if not uri_str.startswith("s3://"):
|
108
|
+
raise ValueError("Invalid S3 URI")
|
109
|
+
|
110
|
+
# Parse the S3 URI
|
111
|
+
path = uri_str[5:] # Remove "s3://"
|
112
|
+
path = unquote(path) # Decode URL-encoded characters
|
113
|
+
parts = path.split("/", 1)
|
114
|
+
|
115
|
+
if len(parts) < 2:
|
116
|
+
raise ValueError("Invalid S3 URI format")
|
117
|
+
|
118
|
+
bucket = parts[0]
|
119
|
+
key = parts[1]
|
120
|
+
|
121
|
+
response = await self.storage.get_object(bucket, key)
|
122
|
+
file_content = response["Body"]
|
123
|
+
content_type = response.get("ContentType", "application/octet-stream")
|
124
|
+
if content_type.startswith("image/"):
|
125
|
+
base64_data = base64.b64encode(file_content).decode("utf-8")
|
126
|
+
return [
|
127
|
+
low_types.ReadResourceContents(
|
128
|
+
mime_type=content_type,
|
129
|
+
content=base64_data,
|
130
|
+
)
|
131
|
+
]
|
132
|
+
|
133
|
+
|
134
|
+
if not isinstance(file_content, bytes):
|
135
|
+
file_content = str(file_content)
|
136
|
+
return [
|
137
|
+
low_types.ReadResourceContents(
|
138
|
+
mime_type=content_type,
|
139
|
+
content=file_content,
|
140
|
+
)
|
141
|
+
]
|
142
|
+
|
143
|
+
if content_type.startswith("text/"):
|
144
|
+
file_content = file_content.decode("utf-8")
|
145
|
+
else:
|
146
|
+
file_content = base64.b64encode(file_content).decode("utf-8")
|
147
|
+
|
148
|
+
return [
|
149
|
+
low_types.ReadResourceContents(
|
150
|
+
mime_type=content_type,
|
151
|
+
content=file_content,
|
152
|
+
)
|
153
|
+
]
|
154
|
+
|
155
|
+
|
156
|
+
def register_resource_provider(storage: StorageService):
|
157
|
+
resource_provider = _ResourceProvider(storage)
|
158
|
+
resource.register_resource_provider(resource_provider)
|