mineru 2.2.2__py3-none-any.whl → 2.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. mineru/backend/pipeline/pipeline_middle_json_mkcontent.py +3 -3
  2. mineru/backend/vlm/model_output_to_middle_json.py +123 -0
  3. mineru/backend/vlm/vlm_analyze.py +97 -16
  4. mineru/backend/vlm/vlm_magic_model.py +201 -135
  5. mineru/backend/vlm/vlm_middle_json_mkcontent.py +52 -11
  6. mineru/cli/client.py +6 -5
  7. mineru/cli/common.py +17 -16
  8. mineru/cli/fast_api.py +9 -7
  9. mineru/cli/gradio_app.py +15 -16
  10. mineru/cli/vlm_vllm_server.py +4 -0
  11. mineru/model/table/rec/unet_table/main.py +8 -0
  12. mineru/model/vlm_vllm_model/__init__.py +0 -0
  13. mineru/model/vlm_vllm_model/server.py +51 -0
  14. mineru/resources/header.html +10 -2
  15. mineru/utils/draw_bbox.py +32 -10
  16. mineru/utils/enum_class.py +16 -2
  17. mineru/utils/guess_suffix_or_lang.py +20 -0
  18. mineru/utils/span_block_fix.py +4 -2
  19. mineru/version.py +1 -1
  20. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/METADATA +70 -25
  21. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/RECORD +25 -38
  22. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/entry_points.txt +1 -1
  23. mineru/backend/vlm/base_predictor.py +0 -186
  24. mineru/backend/vlm/hf_predictor.py +0 -217
  25. mineru/backend/vlm/predictor.py +0 -111
  26. mineru/backend/vlm/sglang_client_predictor.py +0 -443
  27. mineru/backend/vlm/sglang_engine_predictor.py +0 -246
  28. mineru/backend/vlm/token_to_middle_json.py +0 -122
  29. mineru/backend/vlm/utils.py +0 -40
  30. mineru/cli/vlm_sglang_server.py +0 -4
  31. mineru/model/vlm_hf_model/__init__.py +0 -9
  32. mineru/model/vlm_hf_model/configuration_mineru2.py +0 -38
  33. mineru/model/vlm_hf_model/image_processing_mineru2.py +0 -269
  34. mineru/model/vlm_hf_model/modeling_mineru2.py +0 -449
  35. mineru/model/vlm_sglang_model/__init__.py +0 -14
  36. mineru/model/vlm_sglang_model/engine.py +0 -264
  37. mineru/model/vlm_sglang_model/image_processor.py +0 -213
  38. mineru/model/vlm_sglang_model/logit_processor.py +0 -90
  39. mineru/model/vlm_sglang_model/model.py +0 -453
  40. mineru/model/vlm_sglang_model/server.py +0 -75
  41. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/WHEEL +0 -0
  42. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/licenses/LICENSE.md +0 -0
  43. {mineru-2.2.2.dist-info → mineru-2.5.0.dist-info}/top_level.txt +0 -0
