doc-page-extractor 0.2.0__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. doc_page_extractor/__init__.py +5 -15
  2. doc_page_extractor/check_env.py +40 -0
  3. doc_page_extractor/extractor.py +88 -215
  4. doc_page_extractor/model.py +97 -0
  5. doc_page_extractor/parser.py +51 -0
  6. doc_page_extractor/plot.py +52 -79
  7. doc_page_extractor/redacter.py +111 -0
  8. doc_page_extractor-1.0.2.dist-info/METADATA +120 -0
  9. doc_page_extractor-1.0.2.dist-info/RECORD +11 -0
  10. {doc_page_extractor-0.2.0.dist-info → doc_page_extractor-1.0.2.dist-info}/WHEEL +1 -2
  11. doc_page_extractor-1.0.2.dist-info/licenses/LICENSE +21 -0
  12. doc_page_extractor/clipper.py +0 -119
  13. doc_page_extractor/downloader.py +0 -16
  14. doc_page_extractor/latex.py +0 -31
  15. doc_page_extractor/layout_order.py +0 -237
  16. doc_page_extractor/layoutreader.py +0 -126
  17. doc_page_extractor/models.py +0 -92
  18. doc_page_extractor/ocr.py +0 -200
  19. doc_page_extractor/ocr_corrector.py +0 -126
  20. doc_page_extractor/onnxocr/__init__.py +0 -1
  21. doc_page_extractor/onnxocr/cls_postprocess.py +0 -26
  22. doc_page_extractor/onnxocr/db_postprocess.py +0 -246
  23. doc_page_extractor/onnxocr/imaug.py +0 -32
  24. doc_page_extractor/onnxocr/operators.py +0 -187
  25. doc_page_extractor/onnxocr/predict_base.py +0 -57
  26. doc_page_extractor/onnxocr/predict_cls.py +0 -109
  27. doc_page_extractor/onnxocr/predict_det.py +0 -139
  28. doc_page_extractor/onnxocr/predict_rec.py +0 -344
  29. doc_page_extractor/onnxocr/predict_system.py +0 -97
  30. doc_page_extractor/onnxocr/rec_postprocess.py +0 -896
  31. doc_page_extractor/onnxocr/utils.py +0 -71
  32. doc_page_extractor/overlap.py +0 -167
  33. doc_page_extractor/raw_optimizer.py +0 -104
  34. doc_page_extractor/rectangle.py +0 -72
  35. doc_page_extractor/rotation.py +0 -158
  36. doc_page_extractor/struct_eqtable/__init__.py +0 -49
  37. doc_page_extractor/struct_eqtable/internvl/__init__.py +0 -2
  38. doc_page_extractor/struct_eqtable/internvl/conversation.py +0 -394
  39. doc_page_extractor/struct_eqtable/internvl/internvl.py +0 -198
  40. doc_page_extractor/struct_eqtable/internvl/internvl_lmdeploy.py +0 -81
  41. doc_page_extractor/struct_eqtable/pix2s/__init__.py +0 -3
  42. doc_page_extractor/struct_eqtable/pix2s/pix2s.py +0 -76
  43. doc_page_extractor/struct_eqtable/pix2s/pix2s_trt.py +0 -1047
  44. doc_page_extractor/table.py +0 -70
  45. doc_page_extractor/types.py +0 -91
  46. doc_page_extractor/utils.py +0 -32
  47. doc_page_extractor-0.2.0.dist-info/METADATA +0 -85
  48. doc_page_extractor-0.2.0.dist-info/RECORD +0 -45
  49. doc_page_extractor-0.2.0.dist-info/licenses/LICENSE +0 -661
  50. doc_page_extractor-0.2.0.dist-info/top_level.txt +0 -2
  51. tests/__init__.py +0 -0
  52. tests/test_history_bus.py +0 -55
