mineru 2.2.1__py3-none-any.whl → 2.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mineru/backend/pipeline/batch_analyze.py +1 -1
- mineru/backend/pipeline/pipeline_middle_json_mkcontent.py +3 -3
- mineru/backend/vlm/model_output_to_middle_json.py +123 -0
- mineru/backend/vlm/vlm_analyze.py +97 -16
- mineru/backend/vlm/vlm_magic_model.py +201 -135
- mineru/backend/vlm/vlm_middle_json_mkcontent.py +52 -11
- mineru/cli/client.py +6 -5
- mineru/cli/common.py +17 -16
- mineru/cli/fast_api.py +9 -7
- mineru/cli/gradio_app.py +15 -16
- mineru/cli/vlm_vllm_server.py +4 -0
- mineru/model/table/rec/unet_table/main.py +10 -2
- mineru/model/vlm_vllm_model/__init__.py +0 -0
- mineru/model/vlm_vllm_model/server.py +51 -0
- mineru/resources/header.html +10 -2
- mineru/utils/draw_bbox.py +32 -10
- mineru/utils/enum_class.py +16 -2
- mineru/utils/guess_suffix_or_lang.py +20 -0
- mineru/utils/span_block_fix.py +4 -2
- mineru/version.py +1 -1
- {mineru-2.2.1.dist-info → mineru-2.5.0.dist-info}/METADATA +71 -23
- {mineru-2.2.1.dist-info → mineru-2.5.0.dist-info}/RECORD +26 -39
- {mineru-2.2.1.dist-info → mineru-2.5.0.dist-info}/entry_points.txt +1 -1
- mineru/backend/vlm/base_predictor.py +0 -186
- mineru/backend/vlm/hf_predictor.py +0 -217
- mineru/backend/vlm/predictor.py +0 -111
- mineru/backend/vlm/sglang_client_predictor.py +0 -443
- mineru/backend/vlm/sglang_engine_predictor.py +0 -246
- mineru/backend/vlm/token_to_middle_json.py +0 -122
- mineru/backend/vlm/utils.py +0 -40
- mineru/cli/vlm_sglang_server.py +0 -4
- mineru/model/vlm_hf_model/__init__.py +0 -9
- mineru/model/vlm_hf_model/configuration_mineru2.py +0 -38
- mineru/model/vlm_hf_model/image_processing_mineru2.py +0 -269
- mineru/model/vlm_hf_model/modeling_mineru2.py +0 -449
- mineru/model/vlm_sglang_model/__init__.py +0 -14
- mineru/model/vlm_sglang_model/engine.py +0 -264
- mineru/model/vlm_sglang_model/image_processor.py +0 -213
- mineru/model/vlm_sglang_model/logit_processor.py +0 -90
- mineru/model/vlm_sglang_model/model.py +0 -453
- mineru/model/vlm_sglang_model/server.py +0 -75
- {mineru-2.2.1.dist-info → mineru-2.5.0.dist-info}/WHEEL +0 -0
- {mineru-2.2.1.dist-info → mineru-2.5.0.dist-info}/licenses/LICENSE.md +0 -0
- {mineru-2.2.1.dist-info → mineru-2.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,443 +0,0 @@
|
|
|
1
|
-
import asyncio
|
|
2
|
-
import json
|
|
3
|
-
import re
|
|
4
|
-
from base64 import b64encode
|
|
5
|
-
from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union
|
|
6
|
-
|
|
7
|
-
import httpx
|
|
8
|
-
|
|
9
|
-
from .base_predictor import (
|
|
10
|
-
DEFAULT_MAX_NEW_TOKENS,
|
|
11
|
-
DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
|
12
|
-
DEFAULT_PRESENCE_PENALTY,
|
|
13
|
-
DEFAULT_REPETITION_PENALTY,
|
|
14
|
-
DEFAULT_TEMPERATURE,
|
|
15
|
-
DEFAULT_TOP_K,
|
|
16
|
-
DEFAULT_TOP_P,
|
|
17
|
-
BasePredictor,
|
|
18
|
-
)
|
|
19
|
-
from .utils import aio_load_resource, load_resource
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
class SglangClientPredictor(BasePredictor):
|
|
23
|
-
def __init__(
|
|
24
|
-
self,
|
|
25
|
-
server_url: str,
|
|
26
|
-
temperature: float = DEFAULT_TEMPERATURE,
|
|
27
|
-
top_p: float = DEFAULT_TOP_P,
|
|
28
|
-
top_k: int = DEFAULT_TOP_K,
|
|
29
|
-
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
|
|
30
|
-
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
|
|
31
|
-
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
|
32
|
-
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
|
33
|
-
http_timeout: int = 600,
|
|
34
|
-
) -> None:
|
|
35
|
-
super().__init__(
|
|
36
|
-
temperature=temperature,
|
|
37
|
-
top_p=top_p,
|
|
38
|
-
top_k=top_k,
|
|
39
|
-
repetition_penalty=repetition_penalty,
|
|
40
|
-
presence_penalty=presence_penalty,
|
|
41
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
42
|
-
max_new_tokens=max_new_tokens,
|
|
43
|
-
)
|
|
44
|
-
self.http_timeout = http_timeout
|
|
45
|
-
|
|
46
|
-
base_url = self.get_base_url(server_url)
|
|
47
|
-
self.check_server_health(base_url)
|
|
48
|
-
self.model_path = self.get_model_path(base_url)
|
|
49
|
-
self.server_url = f"{base_url}/generate"
|
|
50
|
-
|
|
51
|
-
@staticmethod
|
|
52
|
-
def get_base_url(server_url: str) -> str:
|
|
53
|
-
matched = re.match(r"^(https?://[^/]+)", server_url)
|
|
54
|
-
if not matched:
|
|
55
|
-
raise ValueError(f"Invalid server URL: {server_url}")
|
|
56
|
-
return matched.group(1)
|
|
57
|
-
|
|
58
|
-
def check_server_health(self, base_url: str):
|
|
59
|
-
try:
|
|
60
|
-
response = httpx.get(f"{base_url}/health_generate", timeout=self.http_timeout)
|
|
61
|
-
except httpx.ConnectError:
|
|
62
|
-
raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
|
|
63
|
-
if response.status_code != 200:
|
|
64
|
-
raise RuntimeError(
|
|
65
|
-
f"Server {base_url} is not healthy. Status code: {response.status_code}, response body: {response.text}"
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
def get_model_path(self, base_url: str) -> str:
|
|
69
|
-
try:
|
|
70
|
-
response = httpx.get(f"{base_url}/get_model_info", timeout=self.http_timeout)
|
|
71
|
-
except httpx.ConnectError:
|
|
72
|
-
raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
|
|
73
|
-
if response.status_code != 200:
|
|
74
|
-
raise RuntimeError(
|
|
75
|
-
f"Failed to get model info from {base_url}. Status code: {response.status_code}, response body: {response.text}"
|
|
76
|
-
)
|
|
77
|
-
return response.json()["model_path"]
|
|
78
|
-
|
|
79
|
-
def build_sampling_params(
|
|
80
|
-
self,
|
|
81
|
-
temperature: Optional[float],
|
|
82
|
-
top_p: Optional[float],
|
|
83
|
-
top_k: Optional[int],
|
|
84
|
-
repetition_penalty: Optional[float],
|
|
85
|
-
presence_penalty: Optional[float],
|
|
86
|
-
no_repeat_ngram_size: Optional[int],
|
|
87
|
-
max_new_tokens: Optional[int],
|
|
88
|
-
) -> dict:
|
|
89
|
-
if temperature is None:
|
|
90
|
-
temperature = self.temperature
|
|
91
|
-
if top_p is None:
|
|
92
|
-
top_p = self.top_p
|
|
93
|
-
if top_k is None:
|
|
94
|
-
top_k = self.top_k
|
|
95
|
-
if repetition_penalty is None:
|
|
96
|
-
repetition_penalty = self.repetition_penalty
|
|
97
|
-
if presence_penalty is None:
|
|
98
|
-
presence_penalty = self.presence_penalty
|
|
99
|
-
if no_repeat_ngram_size is None:
|
|
100
|
-
no_repeat_ngram_size = self.no_repeat_ngram_size
|
|
101
|
-
if max_new_tokens is None:
|
|
102
|
-
max_new_tokens = self.max_new_tokens
|
|
103
|
-
|
|
104
|
-
# see SamplingParams for more details
|
|
105
|
-
return {
|
|
106
|
-
"temperature": temperature,
|
|
107
|
-
"top_p": top_p,
|
|
108
|
-
"top_k": top_k,
|
|
109
|
-
"repetition_penalty": repetition_penalty,
|
|
110
|
-
"presence_penalty": presence_penalty,
|
|
111
|
-
"custom_params": {
|
|
112
|
-
"no_repeat_ngram_size": no_repeat_ngram_size,
|
|
113
|
-
},
|
|
114
|
-
"max_new_tokens": max_new_tokens,
|
|
115
|
-
"skip_special_tokens": False,
|
|
116
|
-
}
|
|
117
|
-
|
|
118
|
-
def build_request_body(
|
|
119
|
-
self,
|
|
120
|
-
image: bytes,
|
|
121
|
-
prompt: str,
|
|
122
|
-
sampling_params: dict,
|
|
123
|
-
) -> dict:
|
|
124
|
-
image_base64 = b64encode(image).decode("utf-8")
|
|
125
|
-
return {
|
|
126
|
-
"text": prompt,
|
|
127
|
-
"image_data": image_base64,
|
|
128
|
-
"sampling_params": sampling_params,
|
|
129
|
-
"modalities": ["image"],
|
|
130
|
-
}
|
|
131
|
-
|
|
132
|
-
def predict(
|
|
133
|
-
self,
|
|
134
|
-
image: str | bytes,
|
|
135
|
-
prompt: str = "",
|
|
136
|
-
temperature: Optional[float] = None,
|
|
137
|
-
top_p: Optional[float] = None,
|
|
138
|
-
top_k: Optional[int] = None,
|
|
139
|
-
repetition_penalty: Optional[float] = None,
|
|
140
|
-
presence_penalty: Optional[float] = None,
|
|
141
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
142
|
-
max_new_tokens: Optional[int] = None,
|
|
143
|
-
) -> str:
|
|
144
|
-
prompt = self.build_prompt(prompt)
|
|
145
|
-
|
|
146
|
-
sampling_params = self.build_sampling_params(
|
|
147
|
-
temperature=temperature,
|
|
148
|
-
top_p=top_p,
|
|
149
|
-
top_k=top_k,
|
|
150
|
-
repetition_penalty=repetition_penalty,
|
|
151
|
-
presence_penalty=presence_penalty,
|
|
152
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
153
|
-
max_new_tokens=max_new_tokens,
|
|
154
|
-
)
|
|
155
|
-
|
|
156
|
-
if isinstance(image, str):
|
|
157
|
-
image = load_resource(image)
|
|
158
|
-
|
|
159
|
-
request_body = self.build_request_body(image, prompt, sampling_params)
|
|
160
|
-
response = httpx.post(self.server_url, json=request_body, timeout=self.http_timeout)
|
|
161
|
-
response_body = response.json()
|
|
162
|
-
return response_body["text"]
|
|
163
|
-
|
|
164
|
-
def batch_predict(
|
|
165
|
-
self,
|
|
166
|
-
images: List[str] | List[bytes],
|
|
167
|
-
prompts: Union[List[str], str] = "",
|
|
168
|
-
temperature: Optional[float] = None,
|
|
169
|
-
top_p: Optional[float] = None,
|
|
170
|
-
top_k: Optional[int] = None,
|
|
171
|
-
repetition_penalty: Optional[float] = None,
|
|
172
|
-
presence_penalty: Optional[float] = None,
|
|
173
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
174
|
-
max_new_tokens: Optional[int] = None,
|
|
175
|
-
max_concurrency: int = 100,
|
|
176
|
-
) -> List[str]:
|
|
177
|
-
try:
|
|
178
|
-
loop = asyncio.get_running_loop()
|
|
179
|
-
except RuntimeError:
|
|
180
|
-
loop = None
|
|
181
|
-
|
|
182
|
-
task = self.aio_batch_predict(
|
|
183
|
-
images=images,
|
|
184
|
-
prompts=prompts,
|
|
185
|
-
temperature=temperature,
|
|
186
|
-
top_p=top_p,
|
|
187
|
-
top_k=top_k,
|
|
188
|
-
repetition_penalty=repetition_penalty,
|
|
189
|
-
presence_penalty=presence_penalty,
|
|
190
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
191
|
-
max_new_tokens=max_new_tokens,
|
|
192
|
-
max_concurrency=max_concurrency,
|
|
193
|
-
)
|
|
194
|
-
|
|
195
|
-
if loop is not None:
|
|
196
|
-
return loop.run_until_complete(task)
|
|
197
|
-
else:
|
|
198
|
-
return asyncio.run(task)
|
|
199
|
-
|
|
200
|
-
def stream_predict(
|
|
201
|
-
self,
|
|
202
|
-
image: str | bytes,
|
|
203
|
-
prompt: str = "",
|
|
204
|
-
temperature: Optional[float] = None,
|
|
205
|
-
top_p: Optional[float] = None,
|
|
206
|
-
top_k: Optional[int] = None,
|
|
207
|
-
repetition_penalty: Optional[float] = None,
|
|
208
|
-
presence_penalty: Optional[float] = None,
|
|
209
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
210
|
-
max_new_tokens: Optional[int] = None,
|
|
211
|
-
) -> Iterable[str]:
|
|
212
|
-
prompt = self.build_prompt(prompt)
|
|
213
|
-
|
|
214
|
-
sampling_params = self.build_sampling_params(
|
|
215
|
-
temperature=temperature,
|
|
216
|
-
top_p=top_p,
|
|
217
|
-
top_k=top_k,
|
|
218
|
-
repetition_penalty=repetition_penalty,
|
|
219
|
-
presence_penalty=presence_penalty,
|
|
220
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
221
|
-
max_new_tokens=max_new_tokens,
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
if isinstance(image, str):
|
|
225
|
-
image = load_resource(image)
|
|
226
|
-
|
|
227
|
-
request_body = self.build_request_body(image, prompt, sampling_params)
|
|
228
|
-
request_body["stream"] = True
|
|
229
|
-
|
|
230
|
-
with httpx.stream(
|
|
231
|
-
"POST",
|
|
232
|
-
self.server_url,
|
|
233
|
-
json=request_body,
|
|
234
|
-
timeout=self.http_timeout,
|
|
235
|
-
) as response:
|
|
236
|
-
pos = 0
|
|
237
|
-
for chunk in response.iter_lines():
|
|
238
|
-
if not (chunk or "").startswith("data:"):
|
|
239
|
-
continue
|
|
240
|
-
if chunk == "data: [DONE]":
|
|
241
|
-
break
|
|
242
|
-
data = json.loads(chunk[5:].strip("\n"))
|
|
243
|
-
chunk_text = data["text"][pos:]
|
|
244
|
-
# meta_info = data["meta_info"]
|
|
245
|
-
pos += len(chunk_text)
|
|
246
|
-
yield chunk_text
|
|
247
|
-
|
|
248
|
-
async def aio_predict(
|
|
249
|
-
self,
|
|
250
|
-
image: str | bytes,
|
|
251
|
-
prompt: str = "",
|
|
252
|
-
temperature: Optional[float] = None,
|
|
253
|
-
top_p: Optional[float] = None,
|
|
254
|
-
top_k: Optional[int] = None,
|
|
255
|
-
repetition_penalty: Optional[float] = None,
|
|
256
|
-
presence_penalty: Optional[float] = None,
|
|
257
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
258
|
-
max_new_tokens: Optional[int] = None,
|
|
259
|
-
async_client: Optional[httpx.AsyncClient] = None,
|
|
260
|
-
) -> str:
|
|
261
|
-
prompt = self.build_prompt(prompt)
|
|
262
|
-
|
|
263
|
-
sampling_params = self.build_sampling_params(
|
|
264
|
-
temperature=temperature,
|
|
265
|
-
top_p=top_p,
|
|
266
|
-
top_k=top_k,
|
|
267
|
-
repetition_penalty=repetition_penalty,
|
|
268
|
-
presence_penalty=presence_penalty,
|
|
269
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
270
|
-
max_new_tokens=max_new_tokens,
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
if isinstance(image, str):
|
|
274
|
-
image = await aio_load_resource(image)
|
|
275
|
-
|
|
276
|
-
request_body = self.build_request_body(image, prompt, sampling_params)
|
|
277
|
-
|
|
278
|
-
if async_client is None:
|
|
279
|
-
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
|
|
280
|
-
response = await client.post(self.server_url, json=request_body)
|
|
281
|
-
response_body = response.json()
|
|
282
|
-
else:
|
|
283
|
-
response = await async_client.post(self.server_url, json=request_body)
|
|
284
|
-
response_body = response.json()
|
|
285
|
-
|
|
286
|
-
return response_body["text"]
|
|
287
|
-
|
|
288
|
-
async def aio_batch_predict(
|
|
289
|
-
self,
|
|
290
|
-
images: List[str] | List[bytes],
|
|
291
|
-
prompts: Union[List[str], str] = "",
|
|
292
|
-
temperature: Optional[float] = None,
|
|
293
|
-
top_p: Optional[float] = None,
|
|
294
|
-
top_k: Optional[int] = None,
|
|
295
|
-
repetition_penalty: Optional[float] = None,
|
|
296
|
-
presence_penalty: Optional[float] = None,
|
|
297
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
298
|
-
max_new_tokens: Optional[int] = None,
|
|
299
|
-
max_concurrency: int = 100,
|
|
300
|
-
) -> List[str]:
|
|
301
|
-
if not isinstance(prompts, list):
|
|
302
|
-
prompts = [prompts] * len(images)
|
|
303
|
-
|
|
304
|
-
assert len(prompts) == len(images), "Length of prompts and images must match."
|
|
305
|
-
|
|
306
|
-
semaphore = asyncio.Semaphore(max_concurrency)
|
|
307
|
-
outputs = [""] * len(images)
|
|
308
|
-
|
|
309
|
-
async def predict_with_semaphore(
|
|
310
|
-
idx: int,
|
|
311
|
-
image: str | bytes,
|
|
312
|
-
prompt: str,
|
|
313
|
-
async_client: httpx.AsyncClient,
|
|
314
|
-
):
|
|
315
|
-
async with semaphore:
|
|
316
|
-
output = await self.aio_predict(
|
|
317
|
-
image=image,
|
|
318
|
-
prompt=prompt,
|
|
319
|
-
temperature=temperature,
|
|
320
|
-
top_p=top_p,
|
|
321
|
-
top_k=top_k,
|
|
322
|
-
repetition_penalty=repetition_penalty,
|
|
323
|
-
presence_penalty=presence_penalty,
|
|
324
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
325
|
-
max_new_tokens=max_new_tokens,
|
|
326
|
-
async_client=async_client,
|
|
327
|
-
)
|
|
328
|
-
outputs[idx] = output
|
|
329
|
-
|
|
330
|
-
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
|
|
331
|
-
tasks = []
|
|
332
|
-
for idx, (prompt, image) in enumerate(zip(prompts, images)):
|
|
333
|
-
tasks.append(predict_with_semaphore(idx, image, prompt, client))
|
|
334
|
-
await asyncio.gather(*tasks)
|
|
335
|
-
|
|
336
|
-
return outputs
|
|
337
|
-
|
|
338
|
-
async def aio_batch_predict_as_iter(
|
|
339
|
-
self,
|
|
340
|
-
images: List[str] | List[bytes],
|
|
341
|
-
prompts: Union[List[str], str] = "",
|
|
342
|
-
temperature: Optional[float] = None,
|
|
343
|
-
top_p: Optional[float] = None,
|
|
344
|
-
top_k: Optional[int] = None,
|
|
345
|
-
repetition_penalty: Optional[float] = None,
|
|
346
|
-
presence_penalty: Optional[float] = None,
|
|
347
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
348
|
-
max_new_tokens: Optional[int] = None,
|
|
349
|
-
max_concurrency: int = 100,
|
|
350
|
-
) -> AsyncIterable[Tuple[int, str]]:
|
|
351
|
-
if not isinstance(prompts, list):
|
|
352
|
-
prompts = [prompts] * len(images)
|
|
353
|
-
|
|
354
|
-
assert len(prompts) == len(images), "Length of prompts and images must match."
|
|
355
|
-
|
|
356
|
-
semaphore = asyncio.Semaphore(max_concurrency)
|
|
357
|
-
|
|
358
|
-
async def predict_with_semaphore(
|
|
359
|
-
idx: int,
|
|
360
|
-
image: str | bytes,
|
|
361
|
-
prompt: str,
|
|
362
|
-
async_client: httpx.AsyncClient,
|
|
363
|
-
):
|
|
364
|
-
async with semaphore:
|
|
365
|
-
output = await self.aio_predict(
|
|
366
|
-
image=image,
|
|
367
|
-
prompt=prompt,
|
|
368
|
-
temperature=temperature,
|
|
369
|
-
top_p=top_p,
|
|
370
|
-
top_k=top_k,
|
|
371
|
-
repetition_penalty=repetition_penalty,
|
|
372
|
-
presence_penalty=presence_penalty,
|
|
373
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
374
|
-
max_new_tokens=max_new_tokens,
|
|
375
|
-
async_client=async_client,
|
|
376
|
-
)
|
|
377
|
-
return (idx, output)
|
|
378
|
-
|
|
379
|
-
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
|
|
380
|
-
pending: Set[asyncio.Task[Tuple[int, str]]] = set()
|
|
381
|
-
|
|
382
|
-
for idx, (prompt, image) in enumerate(zip(prompts, images)):
|
|
383
|
-
pending.add(
|
|
384
|
-
asyncio.create_task(
|
|
385
|
-
predict_with_semaphore(idx, image, prompt, client),
|
|
386
|
-
)
|
|
387
|
-
)
|
|
388
|
-
|
|
389
|
-
while len(pending) > 0:
|
|
390
|
-
done, pending = await asyncio.wait(
|
|
391
|
-
pending,
|
|
392
|
-
return_when=asyncio.FIRST_COMPLETED,
|
|
393
|
-
)
|
|
394
|
-
for task in done:
|
|
395
|
-
yield task.result()
|
|
396
|
-
|
|
397
|
-
async def aio_stream_predict(
|
|
398
|
-
self,
|
|
399
|
-
image: str | bytes,
|
|
400
|
-
prompt: str = "",
|
|
401
|
-
temperature: Optional[float] = None,
|
|
402
|
-
top_p: Optional[float] = None,
|
|
403
|
-
top_k: Optional[int] = None,
|
|
404
|
-
repetition_penalty: Optional[float] = None,
|
|
405
|
-
presence_penalty: Optional[float] = None,
|
|
406
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
407
|
-
max_new_tokens: Optional[int] = None,
|
|
408
|
-
) -> AsyncIterable[str]:
|
|
409
|
-
prompt = self.build_prompt(prompt)
|
|
410
|
-
|
|
411
|
-
sampling_params = self.build_sampling_params(
|
|
412
|
-
temperature=temperature,
|
|
413
|
-
top_p=top_p,
|
|
414
|
-
top_k=top_k,
|
|
415
|
-
repetition_penalty=repetition_penalty,
|
|
416
|
-
presence_penalty=presence_penalty,
|
|
417
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
418
|
-
max_new_tokens=max_new_tokens,
|
|
419
|
-
)
|
|
420
|
-
|
|
421
|
-
if isinstance(image, str):
|
|
422
|
-
image = await aio_load_resource(image)
|
|
423
|
-
|
|
424
|
-
request_body = self.build_request_body(image, prompt, sampling_params)
|
|
425
|
-
request_body["stream"] = True
|
|
426
|
-
|
|
427
|
-
async with httpx.AsyncClient(timeout=self.http_timeout) as client:
|
|
428
|
-
async with client.stream(
|
|
429
|
-
"POST",
|
|
430
|
-
self.server_url,
|
|
431
|
-
json=request_body,
|
|
432
|
-
) as response:
|
|
433
|
-
pos = 0
|
|
434
|
-
async for chunk in response.aiter_lines():
|
|
435
|
-
if not (chunk or "").startswith("data:"):
|
|
436
|
-
continue
|
|
437
|
-
if chunk == "data: [DONE]":
|
|
438
|
-
break
|
|
439
|
-
data = json.loads(chunk[5:].strip("\n"))
|
|
440
|
-
chunk_text = data["text"][pos:]
|
|
441
|
-
# meta_info = data["meta_info"]
|
|
442
|
-
pos += len(chunk_text)
|
|
443
|
-
yield chunk_text
|
|
@@ -1,246 +0,0 @@
|
|
|
1
|
-
from base64 import b64encode
|
|
2
|
-
from typing import AsyncIterable, Iterable, List, Optional, Union
|
|
3
|
-
|
|
4
|
-
from sglang.srt.server_args import ServerArgs
|
|
5
|
-
|
|
6
|
-
from ...model.vlm_sglang_model.engine import BatchEngine
|
|
7
|
-
from .base_predictor import (
|
|
8
|
-
DEFAULT_MAX_NEW_TOKENS,
|
|
9
|
-
DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
|
10
|
-
DEFAULT_PRESENCE_PENALTY,
|
|
11
|
-
DEFAULT_REPETITION_PENALTY,
|
|
12
|
-
DEFAULT_TEMPERATURE,
|
|
13
|
-
DEFAULT_TOP_K,
|
|
14
|
-
DEFAULT_TOP_P,
|
|
15
|
-
BasePredictor,
|
|
16
|
-
)
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
class SglangEnginePredictor(BasePredictor):
|
|
20
|
-
def __init__(
|
|
21
|
-
self,
|
|
22
|
-
server_args: ServerArgs,
|
|
23
|
-
temperature: float = DEFAULT_TEMPERATURE,
|
|
24
|
-
top_p: float = DEFAULT_TOP_P,
|
|
25
|
-
top_k: int = DEFAULT_TOP_K,
|
|
26
|
-
repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
|
|
27
|
-
presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
|
|
28
|
-
no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
|
|
29
|
-
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
|
|
30
|
-
) -> None:
|
|
31
|
-
super().__init__(
|
|
32
|
-
temperature=temperature,
|
|
33
|
-
top_p=top_p,
|
|
34
|
-
top_k=top_k,
|
|
35
|
-
repetition_penalty=repetition_penalty,
|
|
36
|
-
presence_penalty=presence_penalty,
|
|
37
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
38
|
-
max_new_tokens=max_new_tokens,
|
|
39
|
-
)
|
|
40
|
-
self.engine = BatchEngine(server_args=server_args)
|
|
41
|
-
|
|
42
|
-
def load_image_string(self, image: str | bytes) -> str:
|
|
43
|
-
if not isinstance(image, (str, bytes)):
|
|
44
|
-
raise ValueError("Image must be a string or bytes.")
|
|
45
|
-
if isinstance(image, bytes):
|
|
46
|
-
return b64encode(image).decode("utf-8")
|
|
47
|
-
if image.startswith("file://"):
|
|
48
|
-
return image[len("file://") :]
|
|
49
|
-
return image
|
|
50
|
-
|
|
51
|
-
def predict(
|
|
52
|
-
self,
|
|
53
|
-
image: str | bytes,
|
|
54
|
-
prompt: str = "",
|
|
55
|
-
temperature: Optional[float] = None,
|
|
56
|
-
top_p: Optional[float] = None,
|
|
57
|
-
top_k: Optional[int] = None,
|
|
58
|
-
repetition_penalty: Optional[float] = None,
|
|
59
|
-
presence_penalty: Optional[float] = None,
|
|
60
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
61
|
-
max_new_tokens: Optional[int] = None,
|
|
62
|
-
) -> str:
|
|
63
|
-
return self.batch_predict(
|
|
64
|
-
[image], # type: ignore
|
|
65
|
-
[prompt],
|
|
66
|
-
temperature=temperature,
|
|
67
|
-
top_p=top_p,
|
|
68
|
-
top_k=top_k,
|
|
69
|
-
repetition_penalty=repetition_penalty,
|
|
70
|
-
presence_penalty=presence_penalty,
|
|
71
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
72
|
-
max_new_tokens=max_new_tokens,
|
|
73
|
-
)[0]
|
|
74
|
-
|
|
75
|
-
def batch_predict(
|
|
76
|
-
self,
|
|
77
|
-
images: List[str] | List[bytes],
|
|
78
|
-
prompts: Union[List[str], str] = "",
|
|
79
|
-
temperature: Optional[float] = None,
|
|
80
|
-
top_p: Optional[float] = None,
|
|
81
|
-
top_k: Optional[int] = None,
|
|
82
|
-
repetition_penalty: Optional[float] = None,
|
|
83
|
-
presence_penalty: Optional[float] = None,
|
|
84
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
85
|
-
max_new_tokens: Optional[int] = None,
|
|
86
|
-
) -> List[str]:
|
|
87
|
-
|
|
88
|
-
if not isinstance(prompts, list):
|
|
89
|
-
prompts = [prompts] * len(images)
|
|
90
|
-
|
|
91
|
-
assert len(prompts) == len(images), "Length of prompts and images must match."
|
|
92
|
-
prompts = [self.build_prompt(prompt) for prompt in prompts]
|
|
93
|
-
|
|
94
|
-
if temperature is None:
|
|
95
|
-
temperature = self.temperature
|
|
96
|
-
if top_p is None:
|
|
97
|
-
top_p = self.top_p
|
|
98
|
-
if top_k is None:
|
|
99
|
-
top_k = self.top_k
|
|
100
|
-
if repetition_penalty is None:
|
|
101
|
-
repetition_penalty = self.repetition_penalty
|
|
102
|
-
if presence_penalty is None:
|
|
103
|
-
presence_penalty = self.presence_penalty
|
|
104
|
-
if no_repeat_ngram_size is None:
|
|
105
|
-
no_repeat_ngram_size = self.no_repeat_ngram_size
|
|
106
|
-
if max_new_tokens is None:
|
|
107
|
-
max_new_tokens = self.max_new_tokens
|
|
108
|
-
|
|
109
|
-
# see SamplingParams for more details
|
|
110
|
-
sampling_params = {
|
|
111
|
-
"temperature": temperature,
|
|
112
|
-
"top_p": top_p,
|
|
113
|
-
"top_k": top_k,
|
|
114
|
-
"repetition_penalty": repetition_penalty,
|
|
115
|
-
"presence_penalty": presence_penalty,
|
|
116
|
-
"custom_params": {
|
|
117
|
-
"no_repeat_ngram_size": no_repeat_ngram_size,
|
|
118
|
-
},
|
|
119
|
-
"max_new_tokens": max_new_tokens,
|
|
120
|
-
"skip_special_tokens": False,
|
|
121
|
-
}
|
|
122
|
-
|
|
123
|
-
image_strings = [self.load_image_string(img) for img in images]
|
|
124
|
-
|
|
125
|
-
output = self.engine.generate(
|
|
126
|
-
prompt=prompts,
|
|
127
|
-
image_data=image_strings,
|
|
128
|
-
sampling_params=sampling_params,
|
|
129
|
-
)
|
|
130
|
-
return [item["text"] for item in output]
|
|
131
|
-
|
|
132
|
-
def stream_predict(
|
|
133
|
-
self,
|
|
134
|
-
image: str | bytes,
|
|
135
|
-
prompt: str = "",
|
|
136
|
-
temperature: Optional[float] = None,
|
|
137
|
-
top_p: Optional[float] = None,
|
|
138
|
-
top_k: Optional[int] = None,
|
|
139
|
-
repetition_penalty: Optional[float] = None,
|
|
140
|
-
presence_penalty: Optional[float] = None,
|
|
141
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
142
|
-
max_new_tokens: Optional[int] = None,
|
|
143
|
-
) -> Iterable[str]:
|
|
144
|
-
raise NotImplementedError("Streaming is not supported yet.")
|
|
145
|
-
|
|
146
|
-
async def aio_predict(
|
|
147
|
-
self,
|
|
148
|
-
image: str | bytes,
|
|
149
|
-
prompt: str = "",
|
|
150
|
-
temperature: Optional[float] = None,
|
|
151
|
-
top_p: Optional[float] = None,
|
|
152
|
-
top_k: Optional[int] = None,
|
|
153
|
-
repetition_penalty: Optional[float] = None,
|
|
154
|
-
presence_penalty: Optional[float] = None,
|
|
155
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
156
|
-
max_new_tokens: Optional[int] = None,
|
|
157
|
-
) -> str:
|
|
158
|
-
output = await self.aio_batch_predict(
|
|
159
|
-
[image], # type: ignore
|
|
160
|
-
[prompt],
|
|
161
|
-
temperature=temperature,
|
|
162
|
-
top_p=top_p,
|
|
163
|
-
top_k=top_k,
|
|
164
|
-
repetition_penalty=repetition_penalty,
|
|
165
|
-
presence_penalty=presence_penalty,
|
|
166
|
-
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
167
|
-
max_new_tokens=max_new_tokens,
|
|
168
|
-
)
|
|
169
|
-
return output[0]
|
|
170
|
-
|
|
171
|
-
async def aio_batch_predict(
|
|
172
|
-
self,
|
|
173
|
-
images: List[str] | List[bytes],
|
|
174
|
-
prompts: Union[List[str], str] = "",
|
|
175
|
-
temperature: Optional[float] = None,
|
|
176
|
-
top_p: Optional[float] = None,
|
|
177
|
-
top_k: Optional[int] = None,
|
|
178
|
-
repetition_penalty: Optional[float] = None,
|
|
179
|
-
presence_penalty: Optional[float] = None,
|
|
180
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
181
|
-
max_new_tokens: Optional[int] = None,
|
|
182
|
-
) -> List[str]:
|
|
183
|
-
|
|
184
|
-
if not isinstance(prompts, list):
|
|
185
|
-
prompts = [prompts] * len(images)
|
|
186
|
-
|
|
187
|
-
assert len(prompts) == len(images), "Length of prompts and images must match."
|
|
188
|
-
prompts = [self.build_prompt(prompt) for prompt in prompts]
|
|
189
|
-
|
|
190
|
-
if temperature is None:
|
|
191
|
-
temperature = self.temperature
|
|
192
|
-
if top_p is None:
|
|
193
|
-
top_p = self.top_p
|
|
194
|
-
if top_k is None:
|
|
195
|
-
top_k = self.top_k
|
|
196
|
-
if repetition_penalty is None:
|
|
197
|
-
repetition_penalty = self.repetition_penalty
|
|
198
|
-
if presence_penalty is None:
|
|
199
|
-
presence_penalty = self.presence_penalty
|
|
200
|
-
if no_repeat_ngram_size is None:
|
|
201
|
-
no_repeat_ngram_size = self.no_repeat_ngram_size
|
|
202
|
-
if max_new_tokens is None:
|
|
203
|
-
max_new_tokens = self.max_new_tokens
|
|
204
|
-
|
|
205
|
-
# see SamplingParams for more details
|
|
206
|
-
sampling_params = {
|
|
207
|
-
"temperature": temperature,
|
|
208
|
-
"top_p": top_p,
|
|
209
|
-
"top_k": top_k,
|
|
210
|
-
"repetition_penalty": repetition_penalty,
|
|
211
|
-
"presence_penalty": presence_penalty,
|
|
212
|
-
"custom_params": {
|
|
213
|
-
"no_repeat_ngram_size": no_repeat_ngram_size,
|
|
214
|
-
},
|
|
215
|
-
"max_new_tokens": max_new_tokens,
|
|
216
|
-
"skip_special_tokens": False,
|
|
217
|
-
}
|
|
218
|
-
|
|
219
|
-
image_strings = [self.load_image_string(img) for img in images]
|
|
220
|
-
|
|
221
|
-
output = await self.engine.async_generate(
|
|
222
|
-
prompt=prompts,
|
|
223
|
-
image_data=image_strings,
|
|
224
|
-
sampling_params=sampling_params,
|
|
225
|
-
)
|
|
226
|
-
ret = []
|
|
227
|
-
for item in output: # type: ignore
|
|
228
|
-
ret.append(item["text"])
|
|
229
|
-
return ret
|
|
230
|
-
|
|
231
|
-
async def aio_stream_predict(
|
|
232
|
-
self,
|
|
233
|
-
image: str | bytes,
|
|
234
|
-
prompt: str = "",
|
|
235
|
-
temperature: Optional[float] = None,
|
|
236
|
-
top_p: Optional[float] = None,
|
|
237
|
-
top_k: Optional[int] = None,
|
|
238
|
-
repetition_penalty: Optional[float] = None,
|
|
239
|
-
presence_penalty: Optional[float] = None,
|
|
240
|
-
no_repeat_ngram_size: Optional[int] = None,
|
|
241
|
-
max_new_tokens: Optional[int] = None,
|
|
242
|
-
) -> AsyncIterable[str]:
|
|
243
|
-
raise NotImplementedError("Streaming is not supported yet.")
|
|
244
|
-
|
|
245
|
-
def close(self):
|
|
246
|
-
self.engine.shutdown()
|