imgenx-mcp 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- imgenx/factory.py +50 -0
- imgenx/main.py +54 -0
- imgenx/operator.py +151 -0
- imgenx/oss_service.py +198 -0
- imgenx/predictor/base/base_image_analyzer.py +13 -0
- imgenx/predictor/base/base_image_generator.py +17 -0
- imgenx/predictor/base/base_video_generator.py +18 -0
- imgenx/predictor/generators/doubao_image_analyzer.py +65 -0
- imgenx/predictor/generators/doubao_image_generator.py +77 -0
- imgenx/predictor/generators/doubao_video_generator.py +122 -0
- imgenx/script.py +103 -0
- imgenx/server.py +378 -0
- imgenx_mcp-0.4.0.dist-info/METADATA +392 -0
- imgenx_mcp-0.4.0.dist-info/RECORD +17 -0
- imgenx_mcp-0.4.0.dist-info/WHEEL +4 -0
- imgenx_mcp-0.4.0.dist-info/entry_points.txt +2 -0
- imgenx_mcp-0.4.0.dist-info/licenses/LICENSE +21 -0
imgenx/factory.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from functools import cache
|
|
3
|
+
from importlib import import_module
|
|
4
|
+
|
|
5
|
+
from imgenx.predictor.base.base_image_generator import BaseImageGenerator
|
|
6
|
+
from imgenx.predictor.base.base_video_generator import BaseVideoGenerator
|
|
7
|
+
from imgenx.predictor.base.base_image_analyzer import BaseImageAnalyzer
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@cache
|
|
11
|
+
def create_image_generator(model: str, api_key: str) -> BaseImageGenerator:
|
|
12
|
+
return create_predictor(model, api_key, 'image_generator')
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@cache
|
|
16
|
+
def create_video_generator(model: str, api_key: str) -> BaseVideoGenerator:
|
|
17
|
+
return create_predictor(model, api_key, 'video_generator')
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@cache
|
|
21
|
+
def create_image_analyzer(model: str, api_key: str) -> BaseImageAnalyzer:
|
|
22
|
+
return create_predictor(model, api_key, 'image_analyzer')
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def create_predictor(model: str, api_key: str, role: str) -> BaseImageAnalyzer:
|
|
26
|
+
provider, model = model.split(':')
|
|
27
|
+
provider = provider.lower()
|
|
28
|
+
|
|
29
|
+
if provider not in get_providers():
|
|
30
|
+
raise ValueError(f'Provider {provider} not found.')
|
|
31
|
+
|
|
32
|
+
predictor_package = f'imgenx.predictor.generators.{provider}_{role.lower()}'
|
|
33
|
+
predictor_class = f'{provider.capitalize()}{role.split("_")[0].capitalize()}{role.split("_")[1].capitalize()}'
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
package = import_module(predictor_package)
|
|
37
|
+
predictor = getattr(package, predictor_class)
|
|
38
|
+
except (ImportError, AttributeError):
|
|
39
|
+
raise ValueError(f'Provider {provider} not found.')
|
|
40
|
+
|
|
41
|
+
return predictor(model, api_key)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@cache
|
|
45
|
+
def get_providers():
|
|
46
|
+
providers = set()
|
|
47
|
+
for path in (Path(__file__).parent / 'predictor/generators').glob('*_*_*.py'):
|
|
48
|
+
providers.add(path.stem.split('_')[0])
|
|
49
|
+
|
|
50
|
+
return list(providers)
|
imgenx/main.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from argparse import ArgumentParser
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def run():
|
|
5
|
+
parser = ArgumentParser()
|
|
6
|
+
subparsers = parser.add_subparsers(dest='command', required=True)
|
|
7
|
+
|
|
8
|
+
server_parser = subparsers.add_parser('server', help='启动 MCP 服务器')
|
|
9
|
+
server_parser.add_argument('--transport', default='stdio', help='stdio|sse|streamable-http')
|
|
10
|
+
server_parser.add_argument('--host', default='0.0.0.0', help='主机地址')
|
|
11
|
+
server_parser.add_argument('--port', default=8000, type=int, help='端口')
|
|
12
|
+
server_parser.add_argument('--no_tools', nargs='+', default=None, help='禁用的工具名列表(用空格分隔)')
|
|
13
|
+
|
|
14
|
+
image_parser = subparsers.add_parser('image', help='生成图片')
|
|
15
|
+
image_parser.add_argument('prompt', help='生成图片的提示词')
|
|
16
|
+
image_parser.add_argument('--images', nargs='+', default=None, help='输入图片路径列表')
|
|
17
|
+
image_parser.add_argument('--size', default='2K', help='生成图像的分辨率或宽高像素值,分辨率可选值:1K、2K、4K,宽高像素可选值:2048x2048、2304x1728、1728x2304、2560x1440、1440x2560、2496x1664、1664x2496、3024x1296')
|
|
18
|
+
image_parser.add_argument('--output', default='imgenx.jpg', help='输出文件或目录路径')
|
|
19
|
+
|
|
20
|
+
video_parser = subparsers.add_parser('video', help='生成视频')
|
|
21
|
+
video_parser.add_argument('prompt', help='生成视频的提示词')
|
|
22
|
+
video_parser.add_argument('--first_frame', default=None, help='输入视频的第一帧路径')
|
|
23
|
+
video_parser.add_argument('--last_frame', default=None, help='输入视频的最后一帧路径')
|
|
24
|
+
video_parser.add_argument('--resolution', default='720p', help='生成视频的分辨率,可选值:480p、720、1080p')
|
|
25
|
+
video_parser.add_argument('--ratio', default='16:9', help='生成视频的宽高比,可选值:16:9、4:3、1:1、3:4、9:16、21:9')
|
|
26
|
+
video_parser.add_argument('--duration', default=5, type=int, help='生成视频的时长,单位秒')
|
|
27
|
+
video_parser.add_argument('--output', default='imgenx.mp4', help='输出文件路径')
|
|
28
|
+
|
|
29
|
+
args = parser.parse_args()
|
|
30
|
+
|
|
31
|
+
if args.command == 'server':
|
|
32
|
+
from imgenx.server import mcp
|
|
33
|
+
|
|
34
|
+
if args.no_tools:
|
|
35
|
+
for tool in args.no_tools:
|
|
36
|
+
mcp.remove_tool(tool)
|
|
37
|
+
|
|
38
|
+
if args.transport == 'stdio':
|
|
39
|
+
mcp.run(transport='stdio')
|
|
40
|
+
else:
|
|
41
|
+
mcp.run(transport=args.transport, host=args.host, port=args.port)
|
|
42
|
+
elif args.command == 'image':
|
|
43
|
+
from imgenx import script
|
|
44
|
+
script.gen_image(prompt=args.prompt, size=args.size, output=args.output, images=args.images)
|
|
45
|
+
elif args.command == 'video':
|
|
46
|
+
from imgenx import script
|
|
47
|
+
script.gen_video(prompt=args.prompt, first_frame=args.first_frame, last_frame=args.last_frame,
|
|
48
|
+
resolution=args.resolution, ratio=args.ratio, duration=args.duration, output=args.output)
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(f'Unknown command: {args.command}')
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
if __name__ == '__main__':
|
|
54
|
+
run()
|
imgenx/operator.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Tuple
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from io import BytesIO
|
|
5
|
+
|
|
6
|
+
import requests
|
|
7
|
+
from PIL import Image, ImageEnhance
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def download(url: str, output: str):
|
|
11
|
+
output = Path(output)
|
|
12
|
+
if output.exists():
|
|
13
|
+
raise FileExistsError(f'Path {output} already exists.')
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
response = requests.get(url, timeout=30)
|
|
17
|
+
response.raise_for_status()
|
|
18
|
+
except Exception as e:
|
|
19
|
+
raise requests.RequestException(f'Error: {e}')
|
|
20
|
+
|
|
21
|
+
output.write_bytes(response.content)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_image_info(image: str) -> str:
|
|
25
|
+
try:
|
|
26
|
+
img = _load_image(image)
|
|
27
|
+
except Exception as e:
|
|
28
|
+
raise ValueError(f'Error loading image: {e}')
|
|
29
|
+
|
|
30
|
+
info = {
|
|
31
|
+
'format': str(img.format),
|
|
32
|
+
'size': f"{img.width}x{img.height}",
|
|
33
|
+
'mode': img.mode,
|
|
34
|
+
'file_size': _format_file_size(len(img.tobytes()))
|
|
35
|
+
}
|
|
36
|
+
return info
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def crop_image(image: str, box: str, output: str):
|
|
40
|
+
try:
|
|
41
|
+
x1, y1, x2, y2 = map(float, box.split(','))
|
|
42
|
+
except Exception:
|
|
43
|
+
raise ValueError('box must be "x1, y1, x2, y2" integers')
|
|
44
|
+
|
|
45
|
+
img = _load_image(image)
|
|
46
|
+
|
|
47
|
+
width = img.width
|
|
48
|
+
height = img.height
|
|
49
|
+
|
|
50
|
+
x1 = width * x1
|
|
51
|
+
y1 = height * y1
|
|
52
|
+
x2 = width * x2
|
|
53
|
+
y2 = height * y2
|
|
54
|
+
x1 = int(max(0, x1))
|
|
55
|
+
y1 = int(max(0, y1))
|
|
56
|
+
x2 = int(max(0, x2))
|
|
57
|
+
y2 = int(max(0, y2))
|
|
58
|
+
|
|
59
|
+
cropped = img.crop((x1, y1, x2, y2))
|
|
60
|
+
_save_image(cropped, output)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def resize_image(image: str, size: str, output: str, keep_aspect: bool = True):
|
|
64
|
+
m = re.match(r'^(\d+)x(\d+)$', size)
|
|
65
|
+
if not m:
|
|
66
|
+
raise ValueError('size must be WIDTHxHEIGHT, e.g., "1024x768"')
|
|
67
|
+
|
|
68
|
+
w, h = map(int, m.groups())
|
|
69
|
+
|
|
70
|
+
img = _load_image(image)
|
|
71
|
+
if keep_aspect:
|
|
72
|
+
resized = img.copy()
|
|
73
|
+
resized.thumbnail((w, h), Image.Resampling.LANCZOS)
|
|
74
|
+
else:
|
|
75
|
+
resized = img.resize((w, h), Image.Resampling.LANCZOS)
|
|
76
|
+
|
|
77
|
+
_save_image(resized, output)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def convert_image(image: str, format: str, output: str, quality: int = 90):
|
|
81
|
+
fmt = format.upper()
|
|
82
|
+
if fmt not in ('PNG', 'JPEG', 'JPG', 'WEBP'):
|
|
83
|
+
raise ValueError('format must be one of PNG/JPEG/JPG/WEBP')
|
|
84
|
+
|
|
85
|
+
img = _load_image(image)
|
|
86
|
+
_save_image(img, output, format_hint=fmt, quality=quality)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def adjust_image(image: str, output: str, brightness: float = 1.0, contrast: float = 1.0, saturation: float = 1.0):
|
|
90
|
+
img = _load_image(image)
|
|
91
|
+
img = ImageEnhance.Brightness(img).enhance(brightness)
|
|
92
|
+
img = ImageEnhance.Contrast(img).enhance(contrast)
|
|
93
|
+
img = ImageEnhance.Color(img).enhance(saturation)
|
|
94
|
+
|
|
95
|
+
_save_image(img, output)
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def paste_image(front_image: str, backgroud_image: str, position: Tuple[int, int], output: str):
|
|
99
|
+
front_img = _load_image(front_image)
|
|
100
|
+
back_img = _load_image(backgroud_image)
|
|
101
|
+
|
|
102
|
+
back_img.paste(front_img, tuple(position))
|
|
103
|
+
_save_image(back_img, output)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def _load_image(image: str) -> Image.Image:
|
|
107
|
+
if re.match(r'^https?://', image):
|
|
108
|
+
try:
|
|
109
|
+
resp = requests.get(image, timeout=30)
|
|
110
|
+
resp.raise_for_status()
|
|
111
|
+
except Exception as e:
|
|
112
|
+
raise requests.RequestException(f'Error: {e}')
|
|
113
|
+
return Image.open(BytesIO(resp.content)).convert('RGBA')
|
|
114
|
+
|
|
115
|
+
path = Path(image)
|
|
116
|
+
if not path.exists():
|
|
117
|
+
raise FileNotFoundError(f'Image not found: {image}')
|
|
118
|
+
return Image.open(path)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _save_image(img: Image.Image, output: str, format_hint: str | None = None, quality: int = 90) -> None:
|
|
122
|
+
out = Path(output)
|
|
123
|
+
if out.exists():
|
|
124
|
+
raise FileExistsError(f'Path {out} already exists.')
|
|
125
|
+
out.parent.mkdir(parents=True, exist_ok=True)
|
|
126
|
+
|
|
127
|
+
fmt = format_hint or (out.suffix[1:].upper() if out.suffix else 'PNG')
|
|
128
|
+
|
|
129
|
+
if fmt.upper() == 'JPG':
|
|
130
|
+
fmt = 'JPEG'
|
|
131
|
+
|
|
132
|
+
save_kwargs = {}
|
|
133
|
+
|
|
134
|
+
if fmt.upper() in ('JPEG', 'WEBP'):
|
|
135
|
+
save_kwargs['quality'] = quality
|
|
136
|
+
|
|
137
|
+
if fmt.upper() == 'JPEG':
|
|
138
|
+
img = img.convert('RGB')
|
|
139
|
+
|
|
140
|
+
img.save(out, format=fmt, **save_kwargs)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _format_file_size(size_bytes: int) -> str:
|
|
144
|
+
if size_bytes < 1024:
|
|
145
|
+
return f"{size_bytes} B"
|
|
146
|
+
elif size_bytes < 1024 * 1024:
|
|
147
|
+
return f"{size_bytes / 1024:.1f} KB"
|
|
148
|
+
elif size_bytes < 1024 * 1024 * 1024:
|
|
149
|
+
return f"{size_bytes / (1024 * 1024):.1f} MB"
|
|
150
|
+
else:
|
|
151
|
+
return f"{size_bytes / (1024 * 1024 * 1024):.1f} GB"
|
imgenx/oss_service.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
1
|
+
"""阿里云 OSS 上传服务"""
|
|
2
|
+
import os
|
|
3
|
+
import uuid
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
import oss2
|
|
9
|
+
from dotenv import load_dotenv
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
load_dotenv()
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class OSSService:
|
|
16
|
+
"""阿里云 OSS 服务类"""
|
|
17
|
+
|
|
18
|
+
def __init__(self):
|
|
19
|
+
"""初始化 OSS 客户端"""
|
|
20
|
+
self.access_key_id = os.getenv('OSS_ACCESS_KEY_ID')
|
|
21
|
+
self.access_key_secret = os.getenv('OSS_ACCESS_KEY_SECRET')
|
|
22
|
+
self.bucket_name = os.getenv('OSS_BUCKET')
|
|
23
|
+
self.endpoint = os.getenv('OSS_ENDPOINT')
|
|
24
|
+
self.cdn_url = os.getenv('OSS_CDN_URL', '').rstrip('/')
|
|
25
|
+
|
|
26
|
+
if not all([self.access_key_id, self.access_key_secret, self.bucket_name, self.endpoint]):
|
|
27
|
+
raise ValueError('OSS 配置不完整,请检查环境变量')
|
|
28
|
+
|
|
29
|
+
# 创建认证对象
|
|
30
|
+
auth = oss2.Auth(self.access_key_id, self.access_key_secret)
|
|
31
|
+
|
|
32
|
+
# 创建 Bucket 对象
|
|
33
|
+
self.bucket = oss2.Bucket(auth, self.endpoint, self.bucket_name)
|
|
34
|
+
|
|
35
|
+
def generate_object_key(self, filename: str, business_dir: str = 'data') -> str:
|
|
36
|
+
"""
|
|
37
|
+
生成 OSS 对象键(存储路径)
|
|
38
|
+
|
|
39
|
+
格式: /{business_dir}/{YYYYMM}/{timestamp}_{uuid}.{ext}
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
filename: 原始文件名
|
|
43
|
+
business_dir: 业务目录,默认 'data'
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
str: OSS 对象键
|
|
47
|
+
"""
|
|
48
|
+
# 获取文件扩展名
|
|
49
|
+
ext = Path(filename).suffix.lower()
|
|
50
|
+
|
|
51
|
+
# 生成年月目录
|
|
52
|
+
now = datetime.now()
|
|
53
|
+
date_dir = now.strftime('%Y%m')
|
|
54
|
+
|
|
55
|
+
# 生成唯一文件名
|
|
56
|
+
timestamp = int(now.timestamp() * 1000)
|
|
57
|
+
unique_id = uuid.uuid4().hex[:8]
|
|
58
|
+
new_filename = f'{timestamp}_{unique_id}{ext}'
|
|
59
|
+
|
|
60
|
+
# 组合完整路径
|
|
61
|
+
object_key = f'{business_dir}/{date_dir}/{new_filename}'
|
|
62
|
+
|
|
63
|
+
return object_key
|
|
64
|
+
|
|
65
|
+
def upload_file(self, file_path: str, object_key: Optional[str] = None,
|
|
66
|
+
business_dir: str = 'data') -> dict:
|
|
67
|
+
"""
|
|
68
|
+
上传本地文件到 OSS
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
file_path: 本地文件路径
|
|
72
|
+
object_key: OSS 对象键,如果为 None 则自动生成
|
|
73
|
+
business_dir: 业务目录,默认 'data'
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
dict: 上传结果,包含 object_key, url, cdn_url
|
|
77
|
+
"""
|
|
78
|
+
file_path = Path(file_path)
|
|
79
|
+
|
|
80
|
+
if not file_path.exists():
|
|
81
|
+
raise FileNotFoundError(f'文件不存在: {file_path}')
|
|
82
|
+
|
|
83
|
+
# 如果未指定 object_key,则自动生成
|
|
84
|
+
if object_key is None:
|
|
85
|
+
object_key = self.generate_object_key(file_path.name, business_dir)
|
|
86
|
+
|
|
87
|
+
# 上传文件
|
|
88
|
+
with open(file_path, 'rb') as f:
|
|
89
|
+
result = self.bucket.put_object(object_key, f)
|
|
90
|
+
|
|
91
|
+
if result.status != 200:
|
|
92
|
+
raise Exception(f'上传失败: {result.status}')
|
|
93
|
+
|
|
94
|
+
# 生成访问 URL
|
|
95
|
+
oss_url = f'https://{self.bucket_name}.{self.endpoint}/{object_key}'
|
|
96
|
+
cdn_url = f'{self.cdn_url}/{object_key}' if self.cdn_url else oss_url
|
|
97
|
+
|
|
98
|
+
return {
|
|
99
|
+
'object_key': object_key,
|
|
100
|
+
'oss_url': oss_url,
|
|
101
|
+
'cdn_url': cdn_url,
|
|
102
|
+
'status': result.status
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
def upload_bytes(self, data: bytes, filename: str,
|
|
106
|
+
object_key: Optional[str] = None,
|
|
107
|
+
business_dir: str = 'data') -> dict:
|
|
108
|
+
"""
|
|
109
|
+
上传字节数据到 OSS
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
data: 字节数据
|
|
113
|
+
filename: 文件名(用于提取扩展名)
|
|
114
|
+
object_key: OSS 对象键,如果为 None 则自动生成
|
|
115
|
+
business_dir: 业务目录,默认 'data'
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
dict: 上传结果,包含 object_key, url, cdn_url
|
|
119
|
+
"""
|
|
120
|
+
# 如果未指定 object_key,则自动生成
|
|
121
|
+
if object_key is None:
|
|
122
|
+
object_key = self.generate_object_key(filename, business_dir)
|
|
123
|
+
|
|
124
|
+
# 上传数据
|
|
125
|
+
result = self.bucket.put_object(object_key, data)
|
|
126
|
+
|
|
127
|
+
if result.status != 200:
|
|
128
|
+
raise Exception(f'上传失败: {result.status}')
|
|
129
|
+
|
|
130
|
+
# 生成访问 URL
|
|
131
|
+
oss_url = f'https://{self.bucket_name}.{self.endpoint}/{object_key}'
|
|
132
|
+
cdn_url = f'{self.cdn_url}/{object_key}' if self.cdn_url else oss_url
|
|
133
|
+
|
|
134
|
+
return {
|
|
135
|
+
'object_key': object_key,
|
|
136
|
+
'oss_url': oss_url,
|
|
137
|
+
'cdn_url': cdn_url,
|
|
138
|
+
'status': result.status
|
|
139
|
+
}
|
|
140
|
+
|
|
141
|
+
def delete_file(self, object_key: str) -> bool:
|
|
142
|
+
"""
|
|
143
|
+
删除 OSS 文件
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
object_key: OSS 对象键
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
bool: 删除是否成功
|
|
150
|
+
"""
|
|
151
|
+
try:
|
|
152
|
+
result = self.bucket.delete_object(object_key)
|
|
153
|
+
return result.status == 204
|
|
154
|
+
except Exception as e:
|
|
155
|
+
print(f'删除文件失败: {e}')
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
def file_exists(self, object_key: str) -> bool:
|
|
159
|
+
"""
|
|
160
|
+
检查文件是否存在
|
|
161
|
+
|
|
162
|
+
Args:
|
|
163
|
+
object_key: OSS 对象键
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
bool: 文件是否存在
|
|
167
|
+
"""
|
|
168
|
+
try:
|
|
169
|
+
return self.bucket.object_exists(object_key)
|
|
170
|
+
except Exception:
|
|
171
|
+
return False
|
|
172
|
+
|
|
173
|
+
def get_file_url(self, object_key: str, use_cdn: bool = True) -> str:
|
|
174
|
+
"""
|
|
175
|
+
获取文件访问 URL
|
|
176
|
+
|
|
177
|
+
Args:
|
|
178
|
+
object_key: OSS 对象键
|
|
179
|
+
use_cdn: 是否使用 CDN URL
|
|
180
|
+
|
|
181
|
+
Returns:
|
|
182
|
+
str: 文件访问 URL
|
|
183
|
+
"""
|
|
184
|
+
if use_cdn and self.cdn_url:
|
|
185
|
+
return f'{self.cdn_url}/{object_key}'
|
|
186
|
+
return f'https://{self.bucket_name}.{self.endpoint}/{object_key}'
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
# 创建全局单例
|
|
190
|
+
_oss_service_instance = None
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def get_oss_service() -> OSSService:
|
|
194
|
+
"""获取 OSS 服务单例"""
|
|
195
|
+
global _oss_service_instance
|
|
196
|
+
if _oss_service_instance is None:
|
|
197
|
+
_oss_service_instance = OSSService()
|
|
198
|
+
return _oss_service_instance
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from typing import List, Dict
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BaseImageAnalyzer(ABC):
|
|
6
|
+
|
|
7
|
+
@abstractmethod
|
|
8
|
+
def __init__(self, model: str, api_key: str):
|
|
9
|
+
raise NotImplementedError
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def analyze(self, prompt: str, image: str) -> List[Dict[str, str]]:
|
|
13
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from typing import List, Dict
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BaseImageGenerator(ABC):
|
|
6
|
+
|
|
7
|
+
@abstractmethod
|
|
8
|
+
def __init__(self, model: str, api_key: str):
|
|
9
|
+
raise NotImplementedError
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def text_to_image(self, prompt: str, size: str) -> List[Dict[str, str]]:
|
|
13
|
+
raise NotImplementedError
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def image_to_image(self, prompt: str, images: List[str], size: str) -> List[Dict[str, str]]:
|
|
17
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from typing import List, Dict
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BaseVideoGenerator(ABC):
|
|
6
|
+
|
|
7
|
+
@abstractmethod
|
|
8
|
+
def __init__(self, model: str, api_key: str):
|
|
9
|
+
raise NotImplementedError
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def text_to_video(self, prompt: str, resolution: str, ratio: str, duration: int) -> Dict[str, str]:
|
|
13
|
+
raise NotImplementedError
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def image_to_video(self, prompt: str, first_frame: str, last_frame: str|None,
|
|
17
|
+
resolution: str, ratio: str, duration: int) -> Dict[str, str]:
|
|
18
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
sys.path.insert(0, '../../..')
|
|
3
|
+
|
|
4
|
+
import base64
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import List, Dict
|
|
7
|
+
from volcenginesdkarkruntime import Ark
|
|
8
|
+
|
|
9
|
+
from imgenx.predictor.base.base_image_analyzer import BaseImageAnalyzer
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DoubaoImageAnalyzer(BaseImageAnalyzer):
|
|
13
|
+
|
|
14
|
+
def __init__(self, model: str, api_key: str):
|
|
15
|
+
self.model = model
|
|
16
|
+
self.client = Ark(
|
|
17
|
+
base_url='https://ark.cn-beijing.volces.com/api/v3',
|
|
18
|
+
api_key=api_key,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
def analyze(self, prompt: str, image: str) -> str:
|
|
22
|
+
if not image.startswith('http'):
|
|
23
|
+
image = self._image_to_base64(image)
|
|
24
|
+
|
|
25
|
+
response = self.client.chat.completions.create(
|
|
26
|
+
model=self.model,
|
|
27
|
+
messages=[
|
|
28
|
+
{
|
|
29
|
+
'role': 'user',
|
|
30
|
+
"content": [
|
|
31
|
+
{
|
|
32
|
+
"type": "image_url",
|
|
33
|
+
"image_url": {
|
|
34
|
+
"url": image
|
|
35
|
+
},
|
|
36
|
+
},
|
|
37
|
+
{"type": "text", "text": prompt},
|
|
38
|
+
],
|
|
39
|
+
}
|
|
40
|
+
],
|
|
41
|
+
)
|
|
42
|
+
return response.choices[0].message.content
|
|
43
|
+
|
|
44
|
+
def _image_to_base64(self, image_path: str) -> str:
|
|
45
|
+
image_path = Path(image_path)
|
|
46
|
+
|
|
47
|
+
with open(image_path, 'rb') as image_file:
|
|
48
|
+
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
|
49
|
+
base64_image = f'data:image/{image_path.suffix.strip(".")};base64,{base64_image}'
|
|
50
|
+
|
|
51
|
+
return base64_image
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
if __name__ == '__main__':
|
|
55
|
+
import os
|
|
56
|
+
from dotenv import load_dotenv
|
|
57
|
+
|
|
58
|
+
load_dotenv()
|
|
59
|
+
|
|
60
|
+
api_key = os.getenv('IMGENX_API_KEY')
|
|
61
|
+
model = 'doubao-seed-1-6-vision-250815'
|
|
62
|
+
|
|
63
|
+
analyzer = DoubaoImageAnalyzer(model, api_key)
|
|
64
|
+
result = analyzer.analyze('请描述这张图片', '/Volumes/DATA/个人/project/imgenx-mcp-server/logo.jpg')
|
|
65
|
+
print(result)
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import List, Dict
|
|
4
|
+
from volcenginesdkarkruntime import Ark
|
|
5
|
+
|
|
6
|
+
from imgenx.predictor.base.base_image_generator import BaseImageGenerator
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DoubaoImageGenerator(BaseImageGenerator):
|
|
10
|
+
|
|
11
|
+
def __init__(self, model: str, api_key: str):
|
|
12
|
+
self.model = model
|
|
13
|
+
self.client = Ark(
|
|
14
|
+
base_url='https://ark.cn-beijing.volces.com/api/v3',
|
|
15
|
+
api_key=api_key,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
def text_to_image(self, prompt: str, size: str) -> List[Dict[str, str]]:
|
|
19
|
+
response = self.client.images.generate(
|
|
20
|
+
model=self.model,
|
|
21
|
+
prompt=prompt,
|
|
22
|
+
sequential_image_generation='auto',
|
|
23
|
+
response_format='url',
|
|
24
|
+
size=size,
|
|
25
|
+
stream=False,
|
|
26
|
+
watermark=False
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
result = []
|
|
30
|
+
for item in response.model_dump()['data']:
|
|
31
|
+
if 'b64_json' in item:
|
|
32
|
+
item.pop('b64_json')
|
|
33
|
+
|
|
34
|
+
result.append(item)
|
|
35
|
+
|
|
36
|
+
return result
|
|
37
|
+
|
|
38
|
+
def image_to_image(self, prompt: str, images: List[str], size: str) -> List[Dict[str, str]]:
|
|
39
|
+
if len(images) == 0:
|
|
40
|
+
pass_images = None
|
|
41
|
+
else:
|
|
42
|
+
pass_images = []
|
|
43
|
+
|
|
44
|
+
for i in images:
|
|
45
|
+
if i.startswith('http'):
|
|
46
|
+
pass_images.append(i)
|
|
47
|
+
else:
|
|
48
|
+
pass_images.append(self._image_to_base64(i))
|
|
49
|
+
|
|
50
|
+
response = self.client.images.generate(
|
|
51
|
+
model=self.model,
|
|
52
|
+
prompt=prompt,
|
|
53
|
+
sequential_image_generation='auto',
|
|
54
|
+
response_format='url',
|
|
55
|
+
image=pass_images,
|
|
56
|
+
size=size,
|
|
57
|
+
stream=False,
|
|
58
|
+
watermark=False
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
result = []
|
|
62
|
+
for item in response.model_dump()['data']:
|
|
63
|
+
if 'b64_json' in item:
|
|
64
|
+
item.pop('b64_json')
|
|
65
|
+
|
|
66
|
+
result.append(item)
|
|
67
|
+
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
def _image_to_base64(self, image_path: str) -> str:
|
|
71
|
+
image_path = Path(image_path)
|
|
72
|
+
|
|
73
|
+
with open(image_path, 'rb') as image_file:
|
|
74
|
+
base64_image = base64.b64encode(image_file.read()).decode('utf-8')
|
|
75
|
+
base64_image = f'data:image/{image_path.suffix.strip(".")};base64,{base64_image}'
|
|
76
|
+
|
|
77
|
+
return base64_image
|