@@ -1,126 +0,0 @@
1
- # Copy from https://github.com/ppaanngggg/layoutreader/blob/main/v3/helpers.py
2
- from collections import defaultdict
3
- from typing import List, Dict
4
-
5
- import torch
6
- from transformers import LayoutLMv3ForTokenClassification
7
-
8
- MAX_LEN = 510
9
- CLS_TOKEN_ID = 0
10
- UNK_TOKEN_ID = 3
11
- EOS_TOKEN_ID = 2
12
-
13
-
14
- class DataCollator:
15
- def __call__(self, features: List[dict]) -> Dict[str, torch.Tensor]:
16
- bbox = []
17
- labels = []
18
- input_ids = []
19
- attention_mask = []
20
-
21
- # clip bbox and labels to max length, build input_ids and attention_mask
22
- for feature in features:
23
- _bbox = feature["source_boxes"]
24
- if len(_bbox) > MAX_LEN:
25
- _bbox = _bbox[:MAX_LEN]
26
- _labels = feature["target_index"]
27
- if len(_labels) > MAX_LEN:
28
- _labels = _labels[:MAX_LEN]
29
- _input_ids = [UNK_TOKEN_ID] * len(_bbox)
30
- _attention_mask = [1] * len(_bbox)
31
- assert len(_bbox) == len(_labels) == len(_input_ids) == len(_attention_mask)
32
- bbox.append(_bbox)
33
- labels.append(_labels)
34
- input_ids.append(_input_ids)
35
- attention_mask.append(_attention_mask)
36
-
37
- # add CLS and EOS tokens
38
- for i in range(len(bbox)):
39
- bbox[i] = [[0, 0, 0, 0]] + bbox[i] + [[0, 0, 0, 0]]
40
- labels[i] = [-100] + labels[i] + [-100]
41
- input_ids[i] = [CLS_TOKEN_ID] + input_ids[i] + [EOS_TOKEN_ID]
42
- attention_mask[i] = [1] + attention_mask[i] + [1]
43
-
44
- # padding to max length
45
- max_len = max(len(x) for x in bbox)
46
- for i in range(len(bbox)):
47
- bbox[i] = bbox[i] + [[0, 0, 0, 0]] * (max_len - len(bbox[i]))
48
- labels[i] = labels[i] + [-100] * (max_len - len(labels[i]))
49
- input_ids[i] = input_ids[i] + [EOS_TOKEN_ID] * (max_len - len(input_ids[i]))
50
- attention_mask[i] = attention_mask[i] + [0] * (
51
- max_len - len(attention_mask[i])
52
- )
53
-
54
- ret = {
55
- "bbox": torch.tensor(bbox),
56
- "attention_mask": torch.tensor(attention_mask),
57
- "labels": torch.tensor(labels),
58
- "input_ids": torch.tensor(input_ids),
59
- }
60
- # set label > MAX_LEN to -100, because original labels may be > MAX_LEN
61
- ret["labels"][ret["labels"] > MAX_LEN] = -100
62
- # set label > 0 to label-1, because original labels are 1-indexed
63
- ret["labels"][ret["labels"] > 0] -= 1
64
- return ret
65
-
66
-
67
- def boxes2inputs(boxes: List[List[int]]) -> Dict[str, torch.Tensor]:
68
- bbox = [[0, 0, 0, 0]] + boxes + [[0, 0, 0, 0]]
69
- input_ids = [CLS_TOKEN_ID] + [UNK_TOKEN_ID] * len(boxes) + [EOS_TOKEN_ID]
70
- attention_mask = [1] + [1] * len(boxes) + [1]
71
- return {
72
- "bbox": torch.tensor([bbox]),
73
- "attention_mask": torch.tensor([attention_mask]),
74
- "input_ids": torch.tensor([input_ids]),
75
- }
76
-
77
-
78
- def prepare_inputs(
79
- inputs: Dict[str, torch.Tensor], model: LayoutLMv3ForTokenClassification
80
- ) -> Dict[str, torch.Tensor]:
81
- ret = {}
82
- for k, v in inputs.items():
83
- v = v.to(model.device)
84
- if torch.is_floating_point(v):
85
- v = v.to(model.dtype)
86
- ret[k] = v
87
- return ret
88
-
89
-
90
- def parse_logits(logits: torch.Tensor, length: int) -> List[int]:
91
- """
92
- parse logits to orders
93
-
94
- :param logits: logits from model
95
- :param length: input length
96
- :return: orders
97
- """
98
- logits = logits[1 : length + 1, :length]
99
- orders = logits.argsort(descending=False).tolist()
100
- ret = [o.pop() for o in orders]
101
- while True:
102
- order_to_idxes = defaultdict(list)
103
- for idx, order in enumerate(ret):
104
- order_to_idxes[order].append(idx)
105
- # filter idxes len > 1
106
- order_to_idxes = {k: v for k, v in order_to_idxes.items() if len(v) > 1}
107
- if not order_to_idxes:
108
- break
109
- # filter
110
- for order, idxes in order_to_idxes.items():
111
- # find original logits of idxes
112
- idxes_to_logit = {}
113
- for idx in idxes:
114
- idxes_to_logit[idx] = logits[idx, order]
115
- idxes_to_logit = sorted(
116
- idxes_to_logit.items(), key=lambda x: x[1], reverse=True
117
- )
118
- # keep the highest logit as order, set others to next candidate
119
- for idx, _ in idxes_to_logit[1:]:
120
- ret[idx] = orders[idx].pop()
121
-
122
- return ret
123
-
124
-
125
- def check_duplicate(a: List[int]) -> bool:
126
- return len(a) != len(set(a))
@@ -1,92 +0,0 @@
1
- import os
2
-
3
- from logging import Logger
4
- from huggingface_hub import hf_hub_download, snapshot_download, try_to_load_from_cache
5
- from .types import ModelsDownloader
6
-
7
- class HuggingfaceModelsDownloader(ModelsDownloader):
8
- def __init__(
9
- self,
10
- logger: Logger,
11
- model_dir_path: str | None
12
- ):
13
- self._logger = logger
14
- self._model_dir_path: str | None = model_dir_path
15
-
16
- def onnx_ocr(self) -> str:
17
- repo_path = try_to_load_from_cache(
18
- repo_id="moskize/OnnxOCR",
19
- filename="README.md",
20
- cache_dir=self._model_dir_path
21
- )
22
- if isinstance(repo_path, str):
23
- return os.path.dirname(repo_path)
24
- else:
25
- self._logger.info("Downloading OCR model...")
26
- return snapshot_download(
27
- cache_dir=self._model_dir_path,
28
- repo_id="moskize/OnnxOCR",
29
- )
30
-
31
- def yolo(self) -> str:
32
- yolo_file_path = try_to_load_from_cache(
33
- repo_id="opendatalab/PDF-Extract-Kit-1.0",
34
- filename="models/Layout/YOLO/doclayout_yolo_ft.pt",
35
- cache_dir=self._model_dir_path
36
- )
37
- if isinstance(yolo_file_path, str):
38
- return yolo_file_path
39
- else:
40
- self._logger.info("Downloading YOLO model...")
41
- return hf_hub_download(
42
- cache_dir=self._model_dir_path,
43
- repo_id="opendatalab/PDF-Extract-Kit-1.0",
44
- filename="models/Layout/YOLO/doclayout_yolo_ft.pt",
45
- )
46
-
47
- def layoutreader(self) -> str:
48
- repo_path = try_to_load_from_cache(
49
- repo_id="hantian/layoutreader",
50
- filename="model.safetensors",
51
- cache_dir=self._model_dir_path
52
- )
53
- if isinstance(repo_path, str):
54
- return os.path.dirname(repo_path)
55
- else:
56
- self._logger.info("Downloading LayoutReader model...")
57
- return snapshot_download(
58
- cache_dir=self._model_dir_path,
59
- repo_id="hantian/layoutreader",
60
- )
61
-
62
- def struct_eqtable(self) -> str:
63
- repo_path = try_to_load_from_cache(
64
- repo_id="U4R/StructTable-InternVL2-1B",
65
- filename="model.safetensors",
66
- cache_dir=self._model_dir_path
67
- )
68
- if isinstance(repo_path, str):
69
- return os.path.dirname(repo_path)
70
- else:
71
- self._logger.info("Downloading StructEqTable model...")
72
- return snapshot_download(
73
- cache_dir=self._model_dir_path,
74
- repo_id="U4R/StructTable-InternVL2-1B",
75
- )
76
-
77
- def latex(self):
78
- repo_path = try_to_load_from_cache(
79
- repo_id="lukbl/LaTeX-OCR",
80
- filename="checkpoints/weights.pth",
81
- repo_type="space",
82
- cache_dir=self._model_dir_path
83
- )
84
- if isinstance(repo_path, str):
85
- return os.path.dirname(os.path.dirname(repo_path))
86
- else:
87
- self._logger.info("Downloading LaTeX model...")
88
- return snapshot_download(
89
- cache_dir=self._model_dir_path,
90
- repo_type="space",
91
- repo_id="lukbl/LaTeX-OCR",
92
- )
doc_page_extractor/ocr.py DELETED
@@ -1,200 +0,0 @@
1
- import numpy as np
2
- import cv2
3
- import os
4
-
5
- from typing import Literal, Generator
6
- from dataclasses import dataclass
7
- from .onnxocr import TextSystem
8
- from .types import GetModelDir, OCRFragment
9
- from .rectangle import Rectangle
10
- from .utils import is_space_text
11
-
12
-
13
- _MODELS = (
14
- ("ppocrv4", "rec", "rec.onnx"),
15
- ("ppocrv4", "cls", "cls.onnx"),
16
- ("ppocrv4", "det", "det.onnx"),
17
- ("ch_ppocr_server_v2.0", "ppocr_keys_v1.txt"),
18
- )
19
-
20
- @dataclass
21
- class _OONXParams:
22
- use_angle_cls: bool
23
- use_gpu: bool
24
- rec_image_shape: tuple[int, int, int]
25
- cls_image_shape: tuple[int, int, int]
26
- cls_batch_num: int
27
- cls_thresh: float
28
- label_list: list[str]
29
-
30
- det_algorithm: str
31
- det_limit_side_len: int
32
- det_limit_type: str
33
- det_db_thresh: float
34
- det_db_box_thresh: float
35
- det_db_unclip_ratio: float
36
- use_dilation: bool
37
- det_db_score_mode: str
38
- det_box_type: str
39
- rec_batch_num: int
40
- drop_score: float
41
- save_crop_res: bool
42
- rec_algorithm: str
43
- use_space_char: bool
44
- rec_model_dir: str
45
- cls_model_dir: str
46
- det_model_dir: str
47
- rec_char_dict_path: str
48
-
49
-
50
-
51
-
52
- class OCR:
53
- def __init__(
54
- self,
55
- device: Literal["cpu", "cuda"],
56
- get_model_dir: GetModelDir,
57
- ):
58
- self._device: Literal["cpu", "cuda"] = device
59
- self._get_model_dir: GetModelDir = get_model_dir
60
- self._text_system: TextSystem | None = None
61
-
62
- def search_fragments(self, image: np.ndarray) -> Generator[OCRFragment, None, None]:
63
- for box, res in self._ocr(image):
64
- text, rank = res
65
- if is_space_text(text):
66
- continue
67
-
68
- rect = Rectangle(
69
- lt=(box[0][0], box[0][1]),
70
- rt=(box[1][0], box[1][1]),
71
- rb=(box[2][0], box[2][1]),
72
- lb=(box[3][0], box[3][1]),
73
- )
74
- if not rect.is_valid or rect.area == 0.0:
75
- continue
76
-
77
- yield OCRFragment(
78
- order=0,
79
- text=text,
80
- rank=rank,
81
- rect=rect,
82
- )
83
-
84
- def _ocr(self, image: np.ndarray) -> Generator[tuple[list[list[float]], tuple[str, float]], None, None]:
85
- text_system = self._get_text_system()
86
- image = self._preprocess_image(image)
87
- dt_boxes, rec_res = text_system(image)
88
-
89
- for box, res in zip(dt_boxes, rec_res):
90
- yield box.tolist(), res
91
-
92
- def make_model_paths(self) -> list[str]:
93
- model_paths = []
94
- model_dir = self._get_model_dir()
95
- for model_path in _MODELS:
96
- file_name = os.path.join(*model_path)
97
- model_paths.append(os.path.join(model_dir, file_name))
98
- return model_paths
99
-
100
- def _get_text_system(self) -> TextSystem:
101
- if self._text_system is None:
102
- model_paths = self.make_model_paths()
103
- self._text_system = TextSystem(_OONXParams(
104
- use_angle_cls=True,
105
- use_gpu=(self._device != "cpu"),
106
- rec_image_shape=(3, 48, 320),
107
- cls_image_shape=(3, 48, 192),
108
- cls_batch_num=6,
109
- cls_thresh=0.9,
110
- label_list=["0", "180"],
111
- det_algorithm="DB",
112
- det_limit_side_len=960,
113
- det_limit_type="max",
114
- det_db_thresh=0.3,
115
- det_db_box_thresh=0.6,
116
- det_db_unclip_ratio=1.5,
117
- use_dilation=False,
118
- det_db_score_mode="fast",
119
- det_box_type="quad",
120
- rec_batch_num=6,
121
- drop_score=0.5,
122
- save_crop_res=False,
123
- rec_algorithm="SVTR_LCNet",
124
- use_space_char=True,
125
- rec_model_dir=model_paths[0],
126
- cls_model_dir=model_paths[1],
127
- det_model_dir=model_paths[2],
128
- rec_char_dict_path=model_paths[3],
129
- ))
130
-
131
- return self._text_system
132
-
133
- def _preprocess_image(self, image: np.ndarray) -> np.ndarray:
134
- image = self._alpha_to_color(image, (255, 255, 255))
135
- # image = cv2.bitwise_not(image) # inv
136
- # image = self._binarize_img(image) # bin
137
- image = cv2.normalize(
138
- src=image,
139
- dst=np.zeros((image.shape[0], image.shape[1])),
140
- alpha=0,
141
- beta=255,
142
- norm_type=cv2.NORM_MINMAX,
143
- )
144
- if cv2.cuda.getCudaEnabledDeviceCount() > 0:
145
- gpu_frame = cv2.cuda.GpuMat()
146
- gpu_frame.upload(image)
147
- image = cv2.cuda.fastNlMeansDenoisingColored(
148
- src=gpu_frame,
149
- dst=None,
150
- h_luminance=10,
151
- photo_render=10,
152
- search_window=15,
153
- block_size=7,
154
- )
155
- image = gpu_frame.download()
156
- elif cv2.ocl.haveOpenCL():
157
- cv2.ocl.setUseOpenCL(True)
158
- gpu_frame = cv2.UMat(image)
159
- image = cv2.fastNlMeansDenoisingColored(
160
- src=gpu_frame,
161
- dst=None,
162
- h=10,
163
- hColor=10,
164
- templateWindowSize=7,
165
- searchWindowSize=15,
166
- )
167
- image = image.get()
168
- else:
169
- image = cv2.fastNlMeansDenoisingColored(
170
- src=image,
171
- dst=None,
172
- h=10,
173
- hColor=10,
174
- templateWindowSize=7,
175
- searchWindowSize=15,
176
- )
177
-
178
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # image to gray
179
- return image
180
-
181
- def _alpha_to_color(self, image: np.ndarray, alpha_color: tuple[float, float, float]) -> np.ndarray:
182
- if len(image.shape) == 3 and image.shape[2] == 4:
183
- B, G, R, A = cv2.split(image)
184
- alpha = A / 255
185
-
186
- R = (alpha_color[0] * (1 - alpha) + R * alpha).astype(np.uint8)
187
- G = (alpha_color[1] * (1 - alpha) + G * alpha).astype(np.uint8)
188
- B = (alpha_color[2] * (1 - alpha) + B * alpha).astype(np.uint8)
189
-
190
- image = cv2.merge((B, G, R))
191
-
192
- return image
193
-
194
- def _binarize_img(self, image: np.ndarray):
195
- if len(image.shape) == 3 and image.shape[2] == 3:
196
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # conversion to grayscale image
197
- # use cv2 threshold binarization
198
- _, gray = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
199
- image = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
200
- return image
@@ -1,126 +0,0 @@
1
- import numpy as np
2
-
3
- from typing import Iterable
4
- from shapely.geometry import Polygon
5
- from PIL.Image import new, Image, Resampling
6
- from .types import Layout, OCRFragment
7
- from .ocr import OCR
8
- from .overlap import overlap_rate
9
- from .rectangle import Point, Rectangle
10
-
11
-
12
- _MIN_RATE = 0.5
13
-
14
- def correct_fragments(ocr: OCR, source: Image, layout: Layout):
15
- x1, y1, x2, y2 = layout.rect.wrapper
16
- image: Image = source.crop((
17
- round(x1), round(y1),
18
- round(x2), round(y2),
19
- ))
20
- image, dx, dy, scale = _adjust_image(image)
21
- image_np = np.array(image)
22
- ocr_fragments = list(ocr.search_fragments(image_np))
23
- corrected_fragments: list[OCRFragment] = []
24
-
25
- for fragment in ocr_fragments:
26
- _apply_fragment(fragment.rect, layout, dx, dy, scale)
27
-
28
- matched_fragments, not_matched_fragments = _match_fragments(
29
- zone_rect=layout.rect,
30
- fragments1=layout.fragments,
31
- fragments2=ocr_fragments,
32
- )
33
- for fragment1, fragment2 in matched_fragments:
34
- if fragment1.rank > fragment2.rank:
35
- corrected_fragments.append(fragment1)
36
- else:
37
- corrected_fragments.append(fragment2)
38
-
39
- corrected_fragments.extend(not_matched_fragments)
40
- layout.fragments = corrected_fragments
41
-
42
- def _adjust_image(image: Image) -> tuple[Image, int, int, float]:
43
- # after testing, adding white borders to images can reduce
44
- # the possibility of some text not being recognized
45
- border_size: int = 50
46
- adjusted_size: int = 1024 - 2 * border_size
47
- width, height = image.size
48
- core_width = float(max(adjusted_size, width))
49
- core_height = float(max(adjusted_size, height))
50
-
51
- scale_x = core_width / width
52
- scale_y = core_height / height
53
- scale = min(scale_x, scale_y)
54
- adjusted_width = width * scale
55
- adjusted_height = height * scale
56
-
57
- dx = (core_width - adjusted_width) / 2.0
58
- dy = (core_height - adjusted_height) / 2.0
59
- dx = round(dx) + border_size
60
- dy = round(dy) + border_size
61
-
62
- if scale != 1.0:
63
- width = round(width * scale)
64
- height = round(height * scale)
65
- image = image.resize((width, height), Resampling.BICUBIC)
66
-
67
- width = round(core_width) + 2 * border_size
68
- height = round(core_height) + 2 * border_size
69
- new_image = new("RGB", (width, height), (255, 255, 255))
70
- new_image.paste(image, (dx, dy))
71
-
72
- return new_image, dx, dy, scale
73
-
74
- def _apply_fragment(rect: Rectangle, layout: Layout, dx: int, dy: int, scale: float):
75
- rect.lt = _apply_point(rect.lt, layout, dx, dy, scale)
76
- rect.lb = _apply_point(rect.lb, layout, dx, dy, scale)
77
- rect.rb = _apply_point(rect.rb, layout, dx, dy, scale)
78
- rect.rt = _apply_point(rect.rt, layout, dx, dy, scale)
79
-
80
- def _apply_point(point: Point, layout: Layout, dx: int, dy: int, scale: float) -> Point:
81
- x, y = point
82
- x = (x - dx) / scale + layout.rect.lt[0]
83
- y = (y - dy) / scale + layout.rect.lt[1]
84
- return x, y
85
-
86
- def _match_fragments(
87
- zone_rect: Rectangle,
88
- fragments1: Iterable[OCRFragment],
89
- fragments2: Iterable[OCRFragment],
90
- ) -> tuple[list[tuple[OCRFragment, OCRFragment]], list[OCRFragment]]:
91
-
92
- zone_polygon = Polygon(zone_rect)
93
- fragments2: list[OCRFragment] = list(fragments2)
94
- matched_fragments: list[tuple[OCRFragment, OCRFragment]] = []
95
- not_matched_fragments: list[OCRFragment] = []
96
-
97
- for fragment1 in fragments1:
98
- polygon1 = Polygon(fragment1.rect)
99
- polygon1 = zone_polygon.intersection(polygon1)
100
- if polygon1.is_empty:
101
- continue
102
-
103
- beast_j = -1
104
- beast_rate = 0.0
105
-
106
- for j, fragment2 in enumerate(fragments2):
107
- polygon2 = Polygon(fragment2.rect)
108
- rate = overlap_rate(polygon1, polygon2)
109
- if rate < _MIN_RATE:
110
- continue
111
-
112
- if rate > beast_rate:
113
- beast_j = j
114
- beast_rate = rate
115
-
116
- if beast_j != -1:
117
- matched_fragments.append((
118
- fragment1,
119
- fragments2[beast_j],
120
- ))
121
- del fragments2[beast_j]
122
- else:
123
- not_matched_fragments.append(fragment1)
124
-
125
- not_matched_fragments.extend(fragments2)
126
- return matched_fragments, not_matched_fragments
@@ -1 +0,0 @@
1
- from .predict_system import TextSystem
@@ -1,26 +0,0 @@
1
- class ClsPostProcess (object):
2
- """ Convert between text-label and text-index """
3
-
4
- def __init__(self, label_list=None, key=None, **kwargs):
5
- super(ClsPostProcess, self).__init__()
6
- self.label_list = label_list
7
- self.key = key
8
-
9
- def __call__(self, preds, label=None, *args, **kwargs):
10
- if self.key is not None:
11
- preds = preds[self.key]
12
-
13
- label_list = self.label_list
14
- if label_list is None:
15
- label_list = {idx: idx for idx in range(preds.shape[-1])}
16
-
17
- # if isinstance(preds, paddle.Tensor):
18
- # preds = preds.numpy()
19
-
20
- pred_idxs = preds.argmax(axis=1)
21
- decode_out = [(label_list[idx], preds[i, idx])
22
- for i, idx in enumerate(pred_idxs)]
23
- if label is None:
24
- return decode_out
25
- label = [(label_list[idx], 1.0) for idx in label]
26
- return decode_out, label