mineru 2.2.2__py3-none-any.whl → 2.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. mineru/backend/pipeline/pipeline_middle_json_mkcontent.py +3 -3
  2. mineru/backend/vlm/model_output_to_middle_json.py +123 -0
  3. mineru/backend/vlm/vlm_analyze.py +105 -16
  4. mineru/backend/vlm/vlm_magic_model.py +201 -135
  5. mineru/backend/vlm/vlm_middle_json_mkcontent.py +52 -11
  6. mineru/cli/client.py +6 -5
  7. mineru/cli/common.py +17 -16
  8. mineru/cli/fast_api.py +9 -7
  9. mineru/cli/gradio_app.py +15 -16
  10. mineru/cli/vlm_vllm_server.py +4 -0
  11. mineru/model/table/rec/unet_table/main.py +8 -0
  12. mineru/model/vlm_vllm_model/__init__.py +0 -0
  13. mineru/model/vlm_vllm_model/server.py +59 -0
  14. mineru/resources/header.html +10 -2
  15. mineru/utils/draw_bbox.py +32 -10
  16. mineru/utils/enum_class.py +16 -2
  17. mineru/utils/guess_suffix_or_lang.py +20 -0
  18. mineru/utils/span_block_fix.py +4 -2
  19. mineru/version.py +1 -1
  20. {mineru-2.2.2.dist-info → mineru-2.5.1.dist-info}/METADATA +70 -25
  21. {mineru-2.2.2.dist-info → mineru-2.5.1.dist-info}/RECORD +25 -38
  22. {mineru-2.2.2.dist-info → mineru-2.5.1.dist-info}/entry_points.txt +1 -1
  23. mineru/backend/vlm/base_predictor.py +0 -186
  24. mineru/backend/vlm/hf_predictor.py +0 -217
  25. mineru/backend/vlm/predictor.py +0 -111
  26. mineru/backend/vlm/sglang_client_predictor.py +0 -443
  27. mineru/backend/vlm/sglang_engine_predictor.py +0 -246
  28. mineru/backend/vlm/token_to_middle_json.py +0 -122
  29. mineru/backend/vlm/utils.py +0 -40
  30. mineru/cli/vlm_sglang_server.py +0 -4
  31. mineru/model/vlm_hf_model/__init__.py +0 -9
  32. mineru/model/vlm_hf_model/configuration_mineru2.py +0 -38
  33. mineru/model/vlm_hf_model/image_processing_mineru2.py +0 -269
  34. mineru/model/vlm_hf_model/modeling_mineru2.py +0 -449
  35. mineru/model/vlm_sglang_model/__init__.py +0 -14
  36. mineru/model/vlm_sglang_model/engine.py +0 -264
  37. mineru/model/vlm_sglang_model/image_processor.py +0 -213
  38. mineru/model/vlm_sglang_model/logit_processor.py +0 -90
  39. mineru/model/vlm_sglang_model/model.py +0 -453
  40. mineru/model/vlm_sglang_model/server.py +0 -75
  41. {mineru-2.2.2.dist-info → mineru-2.5.1.dist-info}/WHEEL +0 -0
  42. {mineru-2.2.2.dist-info → mineru-2.5.1.dist-info}/licenses/LICENSE.md +0 -0
  43. {mineru-2.2.2.dist-info → mineru-2.5.1.dist-info}/top_level.txt +0 -0
