maque 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maque/__init__.py +30 -0
- maque/__main__.py +926 -0
- maque/ai_platform/__init__.py +0 -0
- maque/ai_platform/crawl.py +45 -0
- maque/ai_platform/metrics.py +258 -0
- maque/ai_platform/nlp_preprocess.py +67 -0
- maque/ai_platform/webpage_screen_shot.py +195 -0
- maque/algorithms/__init__.py +78 -0
- maque/algorithms/bezier.py +15 -0
- maque/algorithms/bktree.py +117 -0
- maque/algorithms/core.py +104 -0
- maque/algorithms/hilbert.py +16 -0
- maque/algorithms/rate_function.py +92 -0
- maque/algorithms/transform.py +27 -0
- maque/algorithms/trie.py +272 -0
- maque/algorithms/utils.py +63 -0
- maque/algorithms/video.py +587 -0
- maque/api/__init__.py +1 -0
- maque/api/common.py +110 -0
- maque/api/fetch.py +26 -0
- maque/api/static/icon.png +0 -0
- maque/api/static/redoc.standalone.js +1782 -0
- maque/api/static/swagger-ui-bundle.js +3 -0
- maque/api/static/swagger-ui.css +3 -0
- maque/cli/__init__.py +1 -0
- maque/cli/clean_invisible_chars.py +324 -0
- maque/cli/core.py +34 -0
- maque/cli/groups/__init__.py +26 -0
- maque/cli/groups/config.py +205 -0
- maque/cli/groups/data.py +615 -0
- maque/cli/groups/doctor.py +259 -0
- maque/cli/groups/embedding.py +222 -0
- maque/cli/groups/git.py +29 -0
- maque/cli/groups/help.py +410 -0
- maque/cli/groups/llm.py +223 -0
- maque/cli/groups/mcp.py +241 -0
- maque/cli/groups/mllm.py +1795 -0
- maque/cli/groups/mllm_simple.py +60 -0
- maque/cli/groups/quant.py +210 -0
- maque/cli/groups/service.py +490 -0
- maque/cli/groups/system.py +570 -0
- maque/cli/mllm_run.py +1451 -0
- maque/cli/script.py +52 -0
- maque/cli/tree.py +49 -0
- maque/clustering/__init__.py +52 -0
- maque/clustering/analyzer.py +347 -0
- maque/clustering/clusterers.py +464 -0
- maque/clustering/sampler.py +134 -0
- maque/clustering/visualizer.py +205 -0
- maque/constant.py +13 -0
- maque/core.py +133 -0
- maque/cv/__init__.py +1 -0
- maque/cv/image.py +219 -0
- maque/cv/utils.py +68 -0
- maque/cv/video/__init__.py +3 -0
- maque/cv/video/keyframe_extractor.py +368 -0
- maque/embedding/__init__.py +43 -0
- maque/embedding/base.py +56 -0
- maque/embedding/multimodal.py +308 -0
- maque/embedding/server.py +523 -0
- maque/embedding/text.py +311 -0
- maque/git/__init__.py +24 -0
- maque/git/pure_git.py +912 -0
- maque/io/__init__.py +29 -0
- maque/io/core.py +38 -0
- maque/io/ops.py +194 -0
- maque/llm/__init__.py +111 -0
- maque/llm/backend.py +416 -0
- maque/llm/base.py +411 -0
- maque/llm/server.py +366 -0
- maque/mcp_server.py +1096 -0
- maque/mllm_data_processor_pipeline/__init__.py +17 -0
- maque/mllm_data_processor_pipeline/core.py +341 -0
- maque/mllm_data_processor_pipeline/example.py +291 -0
- maque/mllm_data_processor_pipeline/steps/__init__.py +56 -0
- maque/mllm_data_processor_pipeline/steps/data_alignment.py +267 -0
- maque/mllm_data_processor_pipeline/steps/data_loader.py +172 -0
- maque/mllm_data_processor_pipeline/steps/data_validation.py +304 -0
- maque/mllm_data_processor_pipeline/steps/format_conversion.py +411 -0
- maque/mllm_data_processor_pipeline/steps/mllm_annotation.py +331 -0
- maque/mllm_data_processor_pipeline/steps/mllm_refinement.py +446 -0
- maque/mllm_data_processor_pipeline/steps/result_validation.py +501 -0
- maque/mllm_data_processor_pipeline/web_app.py +317 -0
- maque/nlp/__init__.py +14 -0
- maque/nlp/ngram.py +9 -0
- maque/nlp/parser.py +63 -0
- maque/nlp/risk_matcher.py +543 -0
- maque/nlp/sentence_splitter.py +202 -0
- maque/nlp/simple_tradition_cvt.py +31 -0
- maque/performance/__init__.py +21 -0
- maque/performance/_measure_time.py +70 -0
- maque/performance/_profiler.py +367 -0
- maque/performance/_stat_memory.py +51 -0
- maque/pipelines/__init__.py +15 -0
- maque/pipelines/clustering.py +252 -0
- maque/quantization/__init__.py +42 -0
- maque/quantization/auto_round.py +120 -0
- maque/quantization/base.py +145 -0
- maque/quantization/bitsandbytes.py +127 -0
- maque/quantization/llm_compressor.py +102 -0
- maque/retriever/__init__.py +35 -0
- maque/retriever/chroma.py +654 -0
- maque/retriever/document.py +140 -0
- maque/retriever/milvus.py +1140 -0
- maque/table_ops/__init__.py +1 -0
- maque/table_ops/core.py +133 -0
- maque/table_viewer/__init__.py +4 -0
- maque/table_viewer/download_assets.py +57 -0
- maque/table_viewer/server.py +698 -0
- maque/table_viewer/static/element-plus-icons.js +5791 -0
- maque/table_viewer/static/element-plus.css +1 -0
- maque/table_viewer/static/element-plus.js +65236 -0
- maque/table_viewer/static/main.css +268 -0
- maque/table_viewer/static/main.js +669 -0
- maque/table_viewer/static/vue.global.js +18227 -0
- maque/table_viewer/templates/index.html +401 -0
- maque/utils/__init__.py +56 -0
- maque/utils/color.py +68 -0
- maque/utils/color_string.py +45 -0
- maque/utils/compress.py +66 -0
- maque/utils/constant.py +183 -0
- maque/utils/core.py +261 -0
- maque/utils/cursor.py +143 -0
- maque/utils/distance.py +58 -0
- maque/utils/docker.py +96 -0
- maque/utils/downloads.py +51 -0
- maque/utils/excel_helper.py +542 -0
- maque/utils/helper_metrics.py +121 -0
- maque/utils/helper_parser.py +168 -0
- maque/utils/net.py +64 -0
- maque/utils/nvidia_stat.py +140 -0
- maque/utils/ops.py +53 -0
- maque/utils/packages.py +31 -0
- maque/utils/path.py +57 -0
- maque/utils/tar.py +260 -0
- maque/utils/untar.py +129 -0
- maque/web/__init__.py +0 -0
- maque/web/image_downloader.py +1410 -0
- maque-0.2.1.dist-info/METADATA +450 -0
- maque-0.2.1.dist-info/RECORD +143 -0
- maque-0.2.1.dist-info/WHEEL +4 -0
- maque-0.2.1.dist-info/entry_points.txt +3 -0
- maque-0.2.1.dist-info/licenses/LICENSE +21 -0
maque/llm/base.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
LLM Backend 抽象基类
|
|
6
|
+
|
|
7
|
+
提供可扩展的 LLM 后端接口,用户可以继承并实现自己的后端。
|
|
8
|
+
|
|
9
|
+
使用示例:
|
|
10
|
+
```python
|
|
11
|
+
from maque.llm import BaseLLMBackend, ModelConfig
|
|
12
|
+
|
|
13
|
+
class MyCustomBackend(BaseLLMBackend):
|
|
14
|
+
def _load_model_impl(self, config: ModelConfig) -> None:
|
|
15
|
+
# 自定义模型加载逻辑
|
|
16
|
+
pass
|
|
17
|
+
|
|
18
|
+
def _generate_impl(self, messages, **kwargs) -> tuple[str, int, int]:
|
|
19
|
+
# 自定义生成逻辑
|
|
20
|
+
pass
|
|
21
|
+
|
|
22
|
+
def _generate_stream_impl(self, messages, **kwargs):
|
|
23
|
+
# 自定义流式生成
|
|
24
|
+
yield "token"
|
|
25
|
+
```
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import asyncio
|
|
29
|
+
from abc import ABC, abstractmethod
|
|
30
|
+
from dataclasses import dataclass, field
|
|
31
|
+
from typing import AsyncGenerator, List, Literal, Optional, Union
|
|
32
|
+
|
|
33
|
+
from pydantic import BaseModel, Field
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
# ============== 数据模型 ==============
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class ModelConfig:
|
|
41
|
+
"""模型配置
|
|
42
|
+
|
|
43
|
+
Attributes:
|
|
44
|
+
model_id: 模型名称或路径
|
|
45
|
+
device: 设备类型,None 表示自动检测
|
|
46
|
+
torch_dtype: 数据类型,None 表示自动选择
|
|
47
|
+
trust_remote_code: 是否信任远程代码
|
|
48
|
+
local_dir: 本地模型目录
|
|
49
|
+
attn_implementation: 注意力实现 (eager/sdpa/flash_attention_2)
|
|
50
|
+
model_class: 模型类名 (如 "AutoModelForCausalLM", "Qwen3VLForConditionalGeneration")
|
|
51
|
+
processor_class: 处理器类名 (如 "AutoTokenizer", "AutoProcessor")
|
|
52
|
+
chat_template_kwargs: apply_chat_template 的额外参数 (如 {"enable_thinking": True})
|
|
53
|
+
vision_processor: 视觉处理器类型 ("qwen_vl", "general", None)
|
|
54
|
+
extra: 其他扩展配置
|
|
55
|
+
"""
|
|
56
|
+
model_id: str
|
|
57
|
+
device: Optional[str] = None # None = auto
|
|
58
|
+
torch_dtype: Optional[str] = None # None = auto
|
|
59
|
+
trust_remote_code: bool = True
|
|
60
|
+
local_dir: Optional[str] = None
|
|
61
|
+
attn_implementation: Optional[str] = None # eager/sdpa/flash_attention_2
|
|
62
|
+
# 新增:模型类配置
|
|
63
|
+
model_class: Optional[str] = None # None = 自动检测
|
|
64
|
+
processor_class: Optional[str] = None # None = 自动选择
|
|
65
|
+
# 新增:chat template 额外参数
|
|
66
|
+
chat_template_kwargs: dict = field(default_factory=dict)
|
|
67
|
+
# 新增:视觉处理器类型
|
|
68
|
+
vision_processor: Optional[str] = None # "qwen_vl", "general", None
|
|
69
|
+
# 扩展配置
|
|
70
|
+
extra: dict = field(default_factory=dict)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class ImageURL(BaseModel):
|
|
74
|
+
"""图片 URL"""
|
|
75
|
+
url: str = Field(..., description="图片 URL 或 base64")
|
|
76
|
+
detail: Optional[str] = None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class ContentPart(BaseModel):
|
|
80
|
+
"""消息内容部分"""
|
|
81
|
+
type: Literal["text", "image_url"]
|
|
82
|
+
text: Optional[str] = None
|
|
83
|
+
image_url: Optional[ImageURL] = None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ChatMessage(BaseModel):
|
|
87
|
+
"""聊天消息"""
|
|
88
|
+
role: Literal["system", "user", "assistant"]
|
|
89
|
+
content: Union[str, List[ContentPart]] = Field(..., description="消息内容")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class GenerateConfig:
|
|
94
|
+
"""生成配置"""
|
|
95
|
+
max_tokens: int = 512
|
|
96
|
+
temperature: float = 0.7
|
|
97
|
+
top_p: float = 0.9
|
|
98
|
+
stop: Optional[List[str]] = None
|
|
99
|
+
# 扩展配置
|
|
100
|
+
extra: dict = field(default_factory=dict)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
# ============== 抽象基类 ==============
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
class BaseLLMBackend(ABC):
|
|
107
|
+
"""LLM 后端抽象基类
|
|
108
|
+
|
|
109
|
+
子类需要实现以下方法:
|
|
110
|
+
- _load_model_impl: 加载模型
|
|
111
|
+
- _generate_impl: 同步生成
|
|
112
|
+
- _generate_stream_impl: 流式生成 (可选,默认基于 _generate_impl)
|
|
113
|
+
|
|
114
|
+
可选覆盖的方法:
|
|
115
|
+
- _detect_multimodal: 检测是否多模态模型
|
|
116
|
+
- _process_messages: 预处理消息
|
|
117
|
+
- _process_image: 处理图片输入
|
|
118
|
+
"""
|
|
119
|
+
|
|
120
|
+
def __init__(self):
|
|
121
|
+
self._model = None
|
|
122
|
+
self._config: Optional[ModelConfig] = None
|
|
123
|
+
self._device: Optional[str] = None
|
|
124
|
+
self._is_multimodal: bool = False
|
|
125
|
+
self._vision_processor: Optional[str] = None # "qwen_vl", "general", None
|
|
126
|
+
self._lock = asyncio.Lock()
|
|
127
|
+
|
|
128
|
+
# ============== 属性 ==============
|
|
129
|
+
|
|
130
|
+
@property
|
|
131
|
+
def device(self) -> str:
|
|
132
|
+
"""获取设备类型"""
|
|
133
|
+
if self._device is None:
|
|
134
|
+
self._device = self._detect_device()
|
|
135
|
+
return self._device
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def is_multimodal(self) -> bool:
|
|
139
|
+
"""是否多模态模型"""
|
|
140
|
+
return self._is_multimodal
|
|
141
|
+
|
|
142
|
+
@property
|
|
143
|
+
def model_id(self) -> Optional[str]:
|
|
144
|
+
"""当前模型 ID"""
|
|
145
|
+
return self._config.model_id if self._config else None
|
|
146
|
+
|
|
147
|
+
@property
|
|
148
|
+
def is_loaded(self) -> bool:
|
|
149
|
+
"""模型是否已加载"""
|
|
150
|
+
return self._model is not None
|
|
151
|
+
|
|
152
|
+
# ============== 公共接口 ==============
|
|
153
|
+
|
|
154
|
+
async def load_model(self, config: ModelConfig) -> None:
|
|
155
|
+
"""加载模型 (异步)"""
|
|
156
|
+
async with self._lock:
|
|
157
|
+
if self._model is not None:
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
from loguru import logger
|
|
161
|
+
logger.info(f"Loading model: {config.model_id} on {self.device}")
|
|
162
|
+
|
|
163
|
+
loop = asyncio.get_event_loop()
|
|
164
|
+
await loop.run_in_executor(None, self._load_model_sync, config)
|
|
165
|
+
|
|
166
|
+
self._config = config
|
|
167
|
+
logger.info(f"Model loaded: {config.model_id} (multimodal={self._is_multimodal})")
|
|
168
|
+
|
|
169
|
+
def load_model_sync(self, config: ModelConfig) -> None:
|
|
170
|
+
"""加载模型 (同步)"""
|
|
171
|
+
if self._model is not None:
|
|
172
|
+
return
|
|
173
|
+
|
|
174
|
+
from loguru import logger
|
|
175
|
+
logger.info(f"Loading model: {config.model_id} on {self.device}")
|
|
176
|
+
|
|
177
|
+
self._load_model_sync(config)
|
|
178
|
+
self._config = config
|
|
179
|
+
logger.info(f"Model loaded: {config.model_id} (multimodal={self._is_multimodal})")
|
|
180
|
+
|
|
181
|
+
async def generate(
|
|
182
|
+
self,
|
|
183
|
+
messages: List[ChatMessage],
|
|
184
|
+
config: Optional[GenerateConfig] = None,
|
|
185
|
+
) -> tuple[str, int, int]:
|
|
186
|
+
"""生成响应 (异步)
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
tuple: (生成文本, prompt_tokens, completion_tokens)
|
|
190
|
+
"""
|
|
191
|
+
config = config or GenerateConfig()
|
|
192
|
+
loop = asyncio.get_event_loop()
|
|
193
|
+
return await loop.run_in_executor(
|
|
194
|
+
None, self._generate_sync, messages, config
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
def generate_sync(
|
|
198
|
+
self,
|
|
199
|
+
messages: List[ChatMessage],
|
|
200
|
+
config: Optional[GenerateConfig] = None,
|
|
201
|
+
) -> tuple[str, int, int]:
|
|
202
|
+
"""生成响应 (同步)"""
|
|
203
|
+
config = config or GenerateConfig()
|
|
204
|
+
return self._generate_sync(messages, config)
|
|
205
|
+
|
|
206
|
+
async def generate_stream(
|
|
207
|
+
self,
|
|
208
|
+
messages: List[ChatMessage],
|
|
209
|
+
config: Optional[GenerateConfig] = None,
|
|
210
|
+
) -> AsyncGenerator[str, None]:
|
|
211
|
+
"""流式生成 (异步)"""
|
|
212
|
+
config = config or GenerateConfig()
|
|
213
|
+
|
|
214
|
+
# 默认实现:在线程中运行同步流式生成
|
|
215
|
+
import queue
|
|
216
|
+
import threading
|
|
217
|
+
|
|
218
|
+
q = queue.Queue()
|
|
219
|
+
stop_event = threading.Event()
|
|
220
|
+
|
|
221
|
+
def producer():
|
|
222
|
+
try:
|
|
223
|
+
for token in self._generate_stream_sync(messages, config):
|
|
224
|
+
if stop_event.is_set():
|
|
225
|
+
break
|
|
226
|
+
q.put(token)
|
|
227
|
+
except Exception as e:
|
|
228
|
+
q.put(e)
|
|
229
|
+
finally:
|
|
230
|
+
q.put(None) # 结束标记
|
|
231
|
+
|
|
232
|
+
thread = threading.Thread(target=producer)
|
|
233
|
+
thread.start()
|
|
234
|
+
|
|
235
|
+
try:
|
|
236
|
+
while True:
|
|
237
|
+
# 非阻塞获取
|
|
238
|
+
await asyncio.sleep(0.01)
|
|
239
|
+
while not q.empty():
|
|
240
|
+
item = q.get_nowait()
|
|
241
|
+
if item is None:
|
|
242
|
+
return
|
|
243
|
+
if isinstance(item, Exception):
|
|
244
|
+
raise item
|
|
245
|
+
yield item
|
|
246
|
+
finally:
|
|
247
|
+
stop_event.set()
|
|
248
|
+
thread.join(timeout=1)
|
|
249
|
+
|
|
250
|
+
# ============== 内部同步方法 ==============
|
|
251
|
+
|
|
252
|
+
def _load_model_sync(self, config: ModelConfig) -> None:
|
|
253
|
+
"""同步加载模型包装器"""
|
|
254
|
+
# 解析模型路径
|
|
255
|
+
model_path = self._resolve_model_path(config)
|
|
256
|
+
|
|
257
|
+
# 检测是否多模态
|
|
258
|
+
self._is_multimodal = self._detect_multimodal(model_path, config)
|
|
259
|
+
|
|
260
|
+
# 设置视觉处理器类型
|
|
261
|
+
if config.vision_processor:
|
|
262
|
+
self._vision_processor = config.vision_processor
|
|
263
|
+
elif self._is_multimodal:
|
|
264
|
+
# 默认使用 qwen_vl(向后兼容)
|
|
265
|
+
self._vision_processor = "qwen_vl"
|
|
266
|
+
|
|
267
|
+
# 设置设备
|
|
268
|
+
if config.device:
|
|
269
|
+
self._device = config.device
|
|
270
|
+
|
|
271
|
+
# 调用子类实现
|
|
272
|
+
self._load_model_impl(model_path, config)
|
|
273
|
+
|
|
274
|
+
def _generate_sync(
|
|
275
|
+
self, messages: List[ChatMessage], config: GenerateConfig
|
|
276
|
+
) -> tuple[str, int, int]:
|
|
277
|
+
"""同步生成包装器"""
|
|
278
|
+
# 预处理消息
|
|
279
|
+
processed = self._process_messages(messages)
|
|
280
|
+
|
|
281
|
+
# 调用子类实现
|
|
282
|
+
return self._generate_impl(processed, config)
|
|
283
|
+
|
|
284
|
+
def _generate_stream_sync(
|
|
285
|
+
self, messages: List[ChatMessage], config: GenerateConfig
|
|
286
|
+
):
|
|
287
|
+
"""同步流式生成包装器"""
|
|
288
|
+
processed = self._process_messages(messages)
|
|
289
|
+
yield from self._generate_stream_impl(processed, config)
|
|
290
|
+
|
|
291
|
+
# ============== 可覆盖的辅助方法 ==============
|
|
292
|
+
|
|
293
|
+
def _detect_device(self) -> str:
|
|
294
|
+
"""检测设备类型,子类可覆盖
|
|
295
|
+
|
|
296
|
+
检测顺序: CUDA > MPS (Apple Silicon) > CPU
|
|
297
|
+
"""
|
|
298
|
+
try:
|
|
299
|
+
import torch
|
|
300
|
+
if torch.cuda.is_available():
|
|
301
|
+
return "cuda"
|
|
302
|
+
elif torch.backends.mps.is_available():
|
|
303
|
+
return "mps"
|
|
304
|
+
else:
|
|
305
|
+
return "cpu"
|
|
306
|
+
except ImportError:
|
|
307
|
+
return "cpu"
|
|
308
|
+
|
|
309
|
+
def _resolve_model_path(self, config: ModelConfig) -> str:
|
|
310
|
+
"""解析模型路径,子类可覆盖"""
|
|
311
|
+
from pathlib import Path
|
|
312
|
+
|
|
313
|
+
model_path = config.model_id
|
|
314
|
+
if config.local_dir:
|
|
315
|
+
local_path = Path(config.local_dir) / config.model_id.split("/")[-1]
|
|
316
|
+
if local_path.exists():
|
|
317
|
+
model_path = str(local_path)
|
|
318
|
+
from loguru import logger
|
|
319
|
+
logger.info(f"Using local model path: {model_path}")
|
|
320
|
+
return model_path
|
|
321
|
+
|
|
322
|
+
def _detect_multimodal(self, model_path: str, config: ModelConfig) -> bool:
|
|
323
|
+
"""检测是否多模态模型,子类可覆盖"""
|
|
324
|
+
try:
|
|
325
|
+
from transformers import AutoConfig
|
|
326
|
+
model_config = AutoConfig.from_pretrained(
|
|
327
|
+
model_path, trust_remote_code=config.trust_remote_code
|
|
328
|
+
)
|
|
329
|
+
architectures = getattr(model_config, "architectures", []) or []
|
|
330
|
+
return any(
|
|
331
|
+
"VL" in arch or "Vision" in arch or "vision" in arch.lower()
|
|
332
|
+
for arch in architectures
|
|
333
|
+
)
|
|
334
|
+
except Exception:
|
|
335
|
+
return False
|
|
336
|
+
|
|
337
|
+
def _process_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]:
|
|
338
|
+
"""预处理消息,子类可覆盖进行自定义处理"""
|
|
339
|
+
return messages
|
|
340
|
+
|
|
341
|
+
def _process_image(self, image_url: str) -> "Image":
|
|
342
|
+
"""处理图片输入,子类可覆盖
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
image_url: 图片 URL,支持 base64、http/https、本地路径
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
PIL.Image 对象
|
|
349
|
+
"""
|
|
350
|
+
import base64
|
|
351
|
+
from io import BytesIO
|
|
352
|
+
from PIL import Image
|
|
353
|
+
|
|
354
|
+
if image_url.startswith("data:"):
|
|
355
|
+
# base64 图片
|
|
356
|
+
header, data = image_url.split(",", 1)
|
|
357
|
+
image_data = base64.b64decode(data)
|
|
358
|
+
return Image.open(BytesIO(image_data))
|
|
359
|
+
elif image_url.startswith(("http://", "https://")):
|
|
360
|
+
# URL 图片
|
|
361
|
+
import requests
|
|
362
|
+
response = requests.get(image_url, timeout=30)
|
|
363
|
+
return Image.open(BytesIO(response.content))
|
|
364
|
+
else:
|
|
365
|
+
# 本地文件
|
|
366
|
+
return Image.open(image_url)
|
|
367
|
+
|
|
368
|
+
# ============== 抽象方法 (子类必须实现) ==============
|
|
369
|
+
|
|
370
|
+
@abstractmethod
|
|
371
|
+
def _load_model_impl(self, model_path: str, config: ModelConfig) -> None:
|
|
372
|
+
"""加载模型实现
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
model_path: 解析后的模型路径
|
|
376
|
+
config: 模型配置
|
|
377
|
+
|
|
378
|
+
子类需要:
|
|
379
|
+
1. 加载模型到 self._model
|
|
380
|
+
2. 加载 tokenizer/processor 到相应属性
|
|
381
|
+
"""
|
|
382
|
+
pass
|
|
383
|
+
|
|
384
|
+
@abstractmethod
|
|
385
|
+
def _generate_impl(
|
|
386
|
+
self, messages: List[ChatMessage], config: GenerateConfig
|
|
387
|
+
) -> tuple[str, int, int]:
|
|
388
|
+
"""生成实现
|
|
389
|
+
|
|
390
|
+
Args:
|
|
391
|
+
messages: 预处理后的消息列表
|
|
392
|
+
config: 生成配置
|
|
393
|
+
|
|
394
|
+
Returns:
|
|
395
|
+
tuple: (生成文本, prompt_tokens, completion_tokens)
|
|
396
|
+
"""
|
|
397
|
+
pass
|
|
398
|
+
|
|
399
|
+
def _generate_stream_impl(
|
|
400
|
+
self, messages: List[ChatMessage], config: GenerateConfig
|
|
401
|
+
):
|
|
402
|
+
"""流式生成实现
|
|
403
|
+
|
|
404
|
+
默认实现调用 _generate_impl 并一次性返回。
|
|
405
|
+
子类可覆盖以实现真正的流式输出。
|
|
406
|
+
|
|
407
|
+
Yields:
|
|
408
|
+
str: 生成的 token
|
|
409
|
+
"""
|
|
410
|
+
text, _, _ = self._generate_impl(messages, config)
|
|
411
|
+
yield text
|