yomitoku 0.4.1__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. yomitoku/base.py +1 -1
  2. yomitoku/cli/main.py +219 -27
  3. yomitoku/configs/__init__.py +2 -0
  4. yomitoku/configs/cfg_text_detector_dbnet.py +1 -1
  5. yomitoku/configs/cfg_text_recognizer_parseq_small.py +51 -0
  6. yomitoku/data/functions.py +48 -23
  7. yomitoku/document_analyzer.py +243 -41
  8. yomitoku/export/__init__.py +18 -5
  9. yomitoku/export/export_csv.py +71 -2
  10. yomitoku/export/export_html.py +46 -12
  11. yomitoku/export/export_json.py +66 -3
  12. yomitoku/export/export_markdown.py +42 -6
  13. yomitoku/layout_analyzer.py +2 -9
  14. yomitoku/layout_parser.py +58 -4
  15. yomitoku/models/dbnet_plus.py +13 -39
  16. yomitoku/models/layers/activate.py +13 -0
  17. yomitoku/models/layers/rtdetr_backbone.py +18 -17
  18. yomitoku/models/layers/rtdetr_hybrid_encoder.py +19 -20
  19. yomitoku/models/layers/rtdetrv2_decoder.py +14 -1
  20. yomitoku/models/parseq.py +15 -22
  21. yomitoku/ocr.py +24 -27
  22. yomitoku/onnx/.gitkeep +0 -0
  23. yomitoku/postprocessor/dbnet_postporcessor.py +15 -14
  24. yomitoku/postprocessor/parseq_tokenizer.py +1 -3
  25. yomitoku/postprocessor/rtdetr_postprocessor.py +14 -1
  26. yomitoku/table_structure_recognizer.py +82 -9
  27. yomitoku/text_detector.py +57 -7
  28. yomitoku/text_recognizer.py +84 -16
  29. yomitoku/utils/misc.py +21 -14
  30. yomitoku/utils/visualizer.py +15 -8
  31. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/METADATA +34 -41
  32. yomitoku-0.7.4.dist-info/RECORD +54 -0
  33. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/WHEEL +1 -1
  34. yomitoku-0.4.1.dist-info/RECORD +0 -52
  35. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/entry_points.txt +0 -0
yomitoku/ocr.py CHANGED
@@ -16,16 +16,37 @@ class WordPrediction(BaseSchema):
16
16
  )
17
17
  content: str
18
18
  direction: str
19
- det_score: float
20
19
  rec_score: float
20
+ det_score: float
21
21
 
22
22
 
23
23
  class OCRSchema(BaseSchema):
24
24
  words: List[WordPrediction]
25
25
 
26
26
 
27
+ def ocr_aggregate(det_outputs, rec_outputs):
28
+ words = []
29
+ for points, det_score, pred, rec_score, direction in zip(
30
+ det_outputs.points,
31
+ det_outputs.scores,
32
+ rec_outputs.contents,
33
+ rec_outputs.scores,
34
+ rec_outputs.directions,
35
+ ):
36
+ words.append(
37
+ {
38
+ "points": points,
39
+ "content": pred,
40
+ "direction": direction,
41
+ "det_score": det_score,
42
+ "rec_score": rec_score,
43
+ }
44
+ )
45
+ return words
46
+
47
+
27
48
  class OCR:
28
- def __init__(self, configs=None, device="cuda", visualize=False):
49
+ def __init__(self, configs={}, device="cuda", visualize=False):
29
50
  text_detector_kwargs = {
30
51
  "device": device,
31
52
  "visualize": visualize,
@@ -36,10 +57,6 @@ class OCR:
36
57
  }
37
58
 
38
59
  if isinstance(configs, dict):
39
- assert (
40
- "text_detector" in configs or "text_recognizer" in configs
41
- ), "Invalid config key. Please check the config keys."
42
-
43
60
  if "text_detector" in configs:
44
61
  text_detector_kwargs.update(configs["text_detector"])
45
62
  if "text_recognizer" in configs:
@@ -52,26 +69,6 @@ class OCR:
52
69
  self.detector = TextDetector(**text_detector_kwargs)
