doc-page-extractor 0.1.1__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. doc_page_extractor/__init__.py +5 -14
  2. doc_page_extractor/check_env.py +40 -0
  3. doc_page_extractor/extractor.py +87 -212
  4. doc_page_extractor/model.py +97 -0
  5. doc_page_extractor/parser.py +51 -0
  6. doc_page_extractor/plot.py +52 -79
  7. doc_page_extractor/redacter.py +111 -0
  8. doc_page_extractor-1.0.2.dist-info/METADATA +120 -0
  9. doc_page_extractor-1.0.2.dist-info/RECORD +11 -0
  10. {doc_page_extractor-0.1.1.dist-info → doc_page_extractor-1.0.2.dist-info}/WHEEL +1 -2
  11. doc_page_extractor-1.0.2.dist-info/licenses/LICENSE +21 -0
  12. doc_page_extractor/clipper.py +0 -119
  13. doc_page_extractor/downloader.py +0 -16
  14. doc_page_extractor/latex.py +0 -57
  15. doc_page_extractor/layout_order.py +0 -240
  16. doc_page_extractor/layoutreader.py +0 -126
  17. doc_page_extractor/ocr.py +0 -175
  18. doc_page_extractor/ocr_corrector.py +0 -126
  19. doc_page_extractor/onnxocr/__init__.py +0 -1
  20. doc_page_extractor/onnxocr/cls_postprocess.py +0 -26
  21. doc_page_extractor/onnxocr/db_postprocess.py +0 -246
  22. doc_page_extractor/onnxocr/imaug.py +0 -32
  23. doc_page_extractor/onnxocr/operators.py +0 -187
  24. doc_page_extractor/onnxocr/predict_base.py +0 -52
  25. doc_page_extractor/onnxocr/predict_cls.py +0 -89
  26. doc_page_extractor/onnxocr/predict_det.py +0 -120
  27. doc_page_extractor/onnxocr/predict_rec.py +0 -321
  28. doc_page_extractor/onnxocr/predict_system.py +0 -97
  29. doc_page_extractor/onnxocr/rec_postprocess.py +0 -896
  30. doc_page_extractor/onnxocr/utils.py +0 -71
  31. doc_page_extractor/overlap.py +0 -167
  32. doc_page_extractor/raw_optimizer.py +0 -104
  33. doc_page_extractor/rectangle.py +0 -72
  34. doc_page_extractor/rotation.py +0 -158
  35. doc_page_extractor/struct_eqtable/__init__.py +0 -49
  36. doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -2
  37. doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -394
  38. doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -198
  39. doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -81
  40. doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -3
  41. doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -76
  42. doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -1047
  43. doc_page_extractor/table.py +0 -71
  44. doc_page_extractor/types.py +0 -67
  45. doc_page_extractor/utils.py +0 -32
  46. doc_page_extractor-0.1.1.dist-info/METADATA +0 -84
  47. doc_page_extractor-0.1.1.dist-info/RECORD +0 -44
  48. doc_page_extractor-0.1.1.dist-info/licenses/LICENSE +0 -661
  49. doc_page_extractor-0.1.1.dist-info/top_level.txt +0 -2
  50. tests/__init__.py +0 -0
  51. tests/test_history_bus.py +0 -55
