maque 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- maque/__init__.py +30 -0
- maque/__main__.py +926 -0
- maque/ai_platform/__init__.py +0 -0
- maque/ai_platform/crawl.py +45 -0
- maque/ai_platform/metrics.py +258 -0
- maque/ai_platform/nlp_preprocess.py +67 -0
- maque/ai_platform/webpage_screen_shot.py +195 -0
- maque/algorithms/__init__.py +78 -0
- maque/algorithms/bezier.py +15 -0
- maque/algorithms/bktree.py +117 -0
- maque/algorithms/core.py +104 -0
- maque/algorithms/hilbert.py +16 -0
- maque/algorithms/rate_function.py +92 -0
- maque/algorithms/transform.py +27 -0
- maque/algorithms/trie.py +272 -0
- maque/algorithms/utils.py +63 -0
- maque/algorithms/video.py +587 -0
- maque/api/__init__.py +1 -0
- maque/api/common.py +110 -0
- maque/api/fetch.py +26 -0
- maque/api/static/icon.png +0 -0
- maque/api/static/redoc.standalone.js +1782 -0
- maque/api/static/swagger-ui-bundle.js +3 -0
- maque/api/static/swagger-ui.css +3 -0
- maque/cli/__init__.py +1 -0
- maque/cli/clean_invisible_chars.py +324 -0
- maque/cli/core.py +34 -0
- maque/cli/groups/__init__.py +26 -0
- maque/cli/groups/config.py +205 -0
- maque/cli/groups/data.py +615 -0
- maque/cli/groups/doctor.py +259 -0
- maque/cli/groups/embedding.py +222 -0
- maque/cli/groups/git.py +29 -0
- maque/cli/groups/help.py +410 -0
- maque/cli/groups/llm.py +223 -0
- maque/cli/groups/mcp.py +241 -0
- maque/cli/groups/mllm.py +1795 -0
- maque/cli/groups/mllm_simple.py +60 -0
- maque/cli/groups/quant.py +210 -0
- maque/cli/groups/service.py +490 -0
- maque/cli/groups/system.py +570 -0
- maque/cli/mllm_run.py +1451 -0
- maque/cli/script.py +52 -0
- maque/cli/tree.py +49 -0
- maque/clustering/__init__.py +52 -0
- maque/clustering/analyzer.py +347 -0
- maque/clustering/clusterers.py +464 -0
- maque/clustering/sampler.py +134 -0
- maque/clustering/visualizer.py +205 -0
- maque/constant.py +13 -0
- maque/core.py +133 -0
- maque/cv/__init__.py +1 -0
- maque/cv/image.py +219 -0
- maque/cv/utils.py +68 -0
- maque/cv/video/__init__.py +3 -0
- maque/cv/video/keyframe_extractor.py +368 -0
- maque/embedding/__init__.py +43 -0
- maque/embedding/base.py +56 -0
- maque/embedding/multimodal.py +308 -0
- maque/embedding/server.py +523 -0
- maque/embedding/text.py +311 -0
- maque/git/__init__.py +24 -0
- maque/git/pure_git.py +912 -0
- maque/io/__init__.py +29 -0
- maque/io/core.py +38 -0
- maque/io/ops.py +194 -0
- maque/llm/__init__.py +111 -0
- maque/llm/backend.py +416 -0
- maque/llm/base.py +411 -0
- maque/llm/server.py +366 -0
- maque/mcp_server.py +1096 -0
- maque/mllm_data_processor_pipeline/__init__.py +17 -0
- maque/mllm_data_processor_pipeline/core.py +341 -0
- maque/mllm_data_processor_pipeline/example.py +291 -0
- maque/mllm_data_processor_pipeline/steps/__init__.py +56 -0
- maque/mllm_data_processor_pipeline/steps/data_alignment.py +267 -0
- maque/mllm_data_processor_pipeline/steps/data_loader.py +172 -0
- maque/mllm_data_processor_pipeline/steps/data_validation.py +304 -0
- maque/mllm_data_processor_pipeline/steps/format_conversion.py +411 -0
- maque/mllm_data_processor_pipeline/steps/mllm_annotation.py +331 -0
- maque/mllm_data_processor_pipeline/steps/mllm_refinement.py +446 -0
- maque/mllm_data_processor_pipeline/steps/result_validation.py +501 -0
- maque/mllm_data_processor_pipeline/web_app.py +317 -0
- maque/nlp/__init__.py +14 -0
- maque/nlp/ngram.py +9 -0
- maque/nlp/parser.py +63 -0
- maque/nlp/risk_matcher.py +543 -0
- maque/nlp/sentence_splitter.py +202 -0
- maque/nlp/simple_tradition_cvt.py +31 -0
- maque/performance/__init__.py +21 -0
- maque/performance/_measure_time.py +70 -0
- maque/performance/_profiler.py +367 -0
- maque/performance/_stat_memory.py +51 -0
- maque/pipelines/__init__.py +15 -0
- maque/pipelines/clustering.py +252 -0
- maque/quantization/__init__.py +42 -0
- maque/quantization/auto_round.py +120 -0
- maque/quantization/base.py +145 -0
- maque/quantization/bitsandbytes.py +127 -0
- maque/quantization/llm_compressor.py +102 -0
- maque/retriever/__init__.py +35 -0
- maque/retriever/chroma.py +654 -0
- maque/retriever/document.py +140 -0
- maque/retriever/milvus.py +1140 -0
- maque/table_ops/__init__.py +1 -0
- maque/table_ops/core.py +133 -0
- maque/table_viewer/__init__.py +4 -0
- maque/table_viewer/download_assets.py +57 -0
- maque/table_viewer/server.py +698 -0
- maque/table_viewer/static/element-plus-icons.js +5791 -0
- maque/table_viewer/static/element-plus.css +1 -0
- maque/table_viewer/static/element-plus.js +65236 -0
- maque/table_viewer/static/main.css +268 -0
- maque/table_viewer/static/main.js +669 -0
- maque/table_viewer/static/vue.global.js +18227 -0
- maque/table_viewer/templates/index.html +401 -0
- maque/utils/__init__.py +56 -0
- maque/utils/color.py +68 -0
- maque/utils/color_string.py +45 -0
- maque/utils/compress.py +66 -0
- maque/utils/constant.py +183 -0
- maque/utils/core.py +261 -0
- maque/utils/cursor.py +143 -0
- maque/utils/distance.py +58 -0
- maque/utils/docker.py +96 -0
- maque/utils/downloads.py +51 -0
- maque/utils/excel_helper.py +542 -0
- maque/utils/helper_metrics.py +121 -0
- maque/utils/helper_parser.py +168 -0
- maque/utils/net.py +64 -0
- maque/utils/nvidia_stat.py +140 -0
- maque/utils/ops.py +53 -0
- maque/utils/packages.py +31 -0
- maque/utils/path.py +57 -0
- maque/utils/tar.py +260 -0
- maque/utils/untar.py +129 -0
- maque/web/__init__.py +0 -0
- maque/web/image_downloader.py +1410 -0
- maque-0.2.1.dist-info/METADATA +450 -0
- maque-0.2.1.dist-info/RECORD +143 -0
- maque-0.2.1.dist-info/WHEEL +4 -0
- maque-0.2.1.dist-info/entry_points.txt +3 -0
- maque-0.2.1.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,523 @@
|
|
|
1
|
+
#! /usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
Embedding Server - 兼容 OpenAI/vLLM 的 Embedding API 服务
|
|
6
|
+
|
|
7
|
+
支持:
|
|
8
|
+
- jina-embeddings-v3 的 task 类型
|
|
9
|
+
- 多模型动态加载
|
|
10
|
+
- 批处理优化
|
|
11
|
+
- GPU/CPU 自动检测
|
|
12
|
+
"""
|
|
13
|
+
|
|
14
|
+
import asyncio
|
|
15
|
+
import base64
|
|
16
|
+
import struct
|
|
17
|
+
import time
|
|
18
|
+
from contextlib import asynccontextmanager
|
|
19
|
+
from dataclasses import dataclass, field
|
|
20
|
+
from enum import Enum
|
|
21
|
+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union
|
|
22
|
+
|
|
23
|
+
import numpy as np
|
|
24
|
+
from fastapi import FastAPI, HTTPException
|
|
25
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
26
|
+
from pydantic import BaseModel, Field
|
|
27
|
+
from loguru import logger
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
from sentence_transformers import SentenceTransformer
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# ============== Pydantic Models (OpenAI Compatible) ==============
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class TaskType(str, Enum):
|
|
38
|
+
"""jina-embeddings-v3 支持的任务类型"""
|
|
39
|
+
|
|
40
|
+
TEXT_MATCHING = "text-matching"
|
|
41
|
+
RETRIEVAL_QUERY = "retrieval.query"
|
|
42
|
+
RETRIEVAL_PASSAGE = "retrieval.passage"
|
|
43
|
+
CLASSIFICATION = "classification"
|
|
44
|
+
SEPARATION = "separation"
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class EmbeddingRequest(BaseModel):
|
|
48
|
+
"""Embedding 请求 - 兼容 OpenAI 格式 + 扩展字段"""
|
|
49
|
+
|
|
50
|
+
model: str = Field(..., description="模型名称")
|
|
51
|
+
input: Union[str, List[str]] = Field(..., description="输入文本")
|
|
52
|
+
encoding_format: Literal["float", "base64"] = Field(
|
|
53
|
+
default="float", description="输出格式"
|
|
54
|
+
)
|
|
55
|
+
dimensions: Optional[int] = Field(
|
|
56
|
+
default=None, description="输出维度 (Matryoshka)"
|
|
57
|
+
)
|
|
58
|
+
# 扩展字段
|
|
59
|
+
task: Optional[TaskType] = Field(
|
|
60
|
+
default=None, description="任务类型 (jina-v3)"
|
|
61
|
+
)
|
|
62
|
+
user: Optional[str] = Field(default=None, description="用户标识")
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class EmbeddingObject(BaseModel):
|
|
66
|
+
"""单个 Embedding 结果"""
|
|
67
|
+
|
|
68
|
+
object: Literal["embedding"] = "embedding"
|
|
69
|
+
embedding: Union[List[float], str] = Field(..., description="向量或 base64")
|
|
70
|
+
index: int = Field(..., description="索引")
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class UsageInfo(BaseModel):
|
|
74
|
+
"""Token 使用统计"""
|
|
75
|
+
|
|
76
|
+
prompt_tokens: int = 0
|
|
77
|
+
total_tokens: int = 0
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class EmbeddingResponse(BaseModel):
|
|
81
|
+
"""Embedding 响应 - 兼容 OpenAI 格式"""
|
|
82
|
+
|
|
83
|
+
object: Literal["list"] = "list"
|
|
84
|
+
data: List[EmbeddingObject] = Field(default_factory=list)
|
|
85
|
+
model: str = ""
|
|
86
|
+
usage: UsageInfo = Field(default_factory=UsageInfo)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class ModelInfo(BaseModel):
|
|
90
|
+
"""模型信息"""
|
|
91
|
+
|
|
92
|
+
id: str
|
|
93
|
+
object: Literal["model"] = "model"
|
|
94
|
+
created: int = 0
|
|
95
|
+
owned_by: str = "local"
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class ModelsResponse(BaseModel):
|
|
99
|
+
"""模型列表响应"""
|
|
100
|
+
|
|
101
|
+
object: Literal["list"] = "list"
|
|
102
|
+
data: List[ModelInfo] = Field(default_factory=list)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# ============== Model Backend ==============
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclass
|
|
109
|
+
class ModelConfig:
|
|
110
|
+
"""模型配置"""
|
|
111
|
+
|
|
112
|
+
model_id: str
|
|
113
|
+
trust_remote_code: bool = True
|
|
114
|
+
device: Optional[str] = None # None = auto
|
|
115
|
+
default_task: Optional[str] = None
|
|
116
|
+
default_dimensions: Optional[int] = None
|
|
117
|
+
local_dir: Optional[str] = None # 本地模型目录
|
|
118
|
+
torch_dtype: Optional[str] = None # float16/bfloat16/float32
|
|
119
|
+
attn_implementation: Optional[str] = None # eager/sdpa/flash_attention_2
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
class EmbeddingBackend:
|
|
123
|
+
"""Embedding 模型后端 - 使用 SentenceTransformers"""
|
|
124
|
+
|
|
125
|
+
def __init__(self):
|
|
126
|
+
self._models: Dict[str, "SentenceTransformer"] = {}
|
|
127
|
+
self._configs: Dict[str, ModelConfig] = {}
|
|
128
|
+
self._device: Optional[str] = None
|
|
129
|
+
self._lock = asyncio.Lock()
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def device(self) -> str:
|
|
133
|
+
"""检测设备类型: CUDA > MPS (Apple Silicon) > CPU"""
|
|
134
|
+
if self._device is None:
|
|
135
|
+
try:
|
|
136
|
+
import torch
|
|
137
|
+
|
|
138
|
+
if torch.cuda.is_available():
|
|
139
|
+
self._device = "cuda"
|
|
140
|
+
elif torch.backends.mps.is_available():
|
|
141
|
+
self._device = "mps"
|
|
142
|
+
else:
|
|
143
|
+
self._device = "cpu"
|
|
144
|
+
except ImportError:
|
|
145
|
+
self._device = "cpu"
|
|
146
|
+
return self._device
|
|
147
|
+
|
|
148
|
+
def _get_model_key(self, model_id: str) -> str:
|
|
149
|
+
"""标准化模型 key"""
|
|
150
|
+
return model_id.lower().replace("/", "_")
|
|
151
|
+
|
|
152
|
+
async def load_model(self, config: ModelConfig) -> None:
|
|
153
|
+
"""加载模型"""
|
|
154
|
+
key = self._get_model_key(config.model_id)
|
|
155
|
+
|
|
156
|
+
async with self._lock:
|
|
157
|
+
if key in self._models:
|
|
158
|
+
return
|
|
159
|
+
|
|
160
|
+
logger.info(f"Loading model: {config.model_id} on {self.device}")
|
|
161
|
+
|
|
162
|
+
# 在线程池中加载模型
|
|
163
|
+
loop = asyncio.get_event_loop()
|
|
164
|
+
model = await loop.run_in_executor(
|
|
165
|
+
None, self._load_model_sync, config
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
self._models[key] = model
|
|
169
|
+
self._configs[key] = config
|
|
170
|
+
logger.info(f"Model loaded: {config.model_id}")
|
|
171
|
+
|
|
172
|
+
def _load_model_sync(self, config: ModelConfig) -> "SentenceTransformer":
|
|
173
|
+
"""同步加载模型"""
|
|
174
|
+
from pathlib import Path
|
|
175
|
+
from sentence_transformers import SentenceTransformer
|
|
176
|
+
|
|
177
|
+
device = config.device or self.device
|
|
178
|
+
model_path = config.model_id
|
|
179
|
+
|
|
180
|
+
# 如果指定了本地目录,直接使用本地路径
|
|
181
|
+
if config.local_dir:
|
|
182
|
+
local_path = Path(config.local_dir) / config.model_id.split("/")[-1]
|
|
183
|
+
if local_path.exists():
|
|
184
|
+
model_path = str(local_path)
|
|
185
|
+
logger.info(f"Using local model path: {model_path}")
|
|
186
|
+
|
|
187
|
+
# 构建 model_kwargs
|
|
188
|
+
model_kwargs = {}
|
|
189
|
+
if config.torch_dtype:
|
|
190
|
+
import torch
|
|
191
|
+
torch_dtype = getattr(torch, config.torch_dtype, None)
|
|
192
|
+
if torch_dtype:
|
|
193
|
+
model_kwargs["torch_dtype"] = torch_dtype
|
|
194
|
+
logger.info(f"Using dtype: {config.torch_dtype}")
|
|
195
|
+
if config.attn_implementation:
|
|
196
|
+
model_kwargs["attn_implementation"] = config.attn_implementation
|
|
197
|
+
logger.info(f"Using attn: {config.attn_implementation}")
|
|
198
|
+
|
|
199
|
+
return SentenceTransformer(
|
|
200
|
+
model_path,
|
|
201
|
+
trust_remote_code=config.trust_remote_code,
|
|
202
|
+
device=device,
|
|
203
|
+
model_kwargs=model_kwargs if model_kwargs else None,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def get_model(self, model_id: str) -> Optional["SentenceTransformer"]:
|
|
207
|
+
"""获取已加载的模型"""
|
|
208
|
+
key = self._get_model_key(model_id)
|
|
209
|
+
return self._models.get(key)
|
|
210
|
+
|
|
211
|
+
def get_config(self, model_id: str) -> Optional[ModelConfig]:
|
|
212
|
+
"""获取模型配置"""
|
|
213
|
+
key = self._get_model_key(model_id)
|
|
214
|
+
return self._configs.get(key)
|
|
215
|
+
|
|
216
|
+
def list_models(self) -> List[str]:
|
|
217
|
+
"""列出已加载的模型"""
|
|
218
|
+
return [cfg.model_id for cfg in self._configs.values()]
|
|
219
|
+
|
|
220
|
+
async def encode(
|
|
221
|
+
self,
|
|
222
|
+
model_id: str,
|
|
223
|
+
texts: List[str],
|
|
224
|
+
task: Optional[str] = None,
|
|
225
|
+
dimensions: Optional[int] = None,
|
|
226
|
+
) -> np.ndarray:
|
|
227
|
+
"""编码文本"""
|
|
228
|
+
model = self.get_model(model_id)
|
|
229
|
+
if model is None:
|
|
230
|
+
raise ValueError(f"Model not loaded: {model_id}")
|
|
231
|
+
|
|
232
|
+
config = self.get_config(model_id)
|
|
233
|
+
task = task or (config.default_task if config else None)
|
|
234
|
+
dimensions = dimensions or (config.default_dimensions if config else None)
|
|
235
|
+
|
|
236
|
+
# 构建 encode 参数
|
|
237
|
+
encode_kwargs = {}
|
|
238
|
+
if task:
|
|
239
|
+
encode_kwargs["task"] = task
|
|
240
|
+
encode_kwargs["prompt_name"] = task
|
|
241
|
+
if dimensions:
|
|
242
|
+
encode_kwargs["truncate_dim"] = dimensions
|
|
243
|
+
|
|
244
|
+
# 在线程池中执行编码
|
|
245
|
+
loop = asyncio.get_event_loop()
|
|
246
|
+
embeddings = await loop.run_in_executor(
|
|
247
|
+
None,
|
|
248
|
+
lambda: model.encode(
|
|
249
|
+
texts,
|
|
250
|
+
convert_to_numpy=True,
|
|
251
|
+
normalize_embeddings=True,
|
|
252
|
+
**encode_kwargs,
|
|
253
|
+
),
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
return embeddings
|
|
257
|
+
|
|
258
|
+
def unload_model(self, model_id: str) -> bool:
|
|
259
|
+
"""卸载模型"""
|
|
260
|
+
key = self._get_model_key(model_id)
|
|
261
|
+
if key in self._models:
|
|
262
|
+
del self._models[key]
|
|
263
|
+
del self._configs[key]
|
|
264
|
+
return True
|
|
265
|
+
return False
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
# ============== Server ==============
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class EmbeddingServer:
|
|
272
|
+
"""Embedding 服务"""
|
|
273
|
+
|
|
274
|
+
def __init__(
|
|
275
|
+
self,
|
|
276
|
+
models: Optional[List[str]] = None,
|
|
277
|
+
default_model: Optional[str] = None,
|
|
278
|
+
device: Optional[str] = None,
|
|
279
|
+
local_dir: Optional[str] = None,
|
|
280
|
+
dtype: Optional[str] = None,
|
|
281
|
+
attn: Optional[str] = None,
|
|
282
|
+
):
|
|
283
|
+
"""
|
|
284
|
+
初始化服务
|
|
285
|
+
|
|
286
|
+
Args:
|
|
287
|
+
models: 预加载的模型列表
|
|
288
|
+
default_model: 默认模型
|
|
289
|
+
device: 设备 (cuda/cpu/auto)
|
|
290
|
+
local_dir: 本地模型目录
|
|
291
|
+
dtype: 数据类型 (float16/bfloat16/float32)
|
|
292
|
+
attn: 注意力实现 (eager/sdpa/flash_attention_2)
|
|
293
|
+
"""
|
|
294
|
+
self.backend = EmbeddingBackend()
|
|
295
|
+
if device:
|
|
296
|
+
self.backend._device = device
|
|
297
|
+
|
|
298
|
+
self._preload_models = models or []
|
|
299
|
+
self._default_model = default_model
|
|
300
|
+
self._local_dir = local_dir
|
|
301
|
+
self._dtype = dtype
|
|
302
|
+
self._attn = attn
|
|
303
|
+
self.app = self._create_app()
|
|
304
|
+
|
|
305
|
+
def _create_app(self) -> FastAPI:
|
|
306
|
+
"""创建 FastAPI 应用"""
|
|
307
|
+
|
|
308
|
+
@asynccontextmanager
|
|
309
|
+
async def lifespan(app: FastAPI):
|
|
310
|
+
# 启动时预加载模型
|
|
311
|
+
for model_id in self._preload_models:
|
|
312
|
+
try:
|
|
313
|
+
config = ModelConfig(
|
|
314
|
+
model_id=model_id,
|
|
315
|
+
local_dir=self._local_dir,
|
|
316
|
+
torch_dtype=self._dtype,
|
|
317
|
+
attn_implementation=self._attn,
|
|
318
|
+
)
|
|
319
|
+
await self.backend.load_model(config)
|
|
320
|
+
except Exception as e:
|
|
321
|
+
logger.error(f"Failed to load {model_id}: {e}")
|
|
322
|
+
|
|
323
|
+
if not self._default_model and self._preload_models:
|
|
324
|
+
self._default_model = self._preload_models[0]
|
|
325
|
+
|
|
326
|
+
yield
|
|
327
|
+
# 关闭时清理
|
|
328
|
+
|
|
329
|
+
app = FastAPI(
|
|
330
|
+
title="Embedding Server",
|
|
331
|
+
description="OpenAI Compatible Embedding API with Task Support",
|
|
332
|
+
version="1.0.0",
|
|
333
|
+
lifespan=lifespan,
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
app.add_middleware(
|
|
337
|
+
CORSMiddleware,
|
|
338
|
+
allow_origins=["*"],
|
|
339
|
+
allow_credentials=True,
|
|
340
|
+
allow_methods=["*"],
|
|
341
|
+
allow_headers=["*"],
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
self._register_routes(app)
|
|
345
|
+
return app
|
|
346
|
+
|
|
347
|
+
def _register_routes(self, app: FastAPI) -> None:
|
|
348
|
+
"""注册路由"""
|
|
349
|
+
|
|
350
|
+
@app.get("/health")
|
|
351
|
+
async def health():
|
|
352
|
+
return {"status": "ok"}
|
|
353
|
+
|
|
354
|
+
@app.get("/v1/models", response_model=ModelsResponse)
|
|
355
|
+
async def list_models():
|
|
356
|
+
models = self.backend.list_models()
|
|
357
|
+
return ModelsResponse(
|
|
358
|
+
data=[
|
|
359
|
+
ModelInfo(id=m, created=int(time.time()))
|
|
360
|
+
for m in models
|
|
361
|
+
]
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
|
365
|
+
async def create_embeddings(request: EmbeddingRequest):
|
|
366
|
+
return await self._handle_embedding(request)
|
|
367
|
+
|
|
368
|
+
# 兼容 vLLM 的路由
|
|
369
|
+
@app.post("/embeddings", response_model=EmbeddingResponse)
|
|
370
|
+
async def create_embeddings_alt(request: EmbeddingRequest):
|
|
371
|
+
return await self._handle_embedding(request)
|
|
372
|
+
|
|
373
|
+
async def _handle_embedding(
|
|
374
|
+
self, request: EmbeddingRequest
|
|
375
|
+
) -> EmbeddingResponse:
|
|
376
|
+
"""处理 embedding 请求"""
|
|
377
|
+
# 确定模型
|
|
378
|
+
model_id = request.model
|
|
379
|
+
if not self.backend.get_model(model_id):
|
|
380
|
+
# 尝试自动加载
|
|
381
|
+
try:
|
|
382
|
+
await self.backend.load_model(ModelConfig(model_id=model_id))
|
|
383
|
+
except Exception as e:
|
|
384
|
+
raise HTTPException(
|
|
385
|
+
status_code=400, detail=f"Failed to load model: {e}"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
# 准备输入
|
|
389
|
+
texts = request.input if isinstance(request.input, list) else [request.input]
|
|
390
|
+
|
|
391
|
+
# 获取任务类型
|
|
392
|
+
task = request.task.value if request.task else None
|
|
393
|
+
|
|
394
|
+
try:
|
|
395
|
+
# 编码
|
|
396
|
+
embeddings = await self.backend.encode(
|
|
397
|
+
model_id=model_id,
|
|
398
|
+
texts=texts,
|
|
399
|
+
task=task,
|
|
400
|
+
dimensions=request.dimensions,
|
|
401
|
+
)
|
|
402
|
+
|
|
403
|
+
# 构建响应
|
|
404
|
+
data = []
|
|
405
|
+
for i, emb in enumerate(embeddings):
|
|
406
|
+
if request.encoding_format == "base64":
|
|
407
|
+
# 转换为 base64
|
|
408
|
+
emb_bytes = struct.pack(f"{len(emb)}f", *emb.tolist())
|
|
409
|
+
emb_value = base64.b64encode(emb_bytes).decode("utf-8")
|
|
410
|
+
else:
|
|
411
|
+
emb_value = emb.tolist()
|
|
412
|
+
|
|
413
|
+
data.append(EmbeddingObject(embedding=emb_value, index=i))
|
|
414
|
+
|
|
415
|
+
# 估算 token 数
|
|
416
|
+
total_chars = sum(len(t) for t in texts)
|
|
417
|
+
estimated_tokens = total_chars // 4
|
|
418
|
+
|
|
419
|
+
return EmbeddingResponse(
|
|
420
|
+
data=data,
|
|
421
|
+
model=model_id,
|
|
422
|
+
usage=UsageInfo(
|
|
423
|
+
prompt_tokens=estimated_tokens,
|
|
424
|
+
total_tokens=estimated_tokens,
|
|
425
|
+
),
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
except Exception as e:
|
|
429
|
+
logger.exception(f"Embedding error: {e}")
|
|
430
|
+
raise HTTPException(status_code=500, detail=str(e))
|
|
431
|
+
|
|
432
|
+
def run(
|
|
433
|
+
self,
|
|
434
|
+
host: str = "0.0.0.0",
|
|
435
|
+
port: int = 8000,
|
|
436
|
+
workers: int = 1,
|
|
437
|
+
**kwargs,
|
|
438
|
+
) -> None:
|
|
439
|
+
"""运行服务"""
|
|
440
|
+
import uvicorn
|
|
441
|
+
|
|
442
|
+
uvicorn.run(
|
|
443
|
+
self.app,
|
|
444
|
+
host=host,
|
|
445
|
+
port=port,
|
|
446
|
+
workers=workers,
|
|
447
|
+
**kwargs,
|
|
448
|
+
)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
def create_server(
|
|
452
|
+
models: Optional[List[str]] = None,
|
|
453
|
+
default_model: Optional[str] = None,
|
|
454
|
+
device: Optional[str] = None,
|
|
455
|
+
local_dir: Optional[str] = None,
|
|
456
|
+
dtype: Optional[str] = None,
|
|
457
|
+
attn: Optional[str] = None,
|
|
458
|
+
) -> EmbeddingServer:
|
|
459
|
+
"""创建 Embedding 服务实例
|
|
460
|
+
|
|
461
|
+
Args:
|
|
462
|
+
models: 预加载的模型列表
|
|
463
|
+
default_model: 默认模型
|
|
464
|
+
device: 设备 (cuda/cpu)
|
|
465
|
+
local_dir: 本地模型目录
|
|
466
|
+
dtype: 数据类型 (float16/bfloat16/float32)
|
|
467
|
+
attn: 注意力实现 (eager/sdpa/flash_attention_2)
|
|
468
|
+
"""
|
|
469
|
+
return EmbeddingServer(
|
|
470
|
+
models=models,
|
|
471
|
+
default_model=default_model,
|
|
472
|
+
device=device,
|
|
473
|
+
local_dir=local_dir,
|
|
474
|
+
dtype=dtype,
|
|
475
|
+
attn=attn,
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
# CLI 入口
|
|
480
|
+
def main():
|
|
481
|
+
"""命令行入口"""
|
|
482
|
+
import argparse
|
|
483
|
+
|
|
484
|
+
parser = argparse.ArgumentParser(description="Embedding Server")
|
|
485
|
+
parser.add_argument(
|
|
486
|
+
"--model",
|
|
487
|
+
"-m",
|
|
488
|
+
type=str,
|
|
489
|
+
nargs="+",
|
|
490
|
+
default=["jinaai/jina-embeddings-v3"],
|
|
491
|
+
help="Models to load",
|
|
492
|
+
)
|
|
493
|
+
parser.add_argument("--host", type=str, default="0.0.0.0", help="Host")
|
|
494
|
+
parser.add_argument("--port", "-p", type=int, default=8000, help="Port")
|
|
495
|
+
parser.add_argument(
|
|
496
|
+
"--device", type=str, default=None, help="Device (cuda/cpu)"
|
|
497
|
+
)
|
|
498
|
+
parser.add_argument(
|
|
499
|
+
"--dtype",
|
|
500
|
+
type=str,
|
|
501
|
+
default=None,
|
|
502
|
+
help="Model precision (bf16/fp16/f16/fp32/f32, default: fp32)",
|
|
503
|
+
)
|
|
504
|
+
parser.add_argument(
|
|
505
|
+
"--local-dir",
|
|
506
|
+
type=str,
|
|
507
|
+
default=None,
|
|
508
|
+
help="Local models directory (auto setup HF cache symlinks)",
|
|
509
|
+
)
|
|
510
|
+
|
|
511
|
+
args = parser.parse_args()
|
|
512
|
+
|
|
513
|
+
server = create_server(
|
|
514
|
+
models=args.model,
|
|
515
|
+
device=args.device,
|
|
516
|
+
local_dir=args.local_dir,
|
|
517
|
+
torch_dtype=args.dtype,
|
|
518
|
+
)
|
|
519
|
+
server.run(host=args.host, port=args.port)
|
|
520
|
+
|
|
521
|
+
|
|
522
|
+
if __name__ == "__main__":
|
|
523
|
+
main()
|