@@ -1,264 +0,0 @@
1
- import asyncio
2
- import time
3
- from types import MethodType
4
- from typing import AsyncIterator, Dict, Iterator, List, Optional, Union
5
-
6
- import fastapi
7
- from sglang.srt.entrypoints.engine import Engine as _Engine
8
- from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
9
- from sglang.srt.managers.tokenizer_manager import (
10
- TokenizerManager,
11
- dataclass_to_string_truncated,
12
- logger,
13
- )
14
- from sglang.srt.sampling.sampling_params import SamplingParams
15
- from sglang.srt.server_args import ServerArgs
16
-
17
- from ...utils.run_async import run_async
18
- from .logit_processor import Mineru2LogitProcessor
19
-
20
-
21
- class BatchEngine(_Engine):
22
- """
23
- The engine is patched to support batch multi-modal generate, and early image preprocessing.
24
- """
25
-
26
- def __init__(self, server_args: ServerArgs, **kwargs):
27
- server_args.enable_custom_logit_processor = True
28
- super().__init__(server_args=server_args, **kwargs)
29
- _patch_tokenizer_manager(self.tokenizer_manager)
30
-
31
- def generate(
32
- self,
33
- # The input prompt. It can be a single prompt or a batch of prompts.
34
- prompt: Optional[Union[List[str], str]] = None,
35
- sampling_params: Optional[Union[List[Dict], Dict]] = None,
36
- # The token ids for text; one can either specify text or input_ids.
37
- input_ids: Optional[Union[List[List[int]], List[int]]] = None,
38
- # The image input. It can be a file name, a url, or base64 encoded string.
39
- # See also python/sglang/srt/utils.py:load_image.
40
- image_data: Optional[Union[List[str], str]] = None,
41
- return_logprob: Optional[Union[List[bool], bool]] = False,
42
- logprob_start_len: Optional[Union[List[int], int]] = None,
43
- top_logprobs_num: Optional[Union[List[int], int]] = None,
44
- token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
45
- lora_path: Optional[List[Optional[str]]] = None,
46
- custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
47
- return_hidden_states: bool = False,
48
- stream: bool = False,
49
- ) -> Union[Dict, Iterator[Dict]]:
50
- """
51
- The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
52
- Please refer to `GenerateReqInput` for the documentation.
53
- """
54
- modalities_list = []
55
-
56
- # EDIT
57
- if isinstance(image_data, list):
58
- for _ in range(len(image_data)):
59
- modalities_list.append(["image"])
60
- elif image_data is not None:
61
- modalities_list.append("image")
62
-
63
- # ADD
64
- if custom_logit_processor is None:
65
- custom_logit_processor = Mineru2LogitProcessor().to_str()
66
-
67
- obj = GenerateReqInput(
68
- text=prompt,
69
- input_ids=input_ids,
70
- sampling_params=sampling_params,
71
- image_data=image_data,
72
- return_logprob=return_logprob,
73
- logprob_start_len=logprob_start_len,
74
- top_logprobs_num=top_logprobs_num,
75
- token_ids_logprob=token_ids_logprob,
76
- lora_path=lora_path,
77
- modalities=modalities_list,
78
- custom_logit_processor=custom_logit_processor,
79
- return_hidden_states=return_hidden_states,
80
- stream=stream,
81
- )
82
- generator = _generate_request(self.tokenizer_manager, obj, None)
83
-
84
- if stream:
85
-
86
- def generator_wrapper():
87
- while True:
88
- try:
89
- chunk = run_async(generator.__anext__())
90
- yield chunk
91
- except StopAsyncIteration:
92
- break
93
-
94
- return generator_wrapper()
95
- else:
96
- ret = run_async(generator.__anext__())
97
- return ret
98
-
99
- async def async_generate(
100
- self,
101
- # The input prompt. It can be a single prompt or a batch of prompts.
102
- prompt: Optional[Union[List[str], str]] = None,
103
- sampling_params: Optional[Union[List[Dict], Dict]] = None,
104
- # The token ids for text; one can either specify text or input_ids.
105
- input_ids: Optional[Union[List[List[int]], List[int]]] = None,
106
- # The image input. It can be a file name, a url, or base64 encoded string.
107
- # See also python/sglang/srt/utils.py:load_image.
108
- image_data: Optional[Union[List[str], str]] = None,
109
- return_logprob: Optional[Union[List[bool], bool]] = False,
110
- logprob_start_len: Optional[Union[List[int], int]] = None,
111
- top_logprobs_num: Optional[Union[List[int], int]] = None,
112
- token_ids_logprob: Optional[Union[List[List[int]], List[int]]] = None,
113
- lora_path: Optional[List[Optional[str]]] = None,
114
- custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None,
115
- return_hidden_states: bool = False,
116
- stream: bool = False,
117
- ) -> Union[Dict, AsyncIterator[Dict], Iterator[Dict]]:
118
- """
119
- The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
120
- Please refer to `GenerateReqInput` for the documentation.
121
- """
122
- modalities_list = []
123
-
124
- # EDIT
125
- if isinstance(image_data, list):
126
- for _ in range(len(image_data)):
127
- modalities_list.append(["image"])
128
- elif image_data is not None:
129
- modalities_list.append("image")
130
-
131
- # ADD
132
- if custom_logit_processor is None:
133
- custom_logit_processor = Mineru2LogitProcessor().to_str()
134
-
135
- obj = GenerateReqInput(
136
- text=prompt,
137
- input_ids=input_ids,
138
- sampling_params=sampling_params,
139
- image_data=image_data,
140
- return_logprob=return_logprob,
141
- logprob_start_len=logprob_start_len,
142
- top_logprobs_num=top_logprobs_num,
143
- token_ids_logprob=token_ids_logprob,
144
- lora_path=lora_path,
145
- modalities=modalities_list,
146
- custom_logit_processor=custom_logit_processor,
147
- return_hidden_states=return_hidden_states,
148
- stream=stream,
149
- )
150
- generator = _generate_request(self.tokenizer_manager, obj, None)
151
-
152
- if stream is True:
153
- return generator
154
- else:
155
- return await generator.__anext__()
156
-
157
-
158
- def _auto_create_handle_loop(self: TokenizerManager):
159
- """
160
- patch the original `auto_create_handle_loop()` method to reset `no_create_loop`
161
- when the event loop changes.
162
- """
163
- try:
164
- curr_handle_loop = asyncio.get_running_loop()
165
- except RuntimeError:
166
- curr_handle_loop = None
167
-
168
- last_handle_loop = getattr(self, "_last_handle_loop", None)
169
- if last_handle_loop != curr_handle_loop:
170
- self.no_create_loop = False
171
- setattr(self, "_last_handle_loop", curr_handle_loop)
172
- return TokenizerManager.auto_create_handle_loop(self)
173
-
174
-
175
- def _patch_tokenizer_manager(self: TokenizerManager):
176
- self.auto_create_handle_loop = MethodType(_auto_create_handle_loop, self)
177
-
178
-
179
- async def _one_request(
180
- self: TokenizerManager,
181
- obj: Union[GenerateReqInput, EmbeddingReqInput],
182
- request: Optional[fastapi.Request],
183
- created_time: Optional[float],
184
- ):
185
- tokenized_obj = await self._tokenize_one_request(obj)
186
- state = self._send_one_request(obj, tokenized_obj, created_time)
187
- async for out in self._wait_one_response(obj, state, request):
188
- yield out
189
-
190
-
191
- async def _handle_batch_request(
192
- self: TokenizerManager,
193
- obj: Union[GenerateReqInput, EmbeddingReqInput],
194
- request: Optional[fastapi.Request] = None,
195
- created_time: Optional[float] = None,
196
- ):
197
- batch_size = obj.batch_size
198
-
199
- generators = []
200
- rids = []
201
-
202
- if getattr(obj, "parallel_sample_num", 1) != 1:
203
- raise Exception("parallel_sample_num != 1 is not supported in this patched code.")
204
-
205
- # Send all requests
206
- for i in range(batch_size):
207
- tmp_obj = obj[i]
208
- generators.append(_one_request(self, tmp_obj, request, created_time))
209
- rids.append(tmp_obj.rid)
210
-
211
- # Wait for all requests
212
- is_stream = hasattr(obj, "stream") and obj.stream
213
- if not is_stream:
214
- outputs = await asyncio.gather(*(gen.__anext__() for gen in generators))
215
- yield outputs
216
- else:
217
- rid_to_index = {rid: i for i, rid in enumerate(rids)}
218
- task_map = {asyncio.create_task(gen.__anext__()): gen for gen in generators}
219
- while task_map:
220
- done, _ = await asyncio.wait(task_map.keys(), return_when=asyncio.FIRST_COMPLETED)
221
-
222
- for task in done:
223
- gen = task_map.pop(task)
224
- try:
225
- result = task.result()
226
- result["index"] = rid_to_index[result["meta_info"]["id"]]
227
- yield result
228
- new_task = asyncio.create_task(gen.__anext__())
229
- task_map[new_task] = gen
230
- except StopAsyncIteration:
231
- pass
232
-
233
-
234
- async def _generate_request(
235
- self: TokenizerManager,
236
- obj: Union[GenerateReqInput, EmbeddingReqInput],
237
- request: Optional[fastapi.Request] = None,
238
- ):
239
- created_time = time.time()
240
-
241
- self.auto_create_handle_loop()
242
-
243
- if isinstance(obj, EmbeddingReqInput) and self.is_generation:
244
- raise ValueError(
245
- "This model does not appear to be an embedding model by default. "
246
- "Please add `--is-embedding` when launching the server or try another model."
247
- )
248
-
249
- obj.normalize_batch_and_arguments()
250
-
251
- if self.log_requests:
252
- max_length, skip_names, _ = self.log_request_metadata
253
- logger.info(f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}")
254
-
255
- async with self.model_update_lock.reader_lock:
256
- is_single = obj.is_single
257
- if is_single:
258
- tokenized_obj = await self._tokenize_one_request(obj)
259
- state = self._send_one_request(obj, tokenized_obj, created_time)
260
- async for response in self._wait_one_response(obj, state, request):
261
- yield response
262
- else:
263
- async for response in _handle_batch_request(self, obj, request, created_time):
264
- yield response
@@ -1,213 +0,0 @@
1
- import ast
2
- import asyncio
3
- import re
4
- from typing import List, Optional, Union
5
-
6
- import numpy as np
7
-
8
- from sglang.version import __version__ as sglang_version
9
- from packaging import version
10
- if version.parse(sglang_version) >= version.parse("0.4.9"):
11
- # sglang >= 0.4.9
12
- from sglang.srt.multimodal.processors.base_processor import (
13
- BaseMultimodalProcessor as BaseProcessor,
14
- )
15
- from sglang.srt.multimodal.mm_utils import divide_to_patches, expand2square, select_best_resolution
16
- else:
17
- # 0.4.7 <= sglang < 0.4.9
18
- from sglang.srt.managers.multimodal_processors.base_processor import (
19
- BaseMultimodalProcessor as BaseProcessor,
20
- )
21
- from sglang.srt.mm_utils import divide_to_patches, expand2square, select_best_resolution
22
-
23
- get_global_processor = None
24
- from sglang.srt.utils import load_image, logger
25
- from sglang.utils import get_exception_traceback
26
-
27
- from .model import Mineru2QwenForCausalLM
28
-
29
-
30
- # image_best_res is only resized (not padded).
31
- def process_anyres_image(image, processor, grid_pinpoints):
32
- if isinstance(grid_pinpoints, str) and "x" in grid_pinpoints:
33
- patch_size = processor.crop_size["height"]
34
- assert patch_size in [224, 336, 384, 448, 512], "patch_size should be in [224, 336, 384, 448, 512]"
35
- matches = re.findall(r"\((\d+)x(\d+)\)", grid_pinpoints)
36
- range_start = tuple(map(int, matches[0]))
37
- range_end = tuple(map(int, matches[-1]))
38
- grid_pinpoints = [
39
- (i, j) for i in range(range_start[0], range_end[0] + 1) for j in range(range_start[1], range_end[1] + 1)
40
- ]
41
- grid_pinpoints = [[dim * patch_size for dim in pair] for pair in grid_pinpoints]
42
-
43
- if type(grid_pinpoints) is list:
44
- possible_resolutions = grid_pinpoints
45
- else:
46
- possible_resolutions = ast.literal_eval(grid_pinpoints)
47
- best_resolution = select_best_resolution(image.size, possible_resolutions)
48
-
49
- image_best_res = image.resize(best_resolution) # <<<<<<< Here changed
50
- patches = divide_to_patches(image_best_res, processor.crop_size["height"])
51
- image_original_resize = image.resize((processor.crop_size["height"], processor.crop_size["height"]))
52
-
53
- image_patches = [image_original_resize] + patches
54
- image_patches = [processor.preprocess(image_patch)["pixel_values"][0] for image_patch in image_patches]
55
- return np.stack(image_patches, axis=0)
56
-
57
-
58
- class Mineru2ImageProcessor(BaseProcessor):
59
- def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
60
- super().__init__(hf_config, server_args, _processor, *args, **kwargs)
61
-
62
- @staticmethod
63
- def _process_single_image_task(
64
- image_data: Union[str, bytes],
65
- image_aspect_ratio: Optional[str] = None,
66
- image_grid_pinpoints: Optional[str] = None,
67
- image_processor=None,
68
- ):
69
- if image_processor is None:
70
- assert get_global_processor is not None
71
- image_processor = get_global_processor().image_processor
72
-
73
- try:
74
- image, image_size = load_image(image_data)
75
- if image_size is not None:
76
- # It is a video with multiple images
77
- image_hash = hash(image_data)
78
- pixel_values = image_processor(image)["pixel_values"]
79
- pixel_values = np.stack(pixel_values, axis=0)
80
- return pixel_values, image_hash, image_size
81
- else:
82
- # It is an image
83
- image_hash = hash(image_data)
84
- if image_aspect_ratio == "pad":
85
- image = expand2square(
86
- image,
87
- tuple(int(x * 255) for x in image_processor.image_mean),
88
- )
89
- pixel_values = image_processor(image.convert("RGB"))["pixel_values"][0]
90
- elif image_aspect_ratio == "anyres" or (image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio):
91
- pixel_values = process_anyres_image(image, image_processor, image_grid_pinpoints)
92
- else:
93
- pixel_values = image_processor(image)["pixel_values"][0]
94
- return pixel_values, image_hash, image.size
95
- except Exception:
96
- logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
97
-
98
- async def _process_single_image(self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str):
99
- if hasattr(self, "cpu_executor"):
100
- executor = self.cpu_executor
101
- else:
102
- executor = self.executor
103
-
104
- if get_global_processor is not None:
105
- image_processor = None # save ipc cost
106
- else:
107
- image_processor = self._processor.image_processor
108
-
109
- if executor is not None:
110
- loop = asyncio.get_running_loop()
111
- return await loop.run_in_executor(
112
- executor,
113
- Mineru2ImageProcessor._process_single_image_task,
114
- image_data,
115
- aspect_ratio,
116
- grid_pinpoints,
117
- image_processor,
118
- )
119
- else:
120
- return self._process_single_image_task(
121
- image_data,
122
- aspect_ratio,
123
- grid_pinpoints,
124
- image_processor,
125
- )
126
-
127
- async def process_mm_data_async(
128
- self,
129
- image_data: List[Union[str, bytes]],
130
- input_text,
131
- request_obj,
132
- *args,
133
- **kwargs,
134
- ):
135
- from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
136
-
137
- if not image_data:
138
- return None
139
-
140
- modalities = request_obj.modalities or ["image"]
141
- aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
142
- grid_pinpoints = (
143
- self.hf_config.image_grid_pinpoints
144
- if hasattr(self.hf_config, "image_grid_pinpoints")
145
- and "anyres" in aspect_ratio
146
- else None
147
- )
148
-
149
- if isinstance(image_data, str):
150
- image_data = [image_data]
151
-
152
- if isinstance(image_data, list) and len(image_data) > 0:
153
- if "multi-images" in modalities or "video" in modalities:
154
- # Multiple images
155
- aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
156
- pixel_values, data_hashes, image_sizes = [], [], []
157
- res = []
158
- for img_data in image_data:
159
- res.append(
160
- self._process_single_image(
161
- img_data, aspect_ratio, grid_pinpoints
162
- )
163
- )
164
-
165
- res = await asyncio.gather(*res)
166
- for pixel_v, image_h, image_s in res:
167
- pixel_values.append(pixel_v)
168
- data_hashes.append(image_h)
169
- image_sizes.append(image_s)
170
-
171
- if isinstance(pixel_values[0], np.ndarray):
172
- pixel_values = np.stack(pixel_values, axis=0)
173
- else:
174
- # A single image
175
- pixel_values, image_hash, image_size = await self._process_single_image(
176
- image_data[0], aspect_ratio, grid_pinpoints
177
- )
178
- image_sizes = [image_size]
179
- else:
180
- raise ValueError(f"Invalid image data: {image_data}")
181
- modality = Modality.IMAGE
182
- if isinstance(request_obj.modalities, list):
183
- if request_obj.modalities[0] == "multi-images":
184
- modality = Modality.MULTI_IMAGES
185
- elif request_obj.modalities[0] == "video":
186
- modality = Modality.VIDEO
187
-
188
- if version.parse(sglang_version) >= version.parse("0.4.9.post3"):
189
- # sglang >= 0.4.9.post3
190
- return {
191
- "mm_items": [
192
- MultimodalDataItem(
193
- feature=pixel_values,
194
- model_specific_data={
195
- "image_sizes": image_sizes,
196
- },
197
- modality=modality,
198
- )
199
- ],
200
- }
201
- else:
202
- # 0.4.7 <= sglang <= 0.4.9.post2
203
- return {
204
- "mm_items": [
205
- MultimodalDataItem(
206
- pixel_values=pixel_values,
207
- image_sizes=image_sizes,
208
- modality=modality,
209
- )
210
- ],
211
- }
212
-
213
- ImageProcessorMapping = {Mineru2QwenForCausalLM: Mineru2ImageProcessor}
@@ -1,90 +0,0 @@
1
- from typing import List
2
-
3
- from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
4
-
5
-
6
- class Mineru2LogitProcessor(CustomLogitProcessor):
7
- """
8
- Stateless logit processor for Mineru2.
9
-
10
- (base-class: sglang.srt.sampling.custom_logit_processor.CustomLogitProcessor)
11
-
12
- This processor applies token-level constraints to prevent repetition during generation.
13
- It supports two main constraints:
14
-
15
- - no_repeat_ngram_size (int):
16
- Prevents repeating the same n-gram of specified size in the output.
17
- Inspired by Hugging Face's NoRepeatNGramLogitsProcessor.
18
- This implementation is slower due to its lack of specialized optimization.
19
-
20
- - no_repeat_token_count (int):
21
- (Placeholder for future logic)
22
- Intended to prevent repeating the same token multiple times.
23
- Not yet implemented in this version.
24
- """
25
-
26
- def __init__(self) -> None:
27
- super().__init__()
28
- self._generated_ngrams = {} # Cache of generated n-grams by request ID
29
- self._time = {} # Timestamp of the last update for each request
30
- self._gen_step = 0 # Global generation step counter
31
-
32
- def __call__(self, logits, batch_info: List[dict]):
33
- """
34
- Applies repetition constraints to the logits before sampling tokens.
35
-
36
- Args:
37
- logits (FloatTensor): A tensor of shape (batch_size, vocab_size) containing raw token logits.
38
- batch_info (List[dict]): A list of metadata dicts for each sample in the batch. Each dict must include:
39
- - "__req__": Request object containing request ID and output_ids.
40
- - "no_repeat_ngram_size": Size of n-gram to avoid repeating.
41
-
42
- Returns:
43
- FloatTensor: The modified logits tensor with banned token logits set to -inf.
44
- """
45
- from sglang.srt.managers.schedule_batch import Req
46
-
47
- self._gen_step += 1 # Update global generation step
48
-
49
- for idx, info in enumerate(batch_info):
50
- if not isinstance(info, dict) or "__req__" not in info:
51
- continue
52
-
53
- req: Req = info["__req__"]
54
- rid = req.rid
55
- output_ids = req.output_ids
56
- ngram_size = info.get("no_repeat_ngram_size", 0)
57
-
58
- # Skip if there are not enough tokens to form an n-gram
59
- if ngram_size <= 0 or len(output_ids) < ngram_size:
60
- continue
61
-
62
- # Record the current step for cache cleanup tracking
63
- self._time[rid] = self._gen_step
64
-
65
- # Initialize n-gram cache for this request if it doesn't exist
66
- if rid not in self._generated_ngrams:
67
- self._generated_ngrams[rid] = {}
68
-
69
- # Get the n-gram prefix (all but the last token)
70
- prev_ngram = tuple(output_ids[-ngram_size:-1])
71
- last_token = output_ids[-1]
72
-
73
- # Store this n-gram occurrence
74
- self._generated_ngrams[rid][prev_ngram] = self._generated_ngrams[rid].get(prev_ngram, []) + [last_token]
75
-
76
- # Get the next-token candidates to ban based on current prefix
77
- current_prefix = tuple(output_ids[-ngram_size + 1 :])
78
- banned_tokens = self._generated_ngrams[rid].get(current_prefix, [])
79
-
80
- # Set the logits of banned tokens to negative infinity
81
- for token in banned_tokens:
82
- logits[idx][token] = -float("inf")
83
-
84
- # Clean up cache for expired requests
85
- expired_rids = [rid for rid, last_used in self._time.items() if last_used < self._gen_step]
86
- for rid in expired_rids:
87
- self._generated_ngrams.pop(rid, None)
88
- self._time.pop(rid, None)
89
-
90
- return logits