@@ -1,71 +0,0 @@
1
- import numpy as np
2
- import cv2
3
-
4
- def get_rotate_crop_image(img, points):
5
- """
6
- img_height, img_width = img.shape[0:2]
7
- left = int(np.min(points[:, 0]))
8
- right = int(np.max(points[:, 0]))
9
- top = int(np.min(points[:, 1]))
10
- bottom = int(np.max(points[:, 1]))
11
- img_crop = img[top:bottom, left:right, :].copy()
12
- points[:, 0] = points[:, 0] - left
13
- points[:, 1] = points[:, 1] - top
14
- """
15
- assert len(points) == 4, "shape of points must be 4*2"
16
- img_crop_width = int(
17
- max(
18
- np.linalg.norm(points[0] - points[1]), np.linalg.norm(points[2] - points[3])
19
- )
20
- )
21
- img_crop_height = int(
22
- max(
23
- np.linalg.norm(points[0] - points[3]), np.linalg.norm(points[1] - points[2])
24
- )
25
- )
26
- pts_std = np.float32(
27
- [
28
- [0, 0],
29
- [img_crop_width, 0],
30
- [img_crop_width, img_crop_height],
31
- [0, img_crop_height],
32
- ]
33
- )
34
- M = cv2.getPerspectiveTransform(points, pts_std)
35
- dst_img = cv2.warpPerspective(
36
- img,
37
- M,
38
- (img_crop_width, img_crop_height),
39
- borderMode=cv2.BORDER_REPLICATE,
40
- flags=cv2.INTER_CUBIC,
41
- )
42
- dst_img_height, dst_img_width = dst_img.shape[0:2]
43
- if dst_img_height * 1.0 / dst_img_width >= 1.5:
44
- dst_img = np.rot90(dst_img)
45
- return dst_img
46
-
47
-
48
- def get_minarea_rect_crop(img, points):
49
- bounding_box = cv2.minAreaRect(np.array(points).astype(np.int32))
50
- points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
51
-
52
- index_a, index_b, index_c, index_d = 0, 1, 2, 3
53
- if points[1][1] > points[0][1]:
54
- index_a = 0
55
- index_d = 1
56
- else:
57
- index_a = 1
58
- index_d = 0
59
- if points[3][1] > points[2][1]:
60
- index_b = 2
61
- index_c = 3
62
- else:
63
- index_b = 3
64
- index_c = 2
65
-
66
- box = [points[index_a], points[index_b], points[index_c], points[index_d]]
67
- crop_img = get_rotate_crop_image(img, np.array(box))
68
- return crop_img
69
-
70
- def str2bool(v):
71
- return v.lower() in ("true", "t", "1")
@@ -1,167 +0,0 @@
1
- from typing import Generator
2
- from shapely.geometry import Polygon
3
- from .types import Layout, OCRFragment
4
- from .rectangle import Rectangle
5
-
6
-
7
- _INCLUDES_MIN_RATE = 0.99
8
-
9
- def remove_overlap_layouts(layouts: list[Layout]) -> list[Layout]:
10
- ctx = _OverlapMatrixContext(layouts)
11
- # the reason for repeating this multiple times is that deleting a layout
12
- # may cause its parent layout to change from an originally non-deletable
13
- # state to a deletable state.
14
- while True:
15
- removed_count = len(ctx.removed_indexes)
16
- for i, layout in enumerate(layouts):
17
- if i in ctx.removed_indexes or \
18
- any(0.0 < rate < _INCLUDES_MIN_RATE for rate in ctx.rates_with_other(i)) or \
19
- all(0.0 == rate for rate in ctx.rates_with_other(i)):
20
- continue
21
-
22
- if len(layout.fragments) == 0:
23
- ctx.removed_indexes.add(i)
24
- else:
25
- for j in ctx.search_includes_indexes(i):
26
- ctx.removed_indexes.add(j)
27
- layout.fragments.extend(layouts[j].fragments)
28
-
29
- if len(ctx.removed_indexes) == removed_count:
30
- break
31
-
32
- return [
33
- layout for i, layout in enumerate(layouts)
34
- if i not in ctx.removed_indexes
35
- ]
36
-
37
- class _OverlapMatrixContext:
38
- def __init__(self, layouts: list[Layout]):
39
- length: int = len(layouts)
40
- polygons: list[Polygon] = [Polygon(layout.rect) for layout in layouts]
41
- self.rate_matrix: list[list[float]] = [[1.0 for _ in range(length)] for _ in range(length)]
42
- self.removed_indexes: set[int] = set()
43
- for i in range(length):
44
- polygon1 = polygons[i]
45
- rates = self.rate_matrix[i]
46
- for j in range(length):
47
- if i != j:
48
- polygon2 = polygons[j]
49
- rates[j] = overlap_rate(polygon1, polygon2)
50
-
51
- def rates_with_other(self, index: int):
52
- for i, rate in enumerate(self.rate_matrix[index]):
53
- if i != index and i not in self.removed_indexes:
54
- yield rate
55
-
56
- def search_includes_indexes(self, index: int):
57
- for i, rate in enumerate(self.rate_matrix[index]):
58
- if i != index and \
59
- i not in self.removed_indexes and \
60
- rate >= _INCLUDES_MIN_RATE:
61
- yield i
62
-
63
- def merge_fragments_as_line(origin_fragments: list[OCRFragment]) -> list[OCRFragment]:
64
- fragments: list[OCRFragment] = []
65
- for group in _split_fragments_into_groups(origin_fragments):
66
- if len(group) == 1:
67
- fragments.append(group[0])
68
- continue
69
-
70
- min_order: float = float("inf")
71
- texts: list[str] = []
72
- text_rate_weights: float = 0.0
73
- proto_texts_len: int = 0
74
-
75
- x1: float = float("inf")
76
- y1: float = float("inf")
77
- x2: float = float("-inf")
78
- y2: float = float("-inf")
79
-
80
- for fragment in sorted(group, key=lambda x: x.rect.lt[0] + x.rect.lb[0]):
81
- proto_texts_len += len(fragment.text)
82
- text_rate_weights += fragment.rank * len(fragment.text)
83
- texts.append(fragment.text)
84
- min_order = min(min_order, fragment.order)
85
- for x, y in fragment.rect:
86
- x1 = min(x1, x)
87
- y1 = min(y1, y)
88
- x2 = max(x2, x)
89
- y2 = max(y2, y)
90
-
91
- if proto_texts_len == 0:
92
- continue
93
-
94
- fragments.append(OCRFragment(
95
- order=min_order,
96
- text=" ".join(texts),
97
- rank=text_rate_weights / proto_texts_len,
98
- rect=Rectangle(
99
- lt=(x1, y1),
100
- rt=(x2, y1),
101
- lb=(x1, y2),
102
- rb=(x2, y2),
103
- ),
104
- ))
105
- return fragments
106
-
107
- def _split_fragments_into_groups(fragments: list[OCRFragment]) -> Generator[list[OCRFragment], None, None]:
108
- group: list[OCRFragment] = []
109
- sum_height: float = 0.0
110
- sum_median: float = 0.0
111
- max_deviation_rate = 0.35
112
-
113
- for fragment in sorted(fragments, key=lambda x: x.rect.lt[1] + x.rect.rt[1]):
114
- _, y1, _, y2 = fragment.rect.wrapper
115
- height = y2 - y1
116
- median = (y1 + y2) / 2.0
117
-
118
- if height == 0:
119
- continue
120
-
121
- if len(group) > 0:
122
- next_mean_median = (sum_median + median) / (len(group) + 1)
123
- next_mean_height = (sum_height + height) / (len(group) + 1)
124
-
125
- deviation_rate = abs(median - next_mean_median) / next_mean_height
126
- if deviation_rate > max_deviation_rate:
127
- yield group
128
- group = []
129
- sum_height = 0.0
130
- sum_median = 0.0
131
-
132
- group.append(fragment)
133
- sum_height += height
134
- sum_median += median
135
-
136
- if len(group) > 0:
137
- yield group
138
-
139
- # calculating overlap ratio: The reason why area is not used is
140
- # that most of the measurements are of rectangles representing text lines.
141
- # they are very sensitive to changes in height because they are very thin and long.
142
- # In order to make it equally sensitive to length and width, the ratio of area is not used.
143
- def overlap_rate(polygon1: Polygon, polygon2: Polygon) -> float:
144
- intersection: Polygon = polygon1.intersection(polygon2)
145
- if intersection.is_empty:
146
- return 0.0
147
- else:
148
- overlay_width, overlay_height = _polygon_size(intersection)
149
- polygon2_width, polygon2_height = _polygon_size(polygon2)
150
- if polygon2_width == 0.0 or polygon2_height == 0.0:
151
- return 0.0
152
- return (
153
- overlay_width / polygon2_width +
154
- overlay_height / polygon2_height
155
- ) / 2.0
156
-
157
- def _polygon_size(polygon: Polygon) -> tuple[float, float]:
158
- x1: float = float("inf")
159
- y1: float = float("inf")
160
- x2: float = float("-inf")
161
- y2: float = float("-inf")
162
- for x, y in polygon.exterior.coords:
163
- x1 = min(x1, x)
164
- y1 = min(y1, y)
165
- x2 = max(x2, x)
166
- y2 = max(y2, y)
167
- return x2 - x1, y2 - y1
@@ -1,104 +0,0 @@
1
- import numpy as np
2
-
3
- from dataclasses import dataclass
4
- from PIL.Image import Image
5
- from math import pi
6
- from .types import OCRFragment, Layout
7
- from .rotation import calculate_rotation, RotationAdjuster
8
- from .rectangle import Rectangle
9
-
10
-
11
- _TINY_ROTATION = 0.005 # below this angle, we consider the text is horizontal
12
-
13
-
14
- @dataclass
15
- class _RotationContext:
16
- to_origin: RotationAdjuster
17
- to_new: RotationAdjuster
18
- fragment_origin_rectangles: list[Rectangle]
19
-
20
- class RawOptimizer:
21
- def __init__(
22
- self,
23
- raw: Image,
24
- adjust_points: bool,
25
- ):
26
- self._raw: Image = raw
27
- self._image: Image = raw
28
- self._adjust_points: bool = adjust_points
29
- self._fragments: list[OCRFragment]
30
- self._rotation: float = 0.0
31
- self._rotation_context: _RotationContext | None = None
32
-
33
- @property
34
- def image(self) -> Image:
35
- return self._image
36
-
37
- @property
38
- def adjusted_image(self) -> Image | None:
39
- if self._adjust_points and self._image != self._raw:
40
- return self._image
41
-
42
- @property
43
- def rotation(self) -> float:
44
- return self._rotation
45
-
46
- @property
47
- def image_np(self) -> np.ndarray:
48
- return np.array(self._raw)
49
-
50
- def receive_raw_fragments(self, fragments: list[OCRFragment]):
51
- self._fragments = fragments
52
- self._rotation = calculate_rotation(fragments)
53
-
54
- if abs(self._rotation) < _TINY_ROTATION:
55
- return
56
-
57
- origin_size = self._raw.size
58
- self._image = self._raw.rotate(
59
- angle=self._rotation * 180 / pi,
60
- fillcolor=(255, 255, 255),
61
- expand=True,
62
- )
63
- self._rotation_context = _RotationContext(
64
- fragment_origin_rectangles=[f.rect for f in fragments],
65
- to_origin=RotationAdjuster(
66
- origin_size=origin_size,
67
- new_size=self._image.size,
68
- rotation=self._rotation,
69
- to_origin_coordinate=True,
70
- ),
71
- to_new=RotationAdjuster(
72
- origin_size=origin_size,
73
- new_size=self._image.size,
74
- rotation=self._rotation,
75
- to_origin_coordinate=False,
76
- ),
77
- )
78
- adjuster = self._rotation_context.to_new
79
-
80
- for fragment in fragments:
81
- rect = fragment.rect
82
- fragment.rect = Rectangle(
83
- lt=adjuster.adjust(rect.lt),
84
- rt=adjuster.adjust(rect.rt),
85
- lb=adjuster.adjust(rect.lb),
86
- rb=adjuster.adjust(rect.rb),
87
- )
88
-
89
- def receive_raw_layouts(self, layouts: list[Layout]):
90
- if self._adjust_points or self._rotation_context is None:
91
- return
92
-
93
- for fragment, origin_rect in zip(self._fragments, self._rotation_context.fragment_origin_rectangles):
94
- fragment.rect = origin_rect
95
-
96
- adjuster = self._rotation_context.to_origin
97
-
98
- for layout in layouts:
99
- layout.rect = Rectangle(
100
- lt=adjuster.adjust(layout.rect.lt),
101
- rt=adjuster.adjust(layout.rect.rt),
102
- lb=adjuster.adjust(layout.rect.lb),
103
- rb=adjuster.adjust(layout.rect.rb),
104
- )
@@ -1,72 +0,0 @@
1
- from typing import Generator
2
- from dataclasses import dataclass
3
- from math import sqrt
4
- from shapely.geometry import Polygon
5
-
6
-
7
- Point = tuple[float, float]
8
-
9
- @dataclass
10
- class Rectangle:
11
- lt: Point
12
- rt: Point
13
- lb: Point
14
- rb: Point
15
-
16
- def __iter__(self) -> Generator[Point, None, None]:
17
- yield self.lt
18
- yield self.lb
19
- yield self.rb
20
- yield self.rt
21
-
22
- @property
23
- def is_valid(self) -> bool:
24
- return Polygon(self).is_valid
25
-
26
- @property
27
- def segments(self) -> Generator[tuple[Point, Point], None, None]:
28
- yield (self.lt, self.lb)
29
- yield (self.lb, self.rb)
30
- yield (self.rb, self.rt)
31
- yield (self.rt, self.lt)
32
-
33
- @property
34
- def area(self) -> float:
35
- return Polygon(self).area
36
-
37
- @property
38
- def size(self) -> tuple[float, float]:
39
- width: float = 0.0
40
- height: float = 0.0
41
- for i, (p1, p2) in enumerate(self.segments):
42
- dx = p1[0] - p2[0]
43
- dy = p1[1] - p2[1]
44
- distance = sqrt(dx * dx + dy * dy)
45
- if i % 2 == 0:
46
- height += distance
47
- else:
48
- width += distance
49
- return width / 2, height / 2
50
-
51
- @property
52
- def wrapper(self) -> tuple[float, float, float, float]:
53
- x1: float = float("inf")
54
- y1: float = float("inf")
55
- x2: float = float("-inf")
56
- y2: float = float("-inf")
57
- for x, y in self:
58
- x1 = min(x1, x)
59
- y1 = min(y1, y)
60
- x2 = max(x2, x)
61
- y2 = max(y2, y)
62
- return x1, y1, x2, y2
63
-
64
- def intersection_area(rect1: Rectangle, rect2: Rectangle) -> float:
65
- poly1 = Polygon(rect1)
66
- poly2 = Polygon(rect2)
67
- if not poly1.is_valid or not poly2.is_valid:
68
- return 0.0
69
- intersection = poly1.intersection(poly2)
70
- if intersection.is_empty:
71
- return 0.0
72
- return intersection.area
@@ -1,158 +0,0 @@
1
- from math import pi, atan2, sqrt, sin, cos
2
- from .types import OCRFragment
3
- from .rectangle import Point, Rectangle
4
-
5
-
6
- class RotationAdjuster:
7
- def __init__(
8
- self,
9
- origin_size: tuple[int, int],
10
- new_size: tuple[int, int],
11
- rotation: float,
12
- to_origin_coordinate: bool,
13
- ):
14
- from_size: tuple[int, int]
15
- to_size: tuple[int, int]
16
- if to_origin_coordinate:
17
- from_size = new_size
18
- to_size = origin_size
19
- else:
20
- from_size = origin_size
21
- to_size = new_size
22
- rotation = -rotation
23
-
24
- self._rotation: float = rotation
25
- self._center_offset: tuple[float, float] = (
26
- - from_size[0] / 2.0,
27
- - from_size[1] / 2.0,
28
- )
29
- self._new_offset: tuple[float, float] = (
30
- to_size[0] / 2.0,
31
- to_size[1] / 2.0,
32
- )
33
-
34
- def adjust(self, point: Point) -> Point:
35
- x, y = point
36
- x += self._center_offset[0]
37
- y += self._center_offset[1]
38
-
39
- if x != 0.0 or y != 0.0:
40
- radius = sqrt(x*x + y*y)
41
- angle = atan2(y, x) + self._rotation
42
- x = radius * cos(angle)
43
- y = radius * sin(angle)
44
-
45
- x += self._new_offset[0]
46
- y += self._new_offset[1]
47
-
48
- return x, y
49
-
50
- # to [0, pi)
51
- def normal_vertical_rotation(rotation: float) -> float:
52
- while rotation >= 2 * pi:
53
- rotation -= 2 * pi
54
- while rotation <= - 2 * pi:
55
- rotation += 2 * pi
56
- if rotation < 0.0:
57
- rotation += pi
58
- return rotation
59
-
60
- def calculate_rotation(fragments: list[OCRFragment]):
61
- horizontal_rotations: list[float] = []
62
- vertical_rotations: list[float] = []
63
-
64
- for fragment in fragments:
65
- result = _rotation_with(fragment.rect)
66
- if result is not None:
67
- horizontal_rotations.extend(result[0])
68
- vertical_rotations.extend(result[1])
69
-
70
- if len(horizontal_rotations) == 0 or len(vertical_rotations) == 0:
71
- return 0.0
72
-
73
- horizontal_rotations = _normal_horizontal_rotations(horizontal_rotations)
74
- horizontal_rotation = _find_median(horizontal_rotations)
75
- vertical_rotation = _find_median(vertical_rotations)
76
-
77
- return (vertical_rotation - 0.5 * pi + horizontal_rotation) / 2.0
78
-
79
- # @return horizontal [-pi/2, pi/2), vertical [0, pi)
80
- def calculate_rotation_with_rect(rect: Rectangle) -> tuple[float, float]:
81
- result = _rotation_with(rect)
82
- if result is None:
83
- return 0.0, 0.5 * pi
84
-
85
- horizontal_rotations, vertical_rotations = result
86
- horizontal_rotations = _normal_horizontal_rotations(horizontal_rotations)
87
- horizontal_rotation = _find_mean(horizontal_rotations)
88
- vertical_rotation = _find_mean(vertical_rotations)
89
-
90
- return horizontal_rotation, vertical_rotation
91
-
92
- def _rotation_with(rect: Rectangle):
93
- rotations0: list[float] = []
94
- rotations1: list[float] = []
95
-
96
- for i, (p1, p2) in enumerate(rect.segments):
97
- dx = p2[0] - p1[0]
98
- dy = p2[1] - p1[1]
99
- if dx == 0.0 and dy == 0.0:
100
- return None
101
- rotation: float = atan2(dy, dx)
102
- if rotation < 0.0:
103
- rotation += pi
104
- if i % 2 == 0:
105
- rotations0.append(rotation)
106
- else:
107
- rotations1.append(rotation)
108
-
109
- if _is_vertical(rotations0):
110
- return rotations1, rotations0
111
- else:
112
- return rotations0, rotations1
113
-
114
- # [0, pi) --> [-pi/2, pi/2)
115
- def _normal_horizontal_rotations(rotations: list[float]) -> list[float]:
116
- for i, rotation in enumerate(rotations):
117
- if rotation > 0.5 * pi:
118
- rotations[i] = rotation - pi
119
- return rotations
120
-
121
- def _find_median(rotations: list[float]):
122
- rotations.sort()
123
- n = len(rotations)
124
-
125
- if n % 2 == 1:
126
- return rotations[n // 2]
127
- else:
128
- mid1 = rotations[n // 2 - 1]
129
- mid2 = rotations[n // 2]
130
- return (mid1 + mid2) / 2
131
-
132
- def _find_mean(rotations: list[float]) -> float:
133
- if len(rotations) == 0:
134
- return 0.0
135
- return sum(rotations) / len(rotations)
136
-
137
- # rotation is in [0, pi)
138
- def _is_vertical(rotations: list[float]):
139
- horizontal_count: int = 0
140
- vertical_count: int = 0
141
- horizontal_delta: float = 0.0
142
- vertical_delta: float = 0.0
143
-
144
- for rotation in rotations:
145
- if rotation < 0.25 * pi: # [0, pi/4)
146
- horizontal_count += 1
147
- horizontal_delta += rotation
148
- elif rotation < 0.75 * pi: # [pi/4, 3pi/4)
149
- vertical_count += 1
150
- vertical_delta += abs(rotation - 0.5 * pi)
151
- else: # [3pi/4, pi)
152
- horizontal_count += 1
153
- horizontal_delta += pi - rotation
154
-
155
- if vertical_count == horizontal_delta:
156
- return vertical_delta < horizontal_delta
157
- else:
158
- return vertical_count > horizontal_count
@@ -1,49 +0,0 @@
1
- from .pix2s import Pix2Struct, Pix2StructTensorRT
2
- from .internvl import InternVL, InternVL_LMDeploy
3
-
4
- from transformers import AutoConfig
5
-
6
-
7
- __ALL_MODELS__ = {
8
- 'Pix2Struct': Pix2Struct,
9
- 'Pix2StructTensorRT': Pix2StructTensorRT,
10
- 'InternVL': InternVL,
11
- 'InternVL_LMDeploy': InternVL_LMDeploy,
12
- }
13
-
14
-
15
- def get_model_name(model_path):
16
- model_config = AutoConfig.from_pretrained(
17
- model_path,
18
- trust_remote_code=True,
19
- )
20
-
21
- if 'Pix2Struct' in model_config.architectures[0]:
22
- model_name = 'Pix2Struct'
23
- elif 'InternVL' in model_config.architectures[0]:
24
- model_name = 'InternVL'
25
- else:
26
- raise ValueError(f"Unsupported model type: {model_config.architectures[0]}")
27
-
28
- return model_name
29
-
30
-
31
- def build_model(
32
- model_ckpt='U4R/StructTable-InternVL2-1B',
33
- cache_dir=None,
34
- local_files_only=None,
35
- **kwargs,
36
- ):
37
- model_name = get_model_name(model_ckpt)
38
- if model_name == 'InternVL' and kwargs.get('lmdeploy', False):
39
- model_name = 'InternVL_LMDeploy'
40
- elif model_name == 'Pix2Struct' and kwargs.get('tensorrt_path', None):
41
- model_name = 'Pix2StructTensorRT'
42
-
43
- model = __ALL_MODELS__[model_name](
44
- model_ckpt,
45
- cache_dir=cache_dir,
46
- local_files_only=local_files_only,
47
- **kwargs
48
- )
49
- return model
@@ -1,2 +0,0 @@
1
- from .internvl import InternVL
2
- from .internvl_lmdeploy import InternVL_LMDeploy