@@ -1,443 +0,0 @@
1
- import asyncio
2
- import json
3
- import re
4
- from base64 import b64encode
5
- from typing import AsyncIterable, Iterable, List, Optional, Set, Tuple, Union
6
-
7
- import httpx
8
-
9
- from .base_predictor import (
10
- DEFAULT_MAX_NEW_TOKENS,
11
- DEFAULT_NO_REPEAT_NGRAM_SIZE,
12
- DEFAULT_PRESENCE_PENALTY,
13
- DEFAULT_REPETITION_PENALTY,
14
- DEFAULT_TEMPERATURE,
15
- DEFAULT_TOP_K,
16
- DEFAULT_TOP_P,
17
- BasePredictor,
18
- )
19
- from .utils import aio_load_resource, load_resource
20
-
21
-
22
- class SglangClientPredictor(BasePredictor):
23
- def __init__(
24
- self,
25
- server_url: str,
26
- temperature: float = DEFAULT_TEMPERATURE,
27
- top_p: float = DEFAULT_TOP_P,
28
- top_k: int = DEFAULT_TOP_K,
29
- repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
30
- presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
31
- no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
32
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
33
- http_timeout: int = 600,
34
- ) -> None:
35
- super().__init__(
36
- temperature=temperature,
37
- top_p=top_p,
38
- top_k=top_k,
39
- repetition_penalty=repetition_penalty,
40
- presence_penalty=presence_penalty,
41
- no_repeat_ngram_size=no_repeat_ngram_size,
42
- max_new_tokens=max_new_tokens,
43
- )
44
- self.http_timeout = http_timeout
45
-
46
- base_url = self.get_base_url(server_url)
47
- self.check_server_health(base_url)
48
- self.model_path = self.get_model_path(base_url)
49
- self.server_url = f"{base_url}/generate"
50
-
51
- @staticmethod
52
- def get_base_url(server_url: str) -> str:
53
- matched = re.match(r"^(https?://[^/]+)", server_url)
54
- if not matched:
55
- raise ValueError(f"Invalid server URL: {server_url}")
56
- return matched.group(1)
57
-
58
- def check_server_health(self, base_url: str):
59
- try:
60
- response = httpx.get(f"{base_url}/health_generate", timeout=self.http_timeout)
61
- except httpx.ConnectError:
62
- raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
63
- if response.status_code != 200:
64
- raise RuntimeError(
65
- f"Server {base_url} is not healthy. Status code: {response.status_code}, response body: {response.text}"
66
- )
67
-
68
- def get_model_path(self, base_url: str) -> str:
69
- try:
70
- response = httpx.get(f"{base_url}/get_model_info", timeout=self.http_timeout)
71
- except httpx.ConnectError:
72
- raise RuntimeError(f"Failed to connect to server {base_url}. Please check if the server is running.")
73
- if response.status_code != 200:
74
- raise RuntimeError(
75
- f"Failed to get model info from {base_url}. Status code: {response.status_code}, response body: {response.text}"
76
- )
77
- return response.json()["model_path"]
78
-
79
- def build_sampling_params(
80
- self,
81
- temperature: Optional[float],
82
- top_p: Optional[float],
83
- top_k: Optional[int],
84
- repetition_penalty: Optional[float],
85
- presence_penalty: Optional[float],
86
- no_repeat_ngram_size: Optional[int],
87
- max_new_tokens: Optional[int],
88
- ) -> dict:
89
- if temperature is None:
90
- temperature = self.temperature
91
- if top_p is None:
92
- top_p = self.top_p
93
- if top_k is None:
94
- top_k = self.top_k
95
- if repetition_penalty is None:
96
- repetition_penalty = self.repetition_penalty
97
- if presence_penalty is None:
98
- presence_penalty = self.presence_penalty
99
- if no_repeat_ngram_size is None:
100
- no_repeat_ngram_size = self.no_repeat_ngram_size
101
- if max_new_tokens is None:
102
- max_new_tokens = self.max_new_tokens
103
-
104
- # see SamplingParams for more details
105
- return {
106
- "temperature": temperature,
107
- "top_p": top_p,
108
- "top_k": top_k,
109
- "repetition_penalty": repetition_penalty,
110
- "presence_penalty": presence_penalty,
111
- "custom_params": {
112
- "no_repeat_ngram_size": no_repeat_ngram_size,
113
- },
114
- "max_new_tokens": max_new_tokens,
115
- "skip_special_tokens": False,
116
- }
117
-
118
- def build_request_body(
119
- self,
120
- image: bytes,
121
- prompt: str,
122
- sampling_params: dict,
123
- ) -> dict:
124
- image_base64 = b64encode(image).decode("utf-8")
125
- return {
126
- "text": prompt,
127
- "image_data": image_base64,
128
- "sampling_params": sampling_params,
129
- "modalities": ["image"],
130
- }
131
-
132
- def predict(
133
- self,
134
- image: str | bytes,
135
- prompt: str = "",
136
- temperature: Optional[float] = None,
137
- top_p: Optional[float] = None,
138
- top_k: Optional[int] = None,
139
- repetition_penalty: Optional[float] = None,
140
- presence_penalty: Optional[float] = None,
141
- no_repeat_ngram_size: Optional[int] = None,
142
- max_new_tokens: Optional[int] = None,
143
- ) -> str:
144
- prompt = self.build_prompt(prompt)
145
-
146
- sampling_params = self.build_sampling_params(
147
- temperature=temperature,
148
- top_p=top_p,
149
- top_k=top_k,
150
- repetition_penalty=repetition_penalty,
151
- presence_penalty=presence_penalty,
152
- no_repeat_ngram_size=no_repeat_ngram_size,
153
- max_new_tokens=max_new_tokens,
154
- )
155
-
156
- if isinstance(image, str):
157
- image = load_resource(image)
158
-
159
- request_body = self.build_request_body(image, prompt, sampling_params)
160
- response = httpx.post(self.server_url, json=request_body, timeout=self.http_timeout)
161
- response_body = response.json()
162
- return response_body["text"]
163
-
164
- def batch_predict(
165
- self,
166
- images: List[str] | List[bytes],
167
- prompts: Union[List[str], str] = "",
168
- temperature: Optional[float] = None,
169
- top_p: Optional[float] = None,
170
- top_k: Optional[int] = None,
171
- repetition_penalty: Optional[float] = None,
172
- presence_penalty: Optional[float] = None,
173
- no_repeat_ngram_size: Optional[int] = None,
174
- max_new_tokens: Optional[int] = None,
175
- max_concurrency: int = 100,
176
- ) -> List[str]:
177
- try:
178
- loop = asyncio.get_running_loop()
179
- except RuntimeError:
180
- loop = None
181
-
182
- task = self.aio_batch_predict(
183
- images=images,
184
- prompts=prompts,
185
- temperature=temperature,
186
- top_p=top_p,
187
- top_k=top_k,
188
- repetition_penalty=repetition_penalty,
189
- presence_penalty=presence_penalty,
190
- no_repeat_ngram_size=no_repeat_ngram_size,
191
- max_new_tokens=max_new_tokens,
192
- max_concurrency=max_concurrency,
193
- )
194
-
195
- if loop is not None:
196
- return loop.run_until_complete(task)
197
- else:
198
- return asyncio.run(task)
199
-
200
- def stream_predict(
201
- self,
202
- image: str | bytes,
203
- prompt: str = "",
204
- temperature: Optional[float] = None,
205
- top_p: Optional[float] = None,
206
- top_k: Optional[int] = None,
207
- repetition_penalty: Optional[float] = None,
208
- presence_penalty: Optional[float] = None,
209
- no_repeat_ngram_size: Optional[int] = None,
210
- max_new_tokens: Optional[int] = None,
211
- ) -> Iterable[str]:
212
- prompt = self.build_prompt(prompt)
213
-
214
- sampling_params = self.build_sampling_params(
215
- temperature=temperature,
216
- top_p=top_p,
217
- top_k=top_k,
218
- repetition_penalty=repetition_penalty,
219
- presence_penalty=presence_penalty,
220
- no_repeat_ngram_size=no_repeat_ngram_size,
221
- max_new_tokens=max_new_tokens,
222
- )
223
-
224
- if isinstance(image, str):
225
- image = load_resource(image)
226
-
227
- request_body = self.build_request_body(image, prompt, sampling_params)
228
- request_body["stream"] = True
229
-
230
- with httpx.stream(
231
- "POST",
232
- self.server_url,
233
- json=request_body,
234
- timeout=self.http_timeout,
235
- ) as response:
236
- pos = 0
237
- for chunk in response.iter_lines():
238
- if not (chunk or "").startswith("data:"):
239
- continue
240
- if chunk == "data: [DONE]":
241
- break
242
- data = json.loads(chunk[5:].strip("\n"))
243
- chunk_text = data["text"][pos:]
244
- # meta_info = data["meta_info"]
245
- pos += len(chunk_text)
246
- yield chunk_text
247
-
248
- async def aio_predict(
249
- self,
250
- image: str | bytes,
251
- prompt: str = "",
252
- temperature: Optional[float] = None,
253
- top_p: Optional[float] = None,
254
- top_k: Optional[int] = None,
255
- repetition_penalty: Optional[float] = None,
256
- presence_penalty: Optional[float] = None,
257
- no_repeat_ngram_size: Optional[int] = None,
258
- max_new_tokens: Optional[int] = None,
259
- async_client: Optional[httpx.AsyncClient] = None,
260
- ) -> str:
261
- prompt = self.build_prompt(prompt)
262
-
263
- sampling_params = self.build_sampling_params(
264
- temperature=temperature,
265
- top_p=top_p,
266
- top_k=top_k,
267
- repetition_penalty=repetition_penalty,
268
- presence_penalty=presence_penalty,
269
- no_repeat_ngram_size=no_repeat_ngram_size,
270
- max_new_tokens=max_new_tokens,
271
- )
272
-
273
- if isinstance(image, str):
274
- image = await aio_load_resource(image)
275
-
276
- request_body = self.build_request_body(image, prompt, sampling_params)
277
-
278
- if async_client is None:
279
- async with httpx.AsyncClient(timeout=self.http_timeout) as client:
280
- response = await client.post(self.server_url, json=request_body)
281
- response_body = response.json()
282
- else:
283
- response = await async_client.post(self.server_url, json=request_body)
284
- response_body = response.json()
285
-
286
- return response_body["text"]
287
-
288
- async def aio_batch_predict(
289
- self,
290
- images: List[str] | List[bytes],
291
- prompts: Union[List[str], str] = "",
292
- temperature: Optional[float] = None,
293
- top_p: Optional[float] = None,
294
- top_k: Optional[int] = None,
295
- repetition_penalty: Optional[float] = None,
296
- presence_penalty: Optional[float] = None,
297
- no_repeat_ngram_size: Optional[int] = None,
298
- max_new_tokens: Optional[int] = None,
299
- max_concurrency: int = 100,
300
- ) -> List[str]:
301
- if not isinstance(prompts, list):
302
- prompts = [prompts] * len(images)
303
-
304
- assert len(prompts) == len(images), "Length of prompts and images must match."
305
-
306
- semaphore = asyncio.Semaphore(max_concurrency)
307
- outputs = [""] * len(images)
308
-
309
- async def predict_with_semaphore(
310
- idx: int,
311
- image: str | bytes,
312
- prompt: str,
313
- async_client: httpx.AsyncClient,
314
- ):
315
- async with semaphore:
316
- output = await self.aio_predict(
317
- image=image,
318
- prompt=prompt,
319
- temperature=temperature,
320
- top_p=top_p,
321
- top_k=top_k,
322
- repetition_penalty=repetition_penalty,
323
- presence_penalty=presence_penalty,
324
- no_repeat_ngram_size=no_repeat_ngram_size,
325
- max_new_tokens=max_new_tokens,
326
- async_client=async_client,
327
- )
328
- outputs[idx] = output
329
-
330
- async with httpx.AsyncClient(timeout=self.http_timeout) as client:
331
- tasks = []
332
- for idx, (prompt, image) in enumerate(zip(prompts, images)):
333
- tasks.append(predict_with_semaphore(idx, image, prompt, client))
334
- await asyncio.gather(*tasks)
335
-
336
- return outputs
337
-
338
- async def aio_batch_predict_as_iter(
339
- self,
340
- images: List[str] | List[bytes],
341
- prompts: Union[List[str], str] = "",
342
- temperature: Optional[float] = None,
343
- top_p: Optional[float] = None,
344
- top_k: Optional[int] = None,
345
- repetition_penalty: Optional[float] = None,
346
- presence_penalty: Optional[float] = None,
347
- no_repeat_ngram_size: Optional[int] = None,
348
- max_new_tokens: Optional[int] = None,
349
- max_concurrency: int = 100,
350
- ) -> AsyncIterable[Tuple[int, str]]:
351
- if not isinstance(prompts, list):
352
- prompts = [prompts] * len(images)
353
-
354
- assert len(prompts) == len(images), "Length of prompts and images must match."
355
-
356
- semaphore = asyncio.Semaphore(max_concurrency)
357
-
358
- async def predict_with_semaphore(
359
- idx: int,
360
- image: str | bytes,
361
- prompt: str,
362
- async_client: httpx.AsyncClient,
363
- ):
364
- async with semaphore:
365
- output = await self.aio_predict(
366
- image=image,
367
- prompt=prompt,
368
- temperature=temperature,
369
- top_p=top_p,
370
- top_k=top_k,
371
- repetition_penalty=repetition_penalty,
372
- presence_penalty=presence_penalty,
373
- no_repeat_ngram_size=no_repeat_ngram_size,
374
- max_new_tokens=max_new_tokens,
375
- async_client=async_client,
376
- )
377
- return (idx, output)
378
-
379
- async with httpx.AsyncClient(timeout=self.http_timeout) as client:
380
- pending: Set[asyncio.Task[Tuple[int, str]]] = set()
381
-
382
- for idx, (prompt, image) in enumerate(zip(prompts, images)):
383
- pending.add(
384
- asyncio.create_task(
385
- predict_with_semaphore(idx, image, prompt, client),
386
- )
387
- )
388
-
389
- while len(pending) > 0:
390
- done, pending = await asyncio.wait(
391
- pending,
392
- return_when=asyncio.FIRST_COMPLETED,
393
- )
394
- for task in done:
395
- yield task.result()
396
-
397
- async def aio_stream_predict(
398
- self,
399
- image: str | bytes,
400
- prompt: str = "",
401
- temperature: Optional[float] = None,
402
- top_p: Optional[float] = None,
403
- top_k: Optional[int] = None,
404
- repetition_penalty: Optional[float] = None,
405
- presence_penalty: Optional[float] = None,
406
- no_repeat_ngram_size: Optional[int] = None,
407
- max_new_tokens: Optional[int] = None,
408
- ) -> AsyncIterable[str]:
409
- prompt = self.build_prompt(prompt)
410
-
411
- sampling_params = self.build_sampling_params(
412
- temperature=temperature,
413
- top_p=top_p,
414
- top_k=top_k,
415
- repetition_penalty=repetition_penalty,
416
- presence_penalty=presence_penalty,
417
- no_repeat_ngram_size=no_repeat_ngram_size,
418
- max_new_tokens=max_new_tokens,
419
- )
420
-
421
- if isinstance(image, str):
422
- image = await aio_load_resource(image)
423
-
424
- request_body = self.build_request_body(image, prompt, sampling_params)
425
- request_body["stream"] = True
426
-
427
- async with httpx.AsyncClient(timeout=self.http_timeout) as client:
428
- async with client.stream(
429
- "POST",
430
- self.server_url,
431
- json=request_body,
432
- ) as response:
433
- pos = 0
434
- async for chunk in response.aiter_lines():
435
- if not (chunk or "").startswith("data:"):
436
- continue
437
- if chunk == "data: [DONE]":
438
- break
439
- data = json.loads(chunk[5:].strip("\n"))
440
- chunk_text = data["text"][pos:]
441
- # meta_info = data["meta_info"]
442
- pos += len(chunk_text)
443
- yield chunk_text
@@ -1,246 +0,0 @@
1
- from base64 import b64encode
2
- from typing import AsyncIterable, Iterable, List, Optional, Union
3
-
4
- from sglang.srt.server_args import ServerArgs
5
-
6
- from ...model.vlm_sglang_model.engine import BatchEngine
7
- from .base_predictor import (
8
- DEFAULT_MAX_NEW_TOKENS,
9
- DEFAULT_NO_REPEAT_NGRAM_SIZE,
10
- DEFAULT_PRESENCE_PENALTY,
11
- DEFAULT_REPETITION_PENALTY,
12
- DEFAULT_TEMPERATURE,
13
- DEFAULT_TOP_K,
14
- DEFAULT_TOP_P,
15
- BasePredictor,
16
- )
17
-
18
-
19
- class SglangEnginePredictor(BasePredictor):
20
- def __init__(
21
- self,
22
- server_args: ServerArgs,
23
- temperature: float = DEFAULT_TEMPERATURE,
24
- top_p: float = DEFAULT_TOP_P,
25
- top_k: int = DEFAULT_TOP_K,
26
- repetition_penalty: float = DEFAULT_REPETITION_PENALTY,
27
- presence_penalty: float = DEFAULT_PRESENCE_PENALTY,
28
- no_repeat_ngram_size: int = DEFAULT_NO_REPEAT_NGRAM_SIZE,
29
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
30
- ) -> None:
31
- super().__init__(
32
- temperature=temperature,
33
- top_p=top_p,
34
- top_k=top_k,
35
- repetition_penalty=repetition_penalty,
36
- presence_penalty=presence_penalty,
37
- no_repeat_ngram_size=no_repeat_ngram_size,
38
- max_new_tokens=max_new_tokens,
39
- )
40
- self.engine = BatchEngine(server_args=server_args)
41
-
42
- def load_image_string(self, image: str | bytes) -> str:
43
- if not isinstance(image, (str, bytes)):
44
- raise ValueError("Image must be a string or bytes.")
45
- if isinstance(image, bytes):
46
- return b64encode(image).decode("utf-8")
47
- if image.startswith("file://"):
48
- return image[len("file://") :]
49
- return image
50
-
51
- def predict(
52
- self,
53
- image: str | bytes,
54
- prompt: str = "",
55
- temperature: Optional[float] = None,
56
- top_p: Optional[float] = None,
57
- top_k: Optional[int] = None,
58
- repetition_penalty: Optional[float] = None,
59
- presence_penalty: Optional[float] = None,
60
- no_repeat_ngram_size: Optional[int] = None,
61
- max_new_tokens: Optional[int] = None,
62
- ) -> str:
63
- return self.batch_predict(
64
- [image], # type: ignore
65
- [prompt],
66
- temperature=temperature,
67
- top_p=top_p,
68
- top_k=top_k,
69
- repetition_penalty=repetition_penalty,
70
- presence_penalty=presence_penalty,
71
- no_repeat_ngram_size=no_repeat_ngram_size,
72
- max_new_tokens=max_new_tokens,
73
- )[0]
74
-
75
- def batch_predict(
76
- self,
77
- images: List[str] | List[bytes],
78
- prompts: Union[List[str], str] = "",
79
- temperature: Optional[float] = None,
80
- top_p: Optional[float] = None,
81
- top_k: Optional[int] = None,
82
- repetition_penalty: Optional[float] = None,
83
- presence_penalty: Optional[float] = None,
84
- no_repeat_ngram_size: Optional[int] = None,
85
- max_new_tokens: Optional[int] = None,
86
- ) -> List[str]:
87
-
88
- if not isinstance(prompts, list):
89
- prompts = [prompts] * len(images)
90
-
91
- assert len(prompts) == len(images), "Length of prompts and images must match."
92
- prompts = [self.build_prompt(prompt) for prompt in prompts]
93
-
94
- if temperature is None:
95
- temperature = self.temperature
96
- if top_p is None:
97
- top_p = self.top_p
98
- if top_k is None:
99
- top_k = self.top_k
100
- if repetition_penalty is None:
101
- repetition_penalty = self.repetition_penalty
102
- if presence_penalty is None:
103
- presence_penalty = self.presence_penalty
104
- if no_repeat_ngram_size is None:
105
- no_repeat_ngram_size = self.no_repeat_ngram_size
106
- if max_new_tokens is None:
107
- max_new_tokens = self.max_new_tokens
108
-
109
- # see SamplingParams for more details
110
- sampling_params = {
111
- "temperature": temperature,
112
- "top_p": top_p,
113
- "top_k": top_k,
114
- "repetition_penalty": repetition_penalty,
115
- "presence_penalty": presence_penalty,
116
- "custom_params": {
117
- "no_repeat_ngram_size": no_repeat_ngram_size,
118
- },
119
- "max_new_tokens": max_new_tokens,
120
- "skip_special_tokens": False,
121
- }
122
-
123
- image_strings = [self.load_image_string(img) for img in images]
124
-
125
- output = self.engine.generate(
126
- prompt=prompts,
127
- image_data=image_strings,
128
- sampling_params=sampling_params,
129
- )
130
- return [item["text"] for item in output]
131
-
132
- def stream_predict(
133
- self,
134
- image: str | bytes,
135
- prompt: str = "",
136
- temperature: Optional[float] = None,
137
- top_p: Optional[float] = None,
138
- top_k: Optional[int] = None,
139
- repetition_penalty: Optional[float] = None,
140
- presence_penalty: Optional[float] = None,
141
- no_repeat_ngram_size: Optional[int] = None,
142
- max_new_tokens: Optional[int] = None,
143
- ) -> Iterable[str]:
144
- raise NotImplementedError("Streaming is not supported yet.")
145
-
146
- async def aio_predict(
147
- self,
148
- image: str | bytes,
149
- prompt: str = "",
150
- temperature: Optional[float] = None,
151
- top_p: Optional[float] = None,
152
- top_k: Optional[int] = None,
153
- repetition_penalty: Optional[float] = None,
154
- presence_penalty: Optional[float] = None,
155
- no_repeat_ngram_size: Optional[int] = None,
156
- max_new_tokens: Optional[int] = None,
157
- ) -> str:
158
- output = await self.aio_batch_predict(
159
- [image], # type: ignore
160
- [prompt],
161
- temperature=temperature,
162
- top_p=top_p,
163
- top_k=top_k,
164
- repetition_penalty=repetition_penalty,
165
- presence_penalty=presence_penalty,
166
- no_repeat_ngram_size=no_repeat_ngram_size,
167
- max_new_tokens=max_new_tokens,
168
- )
169
- return output[0]
170
-
171
- async def aio_batch_predict(
172
- self,
173
- images: List[str] | List[bytes],
174
- prompts: Union[List[str], str] = "",
175
- temperature: Optional[float] = None,
176
- top_p: Optional[float] = None,
177
- top_k: Optional[int] = None,
178
- repetition_penalty: Optional[float] = None,
179
- presence_penalty: Optional[float] = None,
180
- no_repeat_ngram_size: Optional[int] = None,
181
- max_new_tokens: Optional[int] = None,
182
- ) -> List[str]:
183
-
184
- if not isinstance(prompts, list):
185
- prompts = [prompts] * len(images)
186
-
187
- assert len(prompts) == len(images), "Length of prompts and images must match."
188
- prompts = [self.build_prompt(prompt) for prompt in prompts]
189
-
190
- if temperature is None:
191
- temperature = self.temperature
192
- if top_p is None:
193
- top_p = self.top_p
194
- if top_k is None:
195
- top_k = self.top_k
196
- if repetition_penalty is None:
197
- repetition_penalty = self.repetition_penalty
198
- if presence_penalty is None:
199
- presence_penalty = self.presence_penalty
200
- if no_repeat_ngram_size is None:
201
- no_repeat_ngram_size = self.no_repeat_ngram_size
202
- if max_new_tokens is None:
203
- max_new_tokens = self.max_new_tokens
204
-
205
- # see SamplingParams for more details
206
- sampling_params = {
207
- "temperature": temperature,
208
- "top_p": top_p,
209
- "top_k": top_k,
210
- "repetition_penalty": repetition_penalty,
211
- "presence_penalty": presence_penalty,
212
- "custom_params": {
213
- "no_repeat_ngram_size": no_repeat_ngram_size,
214
- },
215
- "max_new_tokens": max_new_tokens,
216
- "skip_special_tokens": False,
217
- }
218
-
219
- image_strings = [self.load_image_string(img) for img in images]
220
-
221
- output = await self.engine.async_generate(
222
- prompt=prompts,
223
- image_data=image_strings,
224
- sampling_params=sampling_params,
225
- )
226
- ret = []
227
- for item in output: # type: ignore
228
- ret.append(item["text"])
229
- return ret
230
-
231
- async def aio_stream_predict(
232
- self,
233
- image: str | bytes,
234
- prompt: str = "",
235
- temperature: Optional[float] = None,
236
- top_p: Optional[float] = None,
237
- top_k: Optional[int] = None,
238
- repetition_penalty: Optional[float] = None,
239
- presence_penalty: Optional[float] = None,
240
- no_repeat_ngram_size: Optional[int] = None,
241
- max_new_tokens: Optional[int] = None,
242
- ) -> AsyncIterable[str]:
243
- raise NotImplementedError("Streaming is not supported yet.")
244
-
245
- def close(self):
246
- self.engine.shutdown()