docling 2.54.0__py3-none-any.whl → 2.55.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling/backend/asciidoc_backend.py +1 -1
- docling/backend/html_backend.py +254 -136
- docling/backend/md_backend.py +4 -1
- docling/backend/msword_backend.py +1 -1
- docling/backend/xml/jats_backend.py +111 -7
- docling/backend/xml/uspto_backend.py +1 -1
- docling/cli/main.py +5 -0
- docling/datamodel/pipeline_options_vlm_model.py +13 -2
- docling/datamodel/vlm_model_specs.py +9 -0
- docling/models/api_vlm_model.py +45 -16
- docling/models/base_model.py +2 -1
- docling/models/readingorder_model.py +1 -1
- docling/models/utils/generation_utils.py +157 -0
- docling/models/utils/hf_model_download.py +6 -1
- docling/models/vlm_models_inline/hf_transformers_model.py +75 -14
- docling/models/vlm_models_inline/mlx_model.py +58 -1
- docling/models/vlm_models_inline/vllm_model.py +189 -124
- docling/utils/api_image_request.py +107 -1
- {docling-2.54.0.dist-info → docling-2.55.0.dist-info}/METADATA +2 -2
- {docling-2.54.0.dist-info → docling-2.55.0.dist-info}/RECORD +24 -23
- {docling-2.54.0.dist-info → docling-2.55.0.dist-info}/WHEEL +0 -0
- {docling-2.54.0.dist-info → docling-2.55.0.dist-info}/entry_points.txt +0 -0
- {docling-2.54.0.dist-info → docling-2.55.0.dist-info}/licenses/LICENSE +0 -0
- {docling-2.54.0.dist-info → docling-2.55.0.dist-info}/top_level.txt +0 -0
|
@@ -2,9 +2,9 @@ import logging
|
|
|
2
2
|
import traceback
|
|
3
3
|
from io import BytesIO
|
|
4
4
|
from pathlib import Path
|
|
5
|
-
from typing import Final, Optional, Union
|
|
5
|
+
from typing import Final, Optional, Union, cast
|
|
6
6
|
|
|
7
|
-
from bs4 import BeautifulSoup, Tag
|
|
7
|
+
from bs4 import BeautifulSoup, NavigableString, Tag
|
|
8
8
|
from docling_core.types.doc import (
|
|
9
9
|
DocItemLabel,
|
|
10
10
|
DoclingDocument,
|
|
@@ -12,6 +12,8 @@ from docling_core.types.doc import (
|
|
|
12
12
|
GroupItem,
|
|
13
13
|
GroupLabel,
|
|
14
14
|
NodeItem,
|
|
15
|
+
TableCell,
|
|
16
|
+
TableData,
|
|
15
17
|
TextItem,
|
|
16
18
|
)
|
|
17
19
|
from lxml import etree
|
|
@@ -350,7 +352,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
|
|
350
352
|
|
|
351
353
|
return
|
|
352
354
|
|
|
353
|
-
def _parse_element_citation(self, node: etree._Element) -> str:
|
|
355
|
+
def _parse_element_citation(self, node: etree._Element) -> str:
|
|
354
356
|
citation: Citation = {
|
|
355
357
|
"author_names": "",
|
|
356
358
|
"title": "",
|
|
@@ -535,6 +537,110 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
|
|
535
537
|
|
|
536
538
|
return
|
|
537
539
|
|
|
540
|
+
@staticmethod
|
|
541
|
+
def parse_table_data(element: Tag) -> Optional[TableData]:
|
|
542
|
+
# TODO, see how to implement proper support for rich tables from HTML backend
|
|
543
|
+
nested_tables = element.find("table")
|
|
544
|
+
if nested_tables is not None:
|
|
545
|
+
_log.debug("Skipping nested table.")
|
|
546
|
+
return None
|
|
547
|
+
|
|
548
|
+
# Find the number of rows and columns (taking into account spans)
|
|
549
|
+
num_rows = 0
|
|
550
|
+
num_cols = 0
|
|
551
|
+
for row in element("tr"):
|
|
552
|
+
col_count = 0
|
|
553
|
+
is_row_header = True
|
|
554
|
+
if not isinstance(row, Tag):
|
|
555
|
+
continue
|
|
556
|
+
for cell in row(["td", "th"]):
|
|
557
|
+
if not isinstance(row, Tag):
|
|
558
|
+
continue
|
|
559
|
+
cell_tag = cast(Tag, cell)
|
|
560
|
+
col_span, row_span = HTMLDocumentBackend._get_cell_spans(cell_tag)
|
|
561
|
+
col_count += col_span
|
|
562
|
+
if cell_tag.name == "td" or row_span == 1:
|
|
563
|
+
is_row_header = False
|
|
564
|
+
num_cols = max(num_cols, col_count)
|
|
565
|
+
if not is_row_header:
|
|
566
|
+
num_rows += 1
|
|
567
|
+
|
|
568
|
+
_log.debug(f"The table has {num_rows} rows and {num_cols} cols.")
|
|
569
|
+
|
|
570
|
+
grid: list = [[None for _ in range(num_cols)] for _ in range(num_rows)]
|
|
571
|
+
|
|
572
|
+
data = TableData(num_rows=num_rows, num_cols=num_cols, table_cells=[])
|
|
573
|
+
|
|
574
|
+
# Iterate over the rows in the table
|
|
575
|
+
start_row_span = 0
|
|
576
|
+
row_idx = -1
|
|
577
|
+
for row in element("tr"):
|
|
578
|
+
if not isinstance(row, Tag):
|
|
579
|
+
continue
|
|
580
|
+
|
|
581
|
+
# For each row, find all the column cells (both <td> and <th>)
|
|
582
|
+
cells = row(["td", "th"])
|
|
583
|
+
|
|
584
|
+
# Check if cell is in a column header or row header
|
|
585
|
+
col_header = True
|
|
586
|
+
row_header = True
|
|
587
|
+
for html_cell in cells:
|
|
588
|
+
if isinstance(html_cell, Tag):
|
|
589
|
+
_, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
|
|
590
|
+
if html_cell.name == "td":
|
|
591
|
+
col_header = False
|
|
592
|
+
row_header = False
|
|
593
|
+
elif row_span == 1:
|
|
594
|
+
row_header = False
|
|
595
|
+
if not row_header:
|
|
596
|
+
row_idx += 1
|
|
597
|
+
start_row_span = 0
|
|
598
|
+
else:
|
|
599
|
+
start_row_span += 1
|
|
600
|
+
|
|
601
|
+
# Extract the text content of each cell
|
|
602
|
+
col_idx = 0
|
|
603
|
+
for html_cell in cells:
|
|
604
|
+
if not isinstance(html_cell, Tag):
|
|
605
|
+
continue
|
|
606
|
+
|
|
607
|
+
# extract inline formulas
|
|
608
|
+
for formula in html_cell("inline-formula"):
|
|
609
|
+
math_parts = formula.text.split("$$")
|
|
610
|
+
if len(math_parts) == 3:
|
|
611
|
+
math_formula = f"$${math_parts[1]}$$"
|
|
612
|
+
formula.replace_with(NavigableString(math_formula))
|
|
613
|
+
|
|
614
|
+
# TODO: extract content correctly from table-cells with lists
|
|
615
|
+
text = HTMLDocumentBackend.get_text(html_cell).strip()
|
|
616
|
+
col_span, row_span = HTMLDocumentBackend._get_cell_spans(html_cell)
|
|
617
|
+
if row_header:
|
|
618
|
+
row_span -= 1
|
|
619
|
+
while (
|
|
620
|
+
col_idx < num_cols
|
|
621
|
+
and grid[row_idx + start_row_span][col_idx] is not None
|
|
622
|
+
):
|
|
623
|
+
col_idx += 1
|
|
624
|
+
for r in range(start_row_span, start_row_span + row_span):
|
|
625
|
+
for c in range(col_span):
|
|
626
|
+
if row_idx + r < num_rows and col_idx + c < num_cols:
|
|
627
|
+
grid[row_idx + r][col_idx + c] = text
|
|
628
|
+
|
|
629
|
+
table_cell = TableCell(
|
|
630
|
+
text=text,
|
|
631
|
+
row_span=row_span,
|
|
632
|
+
col_span=col_span,
|
|
633
|
+
start_row_offset_idx=start_row_span + row_idx,
|
|
634
|
+
end_row_offset_idx=start_row_span + row_idx + row_span,
|
|
635
|
+
start_col_offset_idx=col_idx,
|
|
636
|
+
end_col_offset_idx=col_idx + col_span,
|
|
637
|
+
column_header=col_header,
|
|
638
|
+
row_header=((not col_header) and html_cell.name == "th"),
|
|
639
|
+
)
|
|
640
|
+
data.table_cells.append(table_cell)
|
|
641
|
+
|
|
642
|
+
return data
|
|
643
|
+
|
|
538
644
|
def _add_table(
|
|
539
645
|
self, doc: DoclingDocument, parent: NodeItem, table_xml_component: Table
|
|
540
646
|
) -> None:
|
|
@@ -543,8 +649,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
|
|
543
649
|
if not isinstance(table_tag, Tag):
|
|
544
650
|
return
|
|
545
651
|
|
|
546
|
-
data =
|
|
547
|
-
|
|
652
|
+
data = JatsDocumentBackend.parse_table_data(table_tag)
|
|
548
653
|
# TODO: format label vs caption once styling is supported
|
|
549
654
|
label = table_xml_component["label"]
|
|
550
655
|
caption = table_xml_component["caption"]
|
|
@@ -554,7 +659,6 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
|
|
554
659
|
if table_text
|
|
555
660
|
else None
|
|
556
661
|
)
|
|
557
|
-
|
|
558
662
|
if data is not None:
|
|
559
663
|
doc.add_table(data=data, parent=parent, caption=table_caption)
|
|
560
664
|
|
|
@@ -609,7 +713,7 @@ class JatsDocumentBackend(DeclarativeDocumentBackend):
|
|
|
609
713
|
)
|
|
610
714
|
return
|
|
611
715
|
|
|
612
|
-
def _walk_linear(
|
|
716
|
+
def _walk_linear(
|
|
613
717
|
self, doc: DoclingDocument, parent: NodeItem, node: etree._Element
|
|
614
718
|
) -> str:
|
|
615
719
|
skip_tags = ["term"]
|
docling/cli/main.py
CHANGED
|
@@ -66,6 +66,7 @@ from docling.datamodel.vlm_model_specs import (
|
|
|
66
66
|
GRANITE_VISION_TRANSFORMERS,
|
|
67
67
|
GRANITEDOCLING_MLX,
|
|
68
68
|
GRANITEDOCLING_TRANSFORMERS,
|
|
69
|
+
GRANITEDOCLING_VLLM,
|
|
69
70
|
SMOLDOCLING_MLX,
|
|
70
71
|
SMOLDOCLING_TRANSFORMERS,
|
|
71
72
|
SMOLDOCLING_VLLM,
|
|
@@ -686,6 +687,7 @@ def convert( # noqa: C901
|
|
|
686
687
|
"To run SmolDocling faster, please install mlx-vlm:\n"
|
|
687
688
|
"pip install mlx-vlm"
|
|
688
689
|
)
|
|
690
|
+
|
|
689
691
|
elif vlm_model == VlmModelType.GRANITEDOCLING:
|
|
690
692
|
pipeline_options.vlm_options = GRANITEDOCLING_TRANSFORMERS
|
|
691
693
|
if sys.platform == "darwin":
|
|
@@ -701,6 +703,9 @@ def convert( # noqa: C901
|
|
|
701
703
|
elif vlm_model == VlmModelType.SMOLDOCLING_VLLM:
|
|
702
704
|
pipeline_options.vlm_options = SMOLDOCLING_VLLM
|
|
703
705
|
|
|
706
|
+
elif vlm_model == VlmModelType.GRANITEDOCLING_VLLM:
|
|
707
|
+
pipeline_options.vlm_options = GRANITEDOCLING_VLLM
|
|
708
|
+
|
|
704
709
|
pdf_format_option = PdfFormatOption(
|
|
705
710
|
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
|
|
706
711
|
)
|
|
@@ -1,11 +1,13 @@
|
|
|
1
1
|
from enum import Enum
|
|
2
|
-
from typing import Any, Dict, List, Literal, Optional
|
|
2
|
+
from typing import Any, Dict, List, Literal, Optional, Union
|
|
3
3
|
|
|
4
4
|
from docling_core.types.doc.page import SegmentedPage
|
|
5
|
-
from pydantic import AnyUrl, BaseModel
|
|
5
|
+
from pydantic import AnyUrl, BaseModel, ConfigDict
|
|
6
|
+
from transformers import StoppingCriteria
|
|
6
7
|
from typing_extensions import deprecated
|
|
7
8
|
|
|
8
9
|
from docling.datamodel.accelerator_options import AcceleratorDevice
|
|
10
|
+
from docling.models.utils.generation_utils import GenerationStopper
|
|
9
11
|
|
|
10
12
|
|
|
11
13
|
class BaseVlmOptions(BaseModel):
|
|
@@ -50,9 +52,12 @@ class TransformersPromptStyle(str, Enum):
|
|
|
50
52
|
|
|
51
53
|
|
|
52
54
|
class InlineVlmOptions(BaseVlmOptions):
|
|
55
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
56
|
+
|
|
53
57
|
kind: Literal["inline_model_options"] = "inline_model_options"
|
|
54
58
|
|
|
55
59
|
repo_id: str
|
|
60
|
+
revision: str = "main"
|
|
56
61
|
trust_remote_code: bool = False
|
|
57
62
|
load_in_8bit: bool = True
|
|
58
63
|
llm_int8_threshold: float = 6.0
|
|
@@ -71,6 +76,7 @@ class InlineVlmOptions(BaseVlmOptions):
|
|
|
71
76
|
]
|
|
72
77
|
|
|
73
78
|
stop_strings: List[str] = []
|
|
79
|
+
custom_stopping_criteria: List[Union[StoppingCriteria, GenerationStopper]] = []
|
|
74
80
|
extra_generation_config: Dict[str, Any] = {}
|
|
75
81
|
extra_processor_kwargs: Dict[str, Any] = {}
|
|
76
82
|
|
|
@@ -88,6 +94,8 @@ class HuggingFaceVlmOptions(InlineVlmOptions):
|
|
|
88
94
|
|
|
89
95
|
|
|
90
96
|
class ApiVlmOptions(BaseVlmOptions):
|
|
97
|
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
|
98
|
+
|
|
91
99
|
kind: Literal["api_model_options"] = "api_model_options"
|
|
92
100
|
|
|
93
101
|
url: AnyUrl = AnyUrl(
|
|
@@ -98,3 +106,6 @@ class ApiVlmOptions(BaseVlmOptions):
|
|
|
98
106
|
timeout: float = 60
|
|
99
107
|
concurrency: int = 1
|
|
100
108
|
response_format: ResponseFormat
|
|
109
|
+
|
|
110
|
+
stop_strings: List[str] = []
|
|
111
|
+
custom_stopping_criteria: List[Union[GenerationStopper]] = []
|
|
@@ -29,12 +29,20 @@ GRANITEDOCLING_TRANSFORMERS = InlineVlmOptions(
|
|
|
29
29
|
AcceleratorDevice.CPU,
|
|
30
30
|
AcceleratorDevice.CUDA,
|
|
31
31
|
],
|
|
32
|
+
extra_generation_config=dict(skip_special_tokens=False),
|
|
32
33
|
scale=2.0,
|
|
33
34
|
temperature=0.0,
|
|
34
35
|
max_new_tokens=8192,
|
|
35
36
|
stop_strings=["</doctag>", "<|end_of_text|>"],
|
|
36
37
|
)
|
|
37
38
|
|
|
39
|
+
GRANITEDOCLING_VLLM = GRANITEDOCLING_TRANSFORMERS.model_copy()
|
|
40
|
+
GRANITEDOCLING_VLLM.inference_framework = InferenceFramework.VLLM
|
|
41
|
+
GRANITEDOCLING_VLLM.revision = (
|
|
42
|
+
"untied" # change back to "main" with next vllm relase after 0.10.2
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
|
|
38
46
|
GRANITEDOCLING_MLX = InlineVlmOptions(
|
|
39
47
|
repo_id="ibm-granite/granite-docling-258M-mlx",
|
|
40
48
|
prompt="Convert this page to docling.",
|
|
@@ -302,3 +310,4 @@ class VlmModelType(str, Enum):
|
|
|
302
310
|
GRANITE_VISION_OLLAMA = "granite_vision_ollama"
|
|
303
311
|
GOT_OCR_2 = "got_ocr_2"
|
|
304
312
|
GRANITEDOCLING = "granite_docling"
|
|
313
|
+
GRANITEDOCLING_VLLM = "granite_docling_vllm"
|
docling/models/api_vlm_model.py
CHANGED
|
@@ -1,12 +1,18 @@
|
|
|
1
1
|
from collections.abc import Iterable
|
|
2
2
|
from concurrent.futures import ThreadPoolExecutor
|
|
3
3
|
|
|
4
|
+
from transformers import StoppingCriteria
|
|
5
|
+
|
|
4
6
|
from docling.datamodel.base_models import Page, VlmPrediction
|
|
5
7
|
from docling.datamodel.document import ConversionResult
|
|
6
8
|
from docling.datamodel.pipeline_options_vlm_model import ApiVlmOptions
|
|
7
9
|
from docling.exceptions import OperationNotAllowed
|
|
8
10
|
from docling.models.base_model import BasePageModel
|
|
9
|
-
from docling.utils.
|
|
11
|
+
from docling.models.utils.generation_utils import GenerationStopper
|
|
12
|
+
from docling.utils.api_image_request import (
|
|
13
|
+
api_image_request,
|
|
14
|
+
api_image_request_streaming,
|
|
15
|
+
)
|
|
10
16
|
from docling.utils.profiling import TimeRecorder
|
|
11
17
|
|
|
12
18
|
|
|
@@ -41,19 +47,43 @@ class ApiVlmModel(BasePageModel):
|
|
|
41
47
|
assert page._backend is not None
|
|
42
48
|
if not page._backend.is_valid():
|
|
43
49
|
return page
|
|
44
|
-
else:
|
|
45
|
-
with TimeRecorder(conv_res, "vlm"):
|
|
46
|
-
assert page.size is not None
|
|
47
50
|
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
51
|
+
with TimeRecorder(conv_res, "vlm"):
|
|
52
|
+
assert page.size is not None
|
|
53
|
+
|
|
54
|
+
hi_res_image = page.get_image(
|
|
55
|
+
scale=self.vlm_options.scale, max_size=self.vlm_options.max_size
|
|
56
|
+
)
|
|
57
|
+
assert hi_res_image is not None
|
|
58
|
+
if hi_res_image and hi_res_image.mode != "RGB":
|
|
59
|
+
hi_res_image = hi_res_image.convert("RGB")
|
|
55
60
|
|
|
56
|
-
|
|
61
|
+
prompt = self.vlm_options.build_prompt(page.parsed_page)
|
|
62
|
+
|
|
63
|
+
if self.vlm_options.custom_stopping_criteria:
|
|
64
|
+
# Instantiate any GenerationStopper classes before passing to streaming
|
|
65
|
+
instantiated_stoppers = []
|
|
66
|
+
for criteria in self.vlm_options.custom_stopping_criteria:
|
|
67
|
+
if isinstance(criteria, GenerationStopper):
|
|
68
|
+
instantiated_stoppers.append(criteria)
|
|
69
|
+
elif isinstance(criteria, type) and issubclass(
|
|
70
|
+
criteria, GenerationStopper
|
|
71
|
+
):
|
|
72
|
+
instantiated_stoppers.append(criteria())
|
|
73
|
+
# Skip non-GenerationStopper criteria (should have been caught in validation)
|
|
74
|
+
|
|
75
|
+
# Streaming path with early abort support
|
|
76
|
+
page_tags = api_image_request_streaming(
|
|
77
|
+
image=hi_res_image,
|
|
78
|
+
prompt=prompt,
|
|
79
|
+
url=self.vlm_options.url,
|
|
80
|
+
timeout=self.timeout,
|
|
81
|
+
headers=self.vlm_options.headers,
|
|
82
|
+
generation_stoppers=instantiated_stoppers,
|
|
83
|
+
**self.params,
|
|
84
|
+
)
|
|
85
|
+
else:
|
|
86
|
+
# Non-streaming fallback (existing behavior)
|
|
57
87
|
page_tags = api_image_request(
|
|
58
88
|
image=hi_res_image,
|
|
59
89
|
prompt=prompt,
|
|
@@ -63,10 +93,9 @@ class ApiVlmModel(BasePageModel):
|
|
|
63
93
|
**self.params,
|
|
64
94
|
)
|
|
65
95
|
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
return page
|
|
96
|
+
page_tags = self.vlm_options.decode_response(page_tags)
|
|
97
|
+
page.predictions.vlm_response = VlmPrediction(text=page_tags)
|
|
98
|
+
return page
|
|
70
99
|
|
|
71
100
|
with ThreadPoolExecutor(max_workers=self.concurrency) as executor:
|
|
72
101
|
yield from executor.map(_vlm_request, page_batch)
|
docling/models/base_model.py
CHANGED
|
@@ -88,7 +88,8 @@ class BaseVlmPageModel(BasePageModel, BaseVlmModel):
|
|
|
88
88
|
|
|
89
89
|
if self.vlm_options.transformers_prompt_style == TransformersPromptStyle.RAW:
|
|
90
90
|
return user_prompt
|
|
91
|
-
|
|
91
|
+
elif self.vlm_options.transformers_prompt_style == TransformersPromptStyle.NONE:
|
|
92
|
+
return ""
|
|
92
93
|
elif self.vlm_options.repo_id == "microsoft/Phi-4-multimodal-instruct":
|
|
93
94
|
_log.debug("Using specialized prompt for Phi-4")
|
|
94
95
|
# Note: This might need adjustment for VLLM vs transformers
|
|
@@ -103,7 +103,7 @@ class ReadingOrderModel:
|
|
|
103
103
|
else:
|
|
104
104
|
doc.add_text(parent=doc_item, label=c_label, text=c_text, prov=c_prov)
|
|
105
105
|
|
|
106
|
-
def _readingorder_elements_to_docling_doc(
|
|
106
|
+
def _readingorder_elements_to_docling_doc(
|
|
107
107
|
self,
|
|
108
108
|
conv_res: ConversionResult,
|
|
109
109
|
ro_elements: List[ReadingOrderPageElement],
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
import sys
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import List
|
|
6
|
+
|
|
7
|
+
from transformers import StoppingCriteria
|
|
8
|
+
|
|
9
|
+
_log = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class GenerationStopper:
|
|
13
|
+
"""
|
|
14
|
+
Base interface for stopping logic.
|
|
15
|
+
- should_stop(s): True to stop given the current decoded text window.
|
|
16
|
+
- lookback_tokens(): how many tokens should be considered (default: sys.maxsize).
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def should_stop(self, s: str) -> bool:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
def lookback_tokens(self) -> int:
|
|
24
|
+
return sys.maxsize
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class DocTagsRepetitionStopper(GenerationStopper):
|
|
28
|
+
"""
|
|
29
|
+
Detects repetitive <tag>...<loc_x><loc_y><loc_w><loc_h>text</tag> blocks,
|
|
30
|
+
but only when repeats are **consecutive** and both tag & inner text are identical.
|
|
31
|
+
|
|
32
|
+
Performance:
|
|
33
|
+
- Heavy check runs every N calls (default 32).
|
|
34
|
+
- Only decodes the last LOOKBACK_TOKENS tokens per sequence (default 200).
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, *, N: int = 32, lookback_tokens: int = 200):
|
|
38
|
+
self.N = max(1, int(N))
|
|
39
|
+
self._lookback_tokens = max(1, int(lookback_tokens))
|
|
40
|
+
self._call_count = 0
|
|
41
|
+
|
|
42
|
+
# <tag> ... <loc_x><loc_y><loc_w><loc_h> text ... </tag>
|
|
43
|
+
self._PATTERN = re.compile(
|
|
44
|
+
r"""
|
|
45
|
+
<(?P<tag>[a-zA-Z0-9_]+)>\s*
|
|
46
|
+
(?P<prefix>.*?)?
|
|
47
|
+
<loc_(?P<x>\d+)><loc_(?P<y>\d+)><loc_(?P<w>\d+)><loc_(?P<h>\d+)>
|
|
48
|
+
(?P<text>.*?)
|
|
49
|
+
</(?P=tag)>
|
|
50
|
+
""",
|
|
51
|
+
re.DOTALL | re.VERBOSE,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
# --- small helper ---
|
|
55
|
+
def _regular(self, vals: List[int]) -> bool:
|
|
56
|
+
"""3+ strictly increasing values with ~regular spacing (±20%)."""
|
|
57
|
+
if len(vals) < 3:
|
|
58
|
+
return False
|
|
59
|
+
diffs = [b - a for a, b in zip(vals, vals[1:])]
|
|
60
|
+
if any(d <= 0 for d in diffs):
|
|
61
|
+
return False
|
|
62
|
+
mean = sum(diffs) / len(diffs)
|
|
63
|
+
tol = 0.2 * mean
|
|
64
|
+
return all(abs(d - mean) <= tol for d in diffs)
|
|
65
|
+
|
|
66
|
+
def should_stop(self, s: str) -> bool:
|
|
67
|
+
"""
|
|
68
|
+
Trip only on **consecutive** runs (no other matched blocks between) of ≥3 items
|
|
69
|
+
with the same <tag> and identical inner text, where within that run we see:
|
|
70
|
+
- any exact duplicate (x,y,w,h), or
|
|
71
|
+
- stable X/W with regular Y progression, or
|
|
72
|
+
- stable Y/H with regular X progression.
|
|
73
|
+
"""
|
|
74
|
+
# Stream matches and evaluate runs on-the-fly to stay compact and fast.
|
|
75
|
+
prev_tag = prev_text = None
|
|
76
|
+
run = [] # list of (x,y,w,h)
|
|
77
|
+
|
|
78
|
+
def run_repetitive(boxes: List[tuple]) -> bool:
|
|
79
|
+
if len(boxes) < 3:
|
|
80
|
+
return False
|
|
81
|
+
# duplicates?
|
|
82
|
+
if len(set(boxes)) < len(boxes):
|
|
83
|
+
return True
|
|
84
|
+
xs, ys, ws, hs = zip(*boxes)
|
|
85
|
+
x_stable = all(x == xs[0] for x in xs)
|
|
86
|
+
y_stable = all(y == ys[0] for y in ys)
|
|
87
|
+
w_stable = all(w == ws[0] for w in ws)
|
|
88
|
+
h_stable = all(h == hs[0] for h in hs)
|
|
89
|
+
# horizontal (down the page): X/W stable, Y regular
|
|
90
|
+
if (x_stable or w_stable) and self._regular(list(ys)):
|
|
91
|
+
return True
|
|
92
|
+
# vertical (across): Y/H stable, X regular
|
|
93
|
+
if (y_stable or h_stable) and self._regular(list(xs)):
|
|
94
|
+
return True
|
|
95
|
+
return False
|
|
96
|
+
|
|
97
|
+
for m in self._PATTERN.finditer(s):
|
|
98
|
+
tag, text = m.group("tag"), m.group("text")
|
|
99
|
+
box = (
|
|
100
|
+
int(m.group("x")),
|
|
101
|
+
int(m.group("y")),
|
|
102
|
+
int(m.group("w")),
|
|
103
|
+
int(m.group("h")),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
if prev_tag == tag and prev_text == text:
|
|
107
|
+
run.append(box) # consecutive same-tag+text
|
|
108
|
+
else:
|
|
109
|
+
# evaluate previous run before starting a new one
|
|
110
|
+
if run_repetitive(run):
|
|
111
|
+
return True
|
|
112
|
+
prev_tag, prev_text = tag, text
|
|
113
|
+
run = [box]
|
|
114
|
+
|
|
115
|
+
# check the last run
|
|
116
|
+
return run_repetitive(run)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class HFStoppingCriteriaWrapper(StoppingCriteria):
|
|
120
|
+
"""
|
|
121
|
+
Adapts any GenerationStopper to HuggingFace Transformers.
|
|
122
|
+
Decodes exactly min(seq_len, stopper.lookback_tokens()) tokens from the end.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
tokenizer,
|
|
128
|
+
stopper: GenerationStopper,
|
|
129
|
+
*,
|
|
130
|
+
skip_special_tokens: bool = False,
|
|
131
|
+
):
|
|
132
|
+
self.tokenizer = tokenizer
|
|
133
|
+
self.stopper = stopper
|
|
134
|
+
self.skip_special_tokens = skip_special_tokens
|
|
135
|
+
|
|
136
|
+
def __call__(self, input_ids, scores, **kwargs) -> bool:
|
|
137
|
+
lb = max(1, int(self.stopper.lookback_tokens()))
|
|
138
|
+
for seq in input_ids: # (batch, seq_len)
|
|
139
|
+
window = seq[-lb:] # slicing handles lb > len(seq)
|
|
140
|
+
try:
|
|
141
|
+
text = self.tokenizer.decode(
|
|
142
|
+
window, skip_special_tokens=self.skip_special_tokens
|
|
143
|
+
)
|
|
144
|
+
except Exception as e:
|
|
145
|
+
_log.info(f"Decoding failed for stopping check: {e}")
|
|
146
|
+
continue
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
if self.stopper.should_stop(text):
|
|
150
|
+
_log.info(
|
|
151
|
+
"HF wrapper: stopping due to TextStopper.should_stop==True"
|
|
152
|
+
)
|
|
153
|
+
return True
|
|
154
|
+
except Exception as e:
|
|
155
|
+
_log.info(f"Error in TextStopper.should_stop: {e}")
|
|
156
|
+
continue
|
|
157
|
+
return False
|
|
@@ -34,7 +34,12 @@ class HuggingFaceModelDownloadMixin:
|
|
|
34
34
|
local_dir: Optional[Path] = None,
|
|
35
35
|
force: bool = False,
|
|
36
36
|
progress: bool = False,
|
|
37
|
+
revision: Optional[str] = None,
|
|
37
38
|
) -> Path:
|
|
38
39
|
return download_hf_model(
|
|
39
|
-
repo_id=repo_id,
|
|
40
|
+
repo_id=repo_id,
|
|
41
|
+
local_dir=local_dir,
|
|
42
|
+
force=force,
|
|
43
|
+
progress=progress,
|
|
44
|
+
revision=revision,
|
|
40
45
|
)
|