53
70
  self.recognizer = TextRecognizer(**text_recognizer_kwargs)
54
71
 
55
- def aggregate(self, det_outputs, rec_outputs):
56
- words = []
57
- for points, det_score, pred, rec_score, direction in zip(
58
- det_outputs.points,
59
- det_outputs.scores,
60
- rec_outputs.contents,
61
- rec_outputs.scores,
62
- rec_outputs.directions,
63
- ):
64
- words.append(
65
- {
66
- "points": points,
67
- "content": pred,
68
- "direction": direction,
69
- "det_score": det_score,
70
- "rec_score": rec_score,
71
- }
72
- )
73
- return words
74
-
75
72
  def __call__(self, img):
76
73
  """_summary_
77
74
 
@@ -82,6 +79,6 @@ class OCR:
82
79
  det_outputs, vis = self.detector(img)
83
80
  rec_outputs, vis = self.recognizer(img, det_outputs.points, vis=vis)
84
81
 
85
- outputs = {"words": self.aggregate(det_outputs, rec_outputs)}
82
+ outputs = {"words": ocr_aggregate(det_outputs, rec_outputs)}
86
83
  results = OCRSchema(**outputs)
87
84
  return results, vis
yomitoku/onnx/.gitkeep ADDED
File without changes
@@ -1,13 +1,12 @@
1
1
  import cv2
2
+ import math
2
3
  import numpy as np
3
4
  import pyclipper
4
5
  from shapely.geometry import Polygon
5
6
 
6
7
 
7
8
  class DBnetPostProcessor:
8
- def __init__(
9
- self, min_size, thresh, box_thresh, max_candidates, unclip_ratio
10
- ):
9
+ def __init__(self, min_size, thresh, box_thresh, max_candidates, unclip_ratio):
11
10
  self.min_size = min_size
12
11
  self.thresh = thresh
13
12
  self.box_thresh = box_thresh
@@ -24,9 +23,7 @@ class DBnetPostProcessor:
24
23
  pred = preds["binary"][0]
25
24
  segmentation = self.binarize(pred)[0]
26
25
  height, width = image_size
27
- quads, scores = self.boxes_from_bitmap(
28
- pred, segmentation, width, height
29
- )
26
+ quads, scores = self.boxes_from_bitmap(pred, segmentation, width, height)
30
27
  return quads, scores
31
28
 
32
29
  def binarize(self, pred):
@@ -65,9 +62,7 @@ class DBnetPostProcessor:
65
62
  if self.box_thresh > score:
66
63
  continue
67
64
 
68
- box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(
69
- -1, 1, 2
70
- )
65
+ box = self.unclip(points, unclip_ratio=self.unclip_ratio).reshape(-1, 1, 2)
71
66
  box, sside = self.get_mini_boxes(box)
72
67
  if sside < self.min_size + 2:
73
68
  continue
@@ -76,9 +71,7 @@ class DBnetPostProcessor:
76
71
  dest_width = dest_width.item()
77
72
  dest_height = dest_height.item()
78
73
 
79
- box[:, 0] = np.clip(
80
- np.round(box[:, 0] / width * dest_width), 0, dest_width
81
- )
74
+ box[:, 0] = np.clip(np.round(box[:, 0] / width * dest_width), 0, dest_width)
82
75
  box[:, 1] = np.clip(
83
76
  np.round(box[:, 1] / height * dest_height), 0, dest_height
84
77
  )
@@ -88,9 +81,17 @@ class DBnetPostProcessor:
88
81
 
89
82
  return boxes, scores
90
83
 
91
- def unclip(self, box, unclip_ratio=1.5):
84
+ def unclip(self, box, unclip_ratio=7):
85
+ # 小さい文字が見切れやすい、大きい文字のマージンが過度に大きくなる等の課題がある
86
+ # 対応として、文字の大きさに応じて、拡大パラメータを動的に変更する
87
+ # Note: こののルールはヒューリスティックで理論的根拠はない
92
88
  poly = Polygon(box)
93
- distance = poly.area * unclip_ratio / poly.length
89
+ width = box[:, 0].max() - box[:, 0].min()
90
+ height = box[:, 1].max() - box[:, 1].min()
91
+ box_dist = min(width, height)
92
+ ratio = unclip_ratio / math.sqrt(box_dist)
93
+
94
+ distance = poly.area * ratio / poly.length
94
95
  offset = pyclipper.PyclipperOffset()
95
96
  offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
96
97
  expanded = np.array(offset.Execute(distance))
@@ -122,7 +122,5 @@ class ParseqTokenizer(BaseTokenizer):
122
122
  eos_idx = len(ids) # Nothing to truncate.
123
123
  # Truncate after EOS
124
124
  ids = ids[:eos_idx]
125
- probs = probs[
126
- : eos_idx + 1
127
- ] # but include prob. for EOS (if it exists)
125
+ probs = probs[: eos_idx + 1] # but include prob. for EOS (if it exists)
128
126
  return probs, ids
@@ -1,4 +1,17 @@
1
- """Copyright(c) 2023 lyuwenyu. All Rights Reserved."""
1
+ # Copyright 2023 lyuwenyu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
2
15
 
3
16
  import torch
4
17
  import torch.nn as nn
@@ -1,11 +1,16 @@
1
1
  from typing import List, Union
2
2
 
3
3
  import cv2
4
+ import os
5
+ import onnx
6
+ import onnxruntime
4
7
  import torch
5
8
  import torchvision.transforms as T
6
9
  from PIL import Image
7
10
  from pydantic import conlist
8
11
 
12
+ from .constants import ROOT_DIR
13
+
9
14
  from .base import BaseModelCatalog, BaseModule, BaseSchema
10
15
  from .configs import TableStructureRecognizerRTDETRv2Config
11
16
  from .layout_parser import filter_contained_rectangles_within_category
@@ -30,11 +35,19 @@ class TableCellSchema(BaseSchema):
30
35
  contents: Union[str, None]
31
36
 
32
37
 
38
+ class TableLineSchema(BaseSchema):
39
+ box: conlist(int, min_length=4, max_length=4)
40
+ score: float
41
+
42
+
33
43
  class TableStructureRecognizerSchema(BaseSchema):
34
44
  box: conlist(int, min_length=4, max_length=4)
35
45
  n_row: int
36
46
  n_col: int
47
+ rows: List[TableLineSchema]
48
+ cols: List[TableLineSchema]
37
49
  cells: List[TableCellSchema]
50
+ spans: List[TableLineSchema]
38
51
  order: int
39
52
 
40
53
 
@@ -109,12 +122,13 @@ class TableStructureRecognizer(BaseModule):
109
122
  device="cuda",
110
123
  visualize=False,
111
124
  from_pretrained=True,
125
+ infer_onnx=False,
112
126
  ):
113
127
  super().__init__()
114
128
  self.load_model(
115
129
  model_name,
116
130
  path_cfg,
117
- from_pretrained=True,
131
+ from_pretrained=from_pretrained,
118
132
  )
119
133
  self.device = device
120
134
  self.visualize = visualize
@@ -140,6 +154,45 @@ class TableStructureRecognizer(BaseModule):
140
154
  id: category for id, category in enumerate(self._cfg.category)
141
155
  }
142
156
 
157
+ self.infer_onnx = infer_onnx
158
+ if infer_onnx:
159
+ name = self._cfg.hf_hub_repo.split("/")[-1]
160
+ path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
161
+ if not os.path.exists(path_onnx):
162
+ self.convert_onnx(path_onnx)
163
+
164
+ self.model = None
165
+
166
+ model = onnx.load(path_onnx)
167
+ if torch.cuda.is_available() and device == "cuda":
168
+ self.sess = onnxruntime.InferenceSession(
169
+ model.SerializeToString(), providers=["CUDAExecutionProvider"]
170
+ )
171
+ else:
172
+ self.sess = onnxruntime.InferenceSession(model.SerializeToString())
173
+
174
+ if self.model is not None:
175
+ self.model.to(self.device)
176
+
177
+ def convert_onnx(self, path_onnx):
178
+ dynamic_axes = {
179
+ "input": {0: "batch_size"},
180
+ "output": {0: "batch_size"},
181
+ }
182
+
183
+ img_size = self._cfg.data.img_size
184
+ dummy_input = torch.randn(1, 3, *img_size, requires_grad=True)
185
+
186
+ torch.onnx.export(
187
+ self.model,
188
+ dummy_input,
189
+ path_onnx,
190
+ opset_version=16,
191
+ input_names=["input"],
192
+ output_names=["pred_logits", "pred_boxes"],
193
+ dynamic_axes=dynamic_axes,
194
+ )
195
+
143
196
  def preprocess(self, img, boxes):
144
197
  cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
145
198
 
@@ -149,7 +202,7 @@ class TableStructureRecognizer(BaseModule):
149
202
  table_img = cv_img[y1:y2, x1:x2, :]
150
203
  th, hw = table_img.shape[:2]
151
204
  table_img = Image.fromarray(table_img)
152
- img_tensor = self.transforms(table_img)[None].to(self.device)
205
+ img_tensor = self.transforms(table_img)[None]
153
206
  table_imgs.append(
154
207
  {
155
208
  "tensor": img_tensor,
@@ -190,7 +243,7 @@ class TableStructureRecognizer(BaseModule):
190
243
  category_elements
191
244
  )
192
245
 
193
- cells, n_row, n_col = self.extract_cell_elements(category_elements)
246
+ cells, rows, cols, spans = self.extract_cell_elements(category_elements)
194
247
 
195
248
  table_x, table_y = data["offset"]
196
249
  table_x2 = table_x + data["size"][1]
@@ -199,8 +252,11 @@ class TableStructureRecognizer(BaseModule):
199
252
 
200
253
  table = {
201
254
  "box": table_box,
202
- "n_row": n_row,
203
- "n_col": n_col,
255
+ "n_row": len(rows),
256
+ "n_col": len(cols),
257
+ "rows": rows,
258
+ "cols": cols,
259
+ "spans": spans,
204
260
  "cells": cells,
205
261
  "order": 0,
206
262
  }
@@ -220,16 +276,33 @@ class TableStructureRecognizer(BaseModule):
220
276
  cells = extract_cells(row_boxes, col_boxes)
221
277
  cells = filter_contained_cells_within_spancell(cells, span_boxes)
222
278
 
223
- return cells, len(row_boxes), len(col_boxes)
279
+ rows = sorted(elements["row"], key=lambda x: x["box"][1])
280
+ cols = sorted(elements["col"], key=lambda x: x["box"][0])
281
+ spans = sorted(elements["span"], key=lambda x: x["box"][1])
282
+
283
+ return cells, rows, cols, spans
224
284
 
225
285
  def __call__(self, img, table_boxes, vis=None):
226
286
  img_tensors = self.preprocess(img, table_boxes)
227
287
  outputs = []
228
288
  for data in img_tensors:
229
- with torch.inference_mode():
230
- pred = self.model(data["tensor"])
289
+ if self.infer_onnx:
290
+ input = data["tensor"].numpy()
291
+ results = self.sess.run(None, {"input": input})
292
+ pred = {
293
+ "pred_logits": torch.tensor(results[0]).to(self.device),
294
+ "pred_boxes": torch.tensor(results[1]).to(self.device),
295
+ }
296
+
297
+ else:
298
+ with torch.inference_mode():
299
+ data["tensor"] = data["tensor"].to(self.device)
300
+ pred = self.model(data["tensor"])
301
+
231
302
  table = self.postprocess(pred, data)
232
- outputs.append(table)
303
+
304
+ if table.n_row > 0 and table.n_col > 0:
305
+ outputs.append(table)
233
306
 
234
307
  if vis is None and self.visualize:
235
308
  vis = img.copy()
yomitoku/text_detector.py CHANGED
@@ -2,6 +2,7 @@ from typing import List
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
+ import os
5
6
  from pydantic import conlist
6
7
 
7
8
  from .base import BaseModelCatalog, BaseModule, BaseSchema
@@ -14,6 +15,10 @@ from .data.functions import (
14
15
  from .models import DBNet
15
16
  from .postprocessor import DBnetPostProcessor
16
17
  from .utils.visualizer import det_visualizer
18
+ from .constants import ROOT_DIR
19
+
20
+ import onnx
21
+ import onnxruntime
17
22
 
18
23
 
19
24
  class TextDetectorModelCatalog(BaseModelCatalog):
@@ -43,21 +48,60 @@ class TextDetector(BaseModule):
43
48
  device="cuda",
44
49
  visualize=False,
45
50
  from_pretrained=True,
51
+ infer_onnx=False,
46
52
  ):
47
53
  super().__init__()
48
54
  self.load_model(
49
55
  model_name,
50
56
  path_cfg,
51
- from_pretrained=True,
57
+ from_pretrained=from_pretrained,
52
58
  )
53
59
 
54
60
  self.device = device
55
61
  self.visualize = visualize
56
62
 
57
63
  self.model.eval()
58
- self.model.to(self.device)
59
-
60
64
  self.post_processor = DBnetPostProcessor(**self._cfg.post_process)
65
+ self.infer_onnx = infer_onnx
66
+
67
+ if infer_onnx:
68
+ name = self._cfg.hf_hub_repo.split("/")[-1]
69
+ path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
70
+ if not os.path.exists(path_onnx):
71
+ self.convert_onnx(path_onnx)
72
+
73
+ self.model = None
74
+
75
+ model = onnx.load(path_onnx)
76
+ if torch.cuda.is_available() and device == "cuda":
77
+ self.sess = onnxruntime.InferenceSession(
78
+ model.SerializeToString(), providers=["CUDAExecutionProvider"]
79
+ )
80
+ else:
81
+ self.sess = onnxruntime.InferenceSession(model.SerializeToString())
82
+
83
+ self.model = None
84
+
85
+ if self.model is not None:
86
+ self.model.to(self.device)
87
+
88
+ def convert_onnx(self, path_onnx):
89
+ dynamic_axes = {
90
+ "input": {0: "batch_size", 2: "height", 3: "width"},
91
+ "output": {0: "batch_size", 2: "height", 3: "width"},
92
+ }
93
+
94
+ dummy_input = torch.randn(1, 3, 256, 256, requires_grad=True)
95
+
96
+ torch.onnx.export(
97
+ self.model,
98
+ dummy_input,
99
+ path_onnx,
100
+ opset_version=14,
101
+ input_names=["input"],
102
+ output_names=["output"],
103
+ dynamic_axes=dynamic_axes,
104
+ )
61
105
 
62
106
  def preprocess(self, img):
63
107
  img = img.copy()
@@ -81,9 +125,15 @@ class TextDetector(BaseModule):
81
125
 
82
126
  ori_h, ori_w = img.shape[:2]
83
127
  tensor = self.preprocess(img)
84
- tensor = tensor.to(self.device)
85
- with torch.inference_mode():
86
- preds = self.model(tensor)
128
+
129
+ if self.infer_onnx:
130
+ input = tensor.numpy()
131
+ results = self.sess.run(["output"], {"input": input})
132
+ preds = {"binary": torch.tensor(results[0])}
133
+ else:
134
+ with torch.inference_mode():
135
+ tensor = tensor.to(self.device)
136
+ preds = self.model(tensor)
87
137
 
88
138
  quads, scores = self.postprocess(preds, (ori_h, ori_w))
89
139
  outputs = {"points": quads, "scores": scores}
@@ -93,9 +143,9 @@ class TextDetector(BaseModule):
93
143
  vis = None
94
144
  if self.visualize:
95
145
  vis = det_visualizer(
96
- preds,
97
146
  img,
98
147
  quads,
148
+ preds=preds,
99
149
  vis_heatmap=self._cfg.visualize.heatmap,
100
150
  line_color=tuple(self._cfg.visualize.color[::-1]),
101
151
  )
@@ -2,21 +2,28 @@ from typing import List
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
+ import os
6
+ import unicodedata
5
7
  from pydantic import conlist
6
8
 
7
9
  from .base import BaseModelCatalog, BaseModule, BaseSchema
8
- from .configs import TextRecognizerPARSeqConfig
10
+ from .configs import TextRecognizerPARSeqConfig, TextRecognizerPARSeqSmallConfig
9
11
  from .data.dataset import ParseqDataset
10
12
  from .models import PARSeq
11
13
  from .postprocessor import ParseqTokenizer as Tokenizer
12
14
  from .utils.misc import load_charset
13
15
  from .utils.visualizer import rec_visualizer
14
16
 
17
+ from .constants import ROOT_DIR
18
+ import onnx
19
+ import onnxruntime
20
+
15
21
 
16
22
  class TextRecognizerModelCatalog(BaseModelCatalog):
17
23
  def __init__(self):
18
24
  super().__init__()
19
25
  self.register("parseq", TextRecognizerPARSeqConfig, PARSeq)
26
+ self.register("parseq-small", TextRecognizerPARSeqSmallConfig, PARSeq)
20
27
 
21
28
 
22
29
  class TextRecognizerSchema(BaseSchema):
@@ -42,36 +49,91 @@ class TextRecognizer(BaseModule):
42
49
  device="cuda",
43
50
  visualize=False,
44
51
  from_pretrained=True,
52
+ infer_onnx=False,
45
53
  ):
46
54
  super().__init__()
47
55
  self.load_model(
48
56
  model_name,
49
57
  path_cfg,
50
- from_pretrained=True,
58
+ from_pretrained=from_pretrained,
51
59
  )
52
60
  self.charset = load_charset(self._cfg.charset)
53
61
  self.tokenizer = Tokenizer(self.charset)
54
62
 
55
63
  self.device = device
56
64
 
65
+ self.model.tokenizer = self.tokenizer
57
66
  self.model.eval()
58
- self.model.to(self.device)
59
67
 
60
68
  self.visualize = visualize
61
69
 
70
+ self.infer_onnx = infer_onnx
71
+
72
+ if infer_onnx:
73
+ name = self._cfg.hf_hub_repo.split("/")[-1]
74
+ path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
75
+ if not os.path.exists(path_onnx):
76
+ self.convert_onnx(path_onnx)
77
+
78
+ self.model = None
79
+
80
+ model = onnx.load(path_onnx)
81
+ if torch.cuda.is_available() and device == "cuda":
82
+ self.sess = onnxruntime.InferenceSession(
83
+ model.SerializeToString(), providers=["CUDAExecutionProvider"]
84
+ )
85
+ else:
86
+ self.sess = onnxruntime.InferenceSession(model.SerializeToString())
87
+
88
+ if self.model is not None:
89
+ self.model.to(self.device)
90
+
62
91
  def preprocess(self, img, polygons):
63
92
  dataset = ParseqDataset(self._cfg, img, polygons)
64
- dataloader = torch.utils.data.DataLoader(
65
- dataset,
66
- batch_size=self._cfg.data.batch_size,
67
- shuffle=False,
68
- num_workers=self._cfg.data.num_workers,
69
- )
93
+ dataloader = self._make_mini_batch(dataset)
70
94
 
71
95
  return dataloader
72
96
 
97
+ def _make_mini_batch(self, dataset):
98
+ mini_batches = []
99
+ mini_batch = []
100
+ for data in dataset:
101
+ data = torch.unsqueeze(data, 0)
102
+ mini_batch.append(data)
103
+
104
+ if len(mini_batch) == self._cfg.data.batch_size:
105
+ mini_batches.append(torch.cat(mini_batch, 0))
106
+ mini_batch = []
107
+ else:
108
+ if len(mini_batch) > 0:
109
+ mini_batches.append(torch.cat(mini_batch, 0))
110
+
111
+ return mini_batches
112
+
113
+ def convert_onnx(self, path_onnx):
114
+ img_size = self._cfg.data.img_size
115
+ input = torch.randn(1, 3, *img_size, requires_grad=True)
116
+ dynamic_axes = {
117
+ "input": {0: "batch_size"},
118
+ "output": {0: "batch_size"},
119
+ }
120
+
121
+ self.model.export_onnx = True
122
+ torch.onnx.export(
123
+ self.model,
124
+ input,
125
+ path_onnx,
126
+ opset_version=14,
127
+ input_names=["input"],
128
+ output_names=["output"],
129
+ do_constant_folding=True,
130
+ dynamic_axes=dynamic_axes,
131
+ )
132
+
73
133
  def postprocess(self, p, points):
74
134
  pred, score = self.tokenizer.decode(p)
135
+ pred = [unicodedata.normalize("NFKC", x) for x in pred]
136
+
75
137
  directions = []
76
138
  for point in points:
77
139
  point = np.array(point)
@@ -98,13 +160,19 @@ class TextRecognizer(BaseModule):
98
160
  scores = []
99
161
  directions = []
100
162
  for data in dataloader:
101
- data = data.to(self.device)
102
- with torch.inference_mode():
103
- p = self.model(self.tokenizer, data).softmax(-1)
104
- pred, score, direction = self.postprocess(p, points)
105
- preds.extend(pred)
106
- scores.extend(score)
107
- directions.extend(direction)
163
+ if self.infer_onnx:
164
+ input = data.numpy()
165
+ results = self.sess.run(["output"], {"input": input})
166
+ p = torch.tensor(results[0])
167
+ else:
168
+ with torch.inference_mode():
169
+ data = data.to(self.device)
170
+ p = self.model(data).softmax(-1)
171
+
172
+ pred, score, direction = self.postprocess(p, points)
173
+ preds.extend(pred)
174
+ scores.extend(score)
175
+ directions.extend(direction)
108
176
 
109
177
  outputs = {
110
178
  "contents": preds,
yomitoku/utils/misc.py CHANGED
@@ -1,5 +1,5 @@
1
1
  def load_charset(charset_path):
2
- with open(charset_path, "r") as f:
2
+ with open(charset_path, "r", encoding="utf-8") as f:
3
3
  charset = f.read()
4
4
  return charset
5
5
 
@@ -9,6 +9,24 @@ def filter_by_flag(elements, flags):
9
9
  return [element for element, flag in zip(elements, flags) if flag]
10
10
 
11
11
 
12
+ def calc_overlap_ratio(rect_a, rect_b):
13
+ intersection = calc_intersection(rect_a, rect_b)
14
+ if intersection is None:
15
+ return 0, None
16
+
17
+ ix1, iy1, ix2, iy2 = intersection
18
+
19
+ overlap_width = ix2 - ix1
20
+ overlap_height = iy2 - iy1
21
+ bx1, by1, bx2, by2 = rect_b
22
+
23
+ b_area = (bx2 - bx1) * (by2 - by1)
24
+ overlap_area = overlap_width * overlap_height
25
+
26
+ overlap_ratio = overlap_area / b_area
27
+ return overlap_ratio, intersection
28
+
29
+
12
30
  def is_contained(rect_a, rect_b, threshold=0.8):
13
31
  """二つの矩形A, Bが与えられたとき、矩形Bが矩形Aに含まれるかどうかを判定する。
14
32
  ずれを許容するため、重複率求め、thresholdを超える場合にTrueを返す。
@@ -23,20 +41,9 @@ def is_contained(rect_a, rect_b, threshold=0.8):
23
41
  bool: 矩形Bが矩形Aに含まれる場合True
24
42
  """
25
43
 
26
- intersection = calc_intersection(rect_a, rect_b)
27
- if intersection is None:
28
- return False
29
-
30
- ix1, iy1, ix2, iy2 = intersection
31
-
32
- overlap_width = ix2 - ix1
33
- overlap_height = iy2 - iy1
34
- bx1, by1, bx2, by2 = rect_b
35
-
36
- b_area = (bx2 - bx1) * (by2 - by1)
37
- overlap_area = overlap_width * overlap_height
44
+ overlap_ratio, _ = calc_overlap_ratio(rect_a, rect_b)
38
45
 
39
- if overlap_area / b_area > threshold:
46
+ if overlap_ratio > threshold:
40
47
  return True
41
48
 
42
49
  return False