doc-page-extractor 0.0.10__tar.gz → 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of doc-page-extractor might be problematic. Click here for more details.
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/PKG-INFO +6 -2
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/README.md +3 -1
- doc_page_extractor-0.1.1/doc_page_extractor/__init__.py +15 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/extractor.py +54 -10
- doc_page_extractor-0.1.1/doc_page_extractor/latex.py +57 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/ocr.py +1 -1
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/rectangle.py +6 -0
- doc_page_extractor-0.1.1/doc_page_extractor/struct_eqtable/__init__.py +49 -0
- doc_page_extractor-0.1.1/doc_page_extractor/struct_eqtable/internvl/__init__.py +2 -0
- doc_page_extractor-0.1.1/doc_page_extractor/struct_eqtable/internvl/conversation.py +394 -0
- doc_page_extractor-0.1.1/doc_page_extractor/struct_eqtable/internvl/internvl.py +198 -0
- doc_page_extractor-0.1.1/doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +81 -0
- doc_page_extractor-0.1.1/doc_page_extractor/struct_eqtable/pix2s/__init__.py +3 -0
- doc_page_extractor-0.1.1/doc_page_extractor/struct_eqtable/pix2s/pix2s.py +76 -0
- doc_page_extractor-0.1.1/doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +1047 -0
- doc_page_extractor-0.1.1/doc_page_extractor/table.py +71 -0
- doc_page_extractor-0.1.1/doc_page_extractor/types.py +67 -0
- doc_page_extractor-0.1.1/doc_page_extractor/utils.py +32 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor.egg-info/PKG-INFO +6 -2
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor.egg-info/SOURCES.txt +10 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor.egg-info/requires.txt +3 -1
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/setup.py +7 -2
- doc_page_extractor-0.0.10/doc_page_extractor/__init__.py +0 -5
- doc_page_extractor-0.0.10/doc_page_extractor/types.py +0 -36
- doc_page_extractor-0.0.10/doc_page_extractor/utils.py +0 -10
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/LICENSE +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/clipper.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/downloader.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/layout_order.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/layoutreader.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/ocr_corrector.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/__init__.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/cls_postprocess.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/db_postprocess.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/imaug.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/operators.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/predict_base.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/predict_cls.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/predict_det.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/predict_rec.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/predict_system.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/rec_postprocess.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/onnxocr/utils.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/overlap.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/plot.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/raw_optimizer.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor/rotation.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor.egg-info/dependency_links.txt +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/doc_page_extractor.egg-info/top_level.txt +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/setup.cfg +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/tests/__init__.py +0 -0
- {doc_page_extractor-0.0.10 → doc_page_extractor-0.1.1}/tests/test_history_bus.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: doc-page-extractor
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.1.1
|
|
4
4
|
Summary: doc page extractor can identify text and format in images and return structured data.
|
|
5
5
|
Home-page: https://github.com/Moskize91/doc-page-extractor
|
|
6
6
|
Author: Tao Zeyu
|
|
@@ -12,8 +12,10 @@ Requires-Dist: pillow<11.0,>=10.3
|
|
|
12
12
|
Requires-Dist: pyclipper<2.0,>=1.2.0
|
|
13
13
|
Requires-Dist: numpy<2.0,>=1.24.0
|
|
14
14
|
Requires-Dist: shapely<3.0,>=2.0.0
|
|
15
|
-
Requires-Dist: transformers
|
|
15
|
+
Requires-Dist: transformers<=4.47,>=4.42.4
|
|
16
16
|
Requires-Dist: doclayout_yolo>=0.0.3
|
|
17
|
+
Requires-Dist: pix2tex<=0.2.0,>=0.1.4
|
|
18
|
+
Requires-Dist: accelerate<2.0,>=1.6.0
|
|
17
19
|
Dynamic: author
|
|
18
20
|
Dynamic: author-email
|
|
19
21
|
Dynamic: description
|
|
@@ -78,3 +80,5 @@ The code of `doc_page_extractor/onnxocr` in this repo comes from [OnnxOCR](https
|
|
|
78
80
|
- [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
|
|
79
81
|
- [OnnxOCR](https://github.com/jingsongliujing/OnnxOCR)
|
|
80
82
|
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
|
|
83
|
+
- [StructEqTable](https://github.com/Alpha-Innovator/StructEqTable-Deploy)
|
|
84
|
+
- [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)
|
|
@@ -52,4 +52,6 @@ The code of `doc_page_extractor/onnxocr` in this repo comes from [OnnxOCR](https
|
|
|
52
52
|
|
|
53
53
|
- [DocLayout-YOLO](https://github.com/opendatalab/DocLayout-YOLO)
|
|
54
54
|
- [OnnxOCR](https://github.com/jingsongliujing/OnnxOCR)
|
|
55
|
-
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
|
|
55
|
+
- [layoutreader](https://github.com/ppaanngggg/layoutreader)
|
|
56
|
+
- [StructEqTable](https://github.com/Alpha-Innovator/StructEqTable-Deploy)
|
|
57
|
+
- [LaTeX-OCR](https://github.com/lukas-blecher/LaTeX-OCR)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from .extractor import DocExtractor
|
|
2
|
+
from .clipper import clip, clip_from_image
|
|
3
|
+
from .plot import plot
|
|
4
|
+
from .rectangle import Point, Rectangle
|
|
5
|
+
from .types import (
|
|
6
|
+
ExtractedResult,
|
|
7
|
+
OCRFragment,
|
|
8
|
+
LayoutClass,
|
|
9
|
+
TableLayoutParsedFormat,
|
|
10
|
+
Layout,
|
|
11
|
+
BaseLayout,
|
|
12
|
+
PlainLayout,
|
|
13
|
+
FormulaLayout,
|
|
14
|
+
TableLayout,
|
|
15
|
+
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
|
|
3
|
-
from typing import Literal
|
|
3
|
+
from typing import Literal, Generator
|
|
4
4
|
from pathlib import Path
|
|
5
5
|
from PIL.Image import Image
|
|
6
6
|
from doclayout_yolo import YOLOv10
|
|
@@ -9,10 +9,22 @@ from .ocr import OCR
|
|
|
9
9
|
from .ocr_corrector import correct_fragments
|
|
10
10
|
from .raw_optimizer import RawOptimizer
|
|
11
11
|
from .rectangle import intersection_area, Rectangle
|
|
12
|
-
from .types import ExtractedResult, OCRFragment, LayoutClass, Layout
|
|
13
12
|
from .downloader import download
|
|
13
|
+
from .table import Table
|
|
14
|
+
from .latex import LaTeX
|
|
14
15
|
from .layout_order import LayoutOrder
|
|
15
16
|
from .overlap import merge_fragments_as_line, remove_overlap_layouts
|
|
17
|
+
from .clipper import clip_from_image
|
|
18
|
+
from .types import (
|
|
19
|
+
ExtractedResult,
|
|
20
|
+
OCRFragment,
|
|
21
|
+
TableLayoutParsedFormat,
|
|
22
|
+
Layout,
|
|
23
|
+
LayoutClass,
|
|
24
|
+
PlainLayout,
|
|
25
|
+
TableLayout,
|
|
26
|
+
FormulaLayout,
|
|
27
|
+
)
|
|
16
28
|
|
|
17
29
|
|
|
18
30
|
class DocExtractor:
|
|
@@ -21,12 +33,26 @@ class DocExtractor:
|
|
|
21
33
|
model_dir_path: str,
|
|
22
34
|
device: Literal["cpu", "cuda"] = "cpu",
|
|
23
35
|
ocr_for_each_layouts: bool = True,
|
|
36
|
+
extract_formula: bool = True,
|
|
37
|
+
extract_table_format: TableLayoutParsedFormat | None = None,
|
|
24
38
|
):
|
|
25
39
|
self._model_dir_path: str = model_dir_path
|
|
26
40
|
self._device: Literal["cpu", "cuda"] = device
|
|
27
41
|
self._ocr_for_each_layouts: bool = ocr_for_each_layouts
|
|
28
|
-
self.
|
|
42
|
+
self._extract_formula: bool = extract_formula
|
|
43
|
+
self._extract_table_format: TableLayoutParsedFormat | None = extract_table_format
|
|
29
44
|
self._yolo: YOLOv10 | None = None
|
|
45
|
+
self._ocr: OCR = OCR(
|
|
46
|
+
device=device,
|
|
47
|
+
model_dir_path=os.path.join(model_dir_path, "onnx_ocr"),
|
|
48
|
+
)
|
|
49
|
+
self._table: Table = Table(
|
|
50
|
+
device=device,
|
|
51
|
+
model_path=os.path.join(model_dir_path, "struct_eqtable"),
|
|
52
|
+
)
|
|
53
|
+
self._latex: LaTeX = LaTeX(
|
|
54
|
+
model_path=os.path.join(model_dir_path, "latex"),
|
|
55
|
+
)
|
|
30
56
|
self._layout_order: LayoutOrder = LayoutOrder(
|
|
31
57
|
model_path=os.path.join(model_dir_path, "layoutreader"),
|
|
32
58
|
)
|
|
@@ -40,7 +66,7 @@ class DocExtractor:
|
|
|
40
66
|
raw_optimizer = RawOptimizer(image, adjust_points)
|
|
41
67
|
fragments = list(self._ocr.search_fragments(raw_optimizer.image_np))
|
|
42
68
|
raw_optimizer.receive_raw_fragments(fragments)
|
|
43
|
-
layouts = self.
|
|
69
|
+
layouts = list(self._yolo_extract_layouts(raw_optimizer.image))
|
|
44
70
|
layouts = self._layouts_matched_by_fragments(fragments, layouts)
|
|
45
71
|
layouts = remove_overlap_layouts(layouts)
|
|
46
72
|
|
|
@@ -50,6 +76,8 @@ class DocExtractor:
|
|
|
50
76
|
layouts = self._layout_order.sort(layouts, raw_optimizer.image.size)
|
|
51
77
|
layouts = [layout for layout in layouts if self._should_keep_layout(layout)]
|
|
52
78
|
|
|
79
|
+
self._parse_table_and_formula_layouts(layouts, raw_optimizer)
|
|
80
|
+
|
|
53
81
|
for layout in layouts:
|
|
54
82
|
layout.fragments = merge_fragments_as_line(layout.fragments)
|
|
55
83
|
|
|
@@ -62,7 +90,7 @@ class DocExtractor:
|
|
|
62
90
|
adjusted_image=raw_optimizer.adjusted_image,
|
|
63
91
|
)
|
|
64
92
|
|
|
65
|
-
def
|
|
93
|
+
def _yolo_extract_layouts(self, source: Image) -> Generator[Layout, None, None]:
|
|
66
94
|
# about source parameter to see:
|
|
67
95
|
# https://github.com/opendatalab/DocLayout-YOLO/blob/7c4be36bc61f11b67cf4a44ee47f3c41e9800a91/doclayout_yolo/data/build.py#L157-L175
|
|
68
96
|
det_res = self._get_yolo().predict(
|
|
@@ -72,7 +100,6 @@ class DocExtractor:
|
|
|
72
100
|
device=self._device # Device to use (e.g., "cuda" or "cpu")
|
|
73
101
|
)
|
|
74
102
|
boxes = det_res[0].__dict__["boxes"]
|
|
75
|
-
layouts: list[Layout] = []
|
|
76
103
|
|
|
77
104
|
for cls_id, rect in zip(boxes.cls, boxes.xyxy):
|
|
78
105
|
cls_id = cls_id.item()
|
|
@@ -89,9 +116,13 @@ class DocExtractor:
|
|
|
89
116
|
lb=(x1, y2),
|
|
90
117
|
rb=(x2, y2),
|
|
91
118
|
)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
119
|
+
if rect.is_valid:
|
|
120
|
+
if cls == LayoutClass.TABLE:
|
|
121
|
+
yield TableLayout(cls=cls, rect=rect, fragments=[], parsed=None)
|
|
122
|
+
elif cls == LayoutClass.ISOLATE_FORMULA:
|
|
123
|
+
yield FormulaLayout(cls=cls, rect=rect, fragments=[], latex=None)
|
|
124
|
+
else:
|
|
125
|
+
yield PlainLayout(cls=cls, rect=rect, fragments=[])
|
|
95
126
|
|
|
96
127
|
def _layouts_matched_by_fragments(self, fragments: list[OCRFragment], layouts: list[Layout]):
|
|
97
128
|
layouts_group = self._split_layouts_by_group(layouts)
|
|
@@ -107,6 +138,17 @@ class DocExtractor:
|
|
|
107
138
|
for layout in layouts:
|
|
108
139
|
correct_fragments(self._ocr, source, layout)
|
|
109
140
|
|
|
141
|
+
def _parse_table_and_formula_layouts(self, layouts: list[Layout], raw_optimizer: RawOptimizer):
|
|
142
|
+
for layout in layouts:
|
|
143
|
+
if isinstance(layout, FormulaLayout) and self._extract_formula:
|
|
144
|
+
image = clip_from_image(raw_optimizer.image, layout.rect)
|
|
145
|
+
layout.latex = self._latex.extract(image)
|
|
146
|
+
elif isinstance(layout, TableLayout) and self._extract_table_format is not None:
|
|
147
|
+
image = clip_from_image(raw_optimizer.image, layout.rect)
|
|
148
|
+
parsed = self._table.predict(image, self._extract_table_format)
|
|
149
|
+
if parsed is not None:
|
|
150
|
+
layout.parsed = (parsed, self._extract_table_format)
|
|
151
|
+
|
|
110
152
|
def _split_layouts_by_group(self, layouts: list[Layout]):
|
|
111
153
|
texts_layouts: list[Layout] = []
|
|
112
154
|
abandon_layouts: list[Layout] = []
|
|
@@ -149,9 +191,11 @@ class DocExtractor:
|
|
|
149
191
|
|
|
150
192
|
def _get_yolo(self) -> YOLOv10:
|
|
151
193
|
if self._yolo is None:
|
|
194
|
+
base_path = os.path.join(self._model_dir_path, "yolo")
|
|
195
|
+
os.makedirs(base_path, exist_ok=True)
|
|
152
196
|
yolo_model_url = "https://huggingface.co/opendatalab/PDF-Extract-Kit-1.0/resolve/main/models/Layout/YOLO/doclayout_yolo_ft.pt"
|
|
153
197
|
yolo_model_name = "doclayout_yolo_ft.pt"
|
|
154
|
-
yolo_model_path = Path(os.path.join(
|
|
198
|
+
yolo_model_path = Path(os.path.join(base_path, yolo_model_name))
|
|
155
199
|
if not yolo_model_path.exists():
|
|
156
200
|
download(yolo_model_url, yolo_model_path)
|
|
157
201
|
self._yolo = YOLOv10(str(yolo_model_path))
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import requests
|
|
4
|
+
|
|
5
|
+
from munch import Munch
|
|
6
|
+
from pix2tex.cli import LatexOCR
|
|
7
|
+
from PIL.Image import Image
|
|
8
|
+
from .utils import expand_image
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class LaTeX:
|
|
12
|
+
def __init__(self, model_path: str):
|
|
13
|
+
self._model_path: str = model_path
|
|
14
|
+
self._model: LatexOCR | None = None
|
|
15
|
+
|
|
16
|
+
def extract(self, image: Image) -> str | None:
|
|
17
|
+
image = expand_image(image, 0.1) # 添加边缘提高识别准确率
|
|
18
|
+
model = self._get_model()
|
|
19
|
+
with torch.no_grad():
|
|
20
|
+
return model(image)
|
|
21
|
+
|
|
22
|
+
def _get_model(self) -> LatexOCR:
|
|
23
|
+
if self._model is None:
|
|
24
|
+
if not os.path.exists(self._model_path):
|
|
25
|
+
self._download_model()
|
|
26
|
+
|
|
27
|
+
self._model = LatexOCR(Munch({
|
|
28
|
+
"config": os.path.join("settings", "config.yaml"),
|
|
29
|
+
"checkpoint": os.path.join(self._model_path, "weights.pth"),
|
|
30
|
+
"no_cuda": True,
|
|
31
|
+
"no_resize": False,
|
|
32
|
+
}))
|
|
33
|
+
return self._model
|
|
34
|
+
|
|
35
|
+
# from https://github.com/lukas-blecher/LaTeX-OCR/blob/5c1ac929bd19a7ecf86d5fb8d94771c8969fcb80/pix2tex/model/checkpoints/get_latest_checkpoint.py#L37-L45
|
|
36
|
+
def _download_model(self):
|
|
37
|
+
os.makedirs(self._model_path, exist_ok=True)
|
|
38
|
+
tag = "v0.0.1"
|
|
39
|
+
files: list[tuple[str, str]] = (
|
|
40
|
+
("weights.pth", f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/weights.pth"),
|
|
41
|
+
("image_resizer.pth", f"https://github.com/lukas-blecher/LaTeX-OCR/releases/download/{tag}/image_resizer.pth")
|
|
42
|
+
)
|
|
43
|
+
for file_name, url in files:
|
|
44
|
+
file_path = os.path.join(self._model_path, file_name)
|
|
45
|
+
try:
|
|
46
|
+
with open(file_path, "wb") as file:
|
|
47
|
+
response = requests.get(url, stream=True, timeout=15)
|
|
48
|
+
response.raise_for_status()
|
|
49
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
50
|
+
if chunk: # 过滤掉保持连接的新块
|
|
51
|
+
file.write(chunk)
|
|
52
|
+
file.flush()
|
|
53
|
+
|
|
54
|
+
except BaseException as e:
|
|
55
|
+
if os.path.exists(file_path):
|
|
56
|
+
os.remove(file_path)
|
|
57
|
+
raise e
|
|
@@ -19,6 +19,10 @@ class Rectangle:
|
|
|
19
19
|
yield self.rb
|
|
20
20
|
yield self.rt
|
|
21
21
|
|
|
22
|
+
@property
|
|
23
|
+
def is_valid(self) -> bool:
|
|
24
|
+
return Polygon(self).is_valid
|
|
25
|
+
|
|
22
26
|
@property
|
|
23
27
|
def segments(self) -> Generator[tuple[Point, Point], None, None]:
|
|
24
28
|
yield (self.lt, self.lb)
|
|
@@ -60,6 +64,8 @@ class Rectangle:
|
|
|
60
64
|
def intersection_area(rect1: Rectangle, rect2: Rectangle) -> float:
|
|
61
65
|
poly1 = Polygon(rect1)
|
|
62
66
|
poly2 = Polygon(rect2)
|
|
67
|
+
if not poly1.is_valid or not poly2.is_valid:
|
|
68
|
+
return 0.0
|
|
63
69
|
intersection = poly1.intersection(poly2)
|
|
64
70
|
if intersection.is_empty:
|
|
65
71
|
return 0.0
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from .pix2s import Pix2Struct, Pix2StructTensorRT
|
|
2
|
+
from .internvl import InternVL, InternVL_LMDeploy
|
|
3
|
+
|
|
4
|
+
from transformers import AutoConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
__ALL_MODELS__ = {
|
|
8
|
+
'Pix2Struct': Pix2Struct,
|
|
9
|
+
'Pix2StructTensorRT': Pix2StructTensorRT,
|
|
10
|
+
'InternVL': InternVL,
|
|
11
|
+
'InternVL_LMDeploy': InternVL_LMDeploy,
|
|
12
|
+
}
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_model_name(model_path):
|
|
16
|
+
model_config = AutoConfig.from_pretrained(
|
|
17
|
+
model_path,
|
|
18
|
+
trust_remote_code=True,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
if 'Pix2Struct' in model_config.architectures[0]:
|
|
22
|
+
model_name = 'Pix2Struct'
|
|
23
|
+
elif 'InternVL' in model_config.architectures[0]:
|
|
24
|
+
model_name = 'InternVL'
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError(f"Unsupported model type: {model_config.architectures[0]}")
|
|
27
|
+
|
|
28
|
+
return model_name
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def build_model(
|
|
32
|
+
model_ckpt='U4R/StructTable-InternVL2-1B',
|
|
33
|
+
cache_dir=None,
|
|
34
|
+
local_files_only=None,
|
|
35
|
+
**kwargs,
|
|
36
|
+
):
|
|
37
|
+
model_name = get_model_name(model_ckpt)
|
|
38
|
+
if model_name == 'InternVL' and kwargs.get('lmdeploy', False):
|
|
39
|
+
model_name = 'InternVL_LMDeploy'
|
|
40
|
+
elif model_name == 'Pix2Struct' and kwargs.get('tensorrt_path', None):
|
|
41
|
+
model_name = 'Pix2StructTensorRT'
|
|
42
|
+
|
|
43
|
+
model = __ALL_MODELS__[model_name](
|
|
44
|
+
model_ckpt,
|
|
45
|
+
cache_dir=cache_dir,
|
|
46
|
+
local_files_only=local_files_only,
|
|
47
|
+
**kwargs
|
|
48
|
+
)
|
|
49
|
+
return model
|