yomitoku 0.5.2__py3-none-any.whl → 0.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
yomitoku/cli/main.py CHANGED
@@ -104,6 +104,12 @@ def main():
104
104
  default="results",
105
105
  help="output directory",
106
106
  )
107
+ parser.add_argument(
108
+ "-l",
109
+ "--lite",
110
+ action="store_true",
111
+ help="if set, use lite model",
112
+ )
107
113
  parser.add_argument(
108
114
  "-d",
109
115
  "--device",
@@ -197,6 +203,15 @@ def main():
197
203
  },
198
204
  }
199
205
 
206
+ if args.lite:
207
+ configs["ocr"]["text_recognizer"]["model_name"] = "parseq-small"
208
+ configs["ocr"]["text_detector"]["infer_onnx"] = True
209
+
210
+ # Note: Text Detector以外はONNX推論よりもPyTorch推論の方が速いため、ONNX推論は行わない
211
+ # configs["ocr"]["text_recognizer"]["infer_onnx"] = True
212
+ # configs["layout_analyzer"]["table_structure_recognizer"]["infer_onnx"] = True
213
+ # configs["layout_analyzer"]["layout_parser"]["infer_onnx"] = True
214
+
200
215
  analyzer = DocumentAnalyzer(
201
216
  configs=configs,
202
217
  visualize=args.vis,
@@ -4,10 +4,12 @@ from .cfg_table_structure_recognizer_rtdtrv2 import (
4
4
  )
5
5
  from .cfg_text_detector_dbnet import TextDetectorDBNetConfig
6
6
  from .cfg_text_recognizer_parseq import TextRecognizerPARSeqConfig
7
+ from .cfg_text_recognizer_parseq_small import TextRecognizerPARSeqSmallConfig
7
8
 
8
9
  __all__ = [
9
10
  "TextDetectorDBNetConfig",
10
11
  "TextRecognizerPARSeqConfig",
11
12
  "LayoutParserRTDETRv2Config",
12
13
  "TableStructureRecognizerRTDETRv2Config",
14
+ "TextRecognizerPARSeqSmallConfig",
13
15
  ]
@@ -0,0 +1,51 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ from ..constants import ROOT_DIR
5
+
6
+
7
+ @dataclass
8
+ class Data:
9
+ num_workers: int = 4
10
+ batch_size: int = 128
11
+ img_size: List[int] = field(default_factory=lambda: [32, 800])
12
+
13
+
14
+ @dataclass
15
+ class Encoder:
16
+ patch_size: List[int] = field(default_factory=lambda: [16, 16])
17
+ num_heads: int = 8
18
+ embed_dim: int = 384
19
+ mlp_ratio: int = 4
20
+ depth: int = 9
21
+
22
+
23
+ @dataclass
24
+ class Decoder:
25
+ embed_dim: int = 384
26
+ num_heads: int = 8
27
+ mlp_ratio: int = 4
28
+ depth: int = 1
29
+
30
+
31
+ @dataclass
32
+ class Visualize:
33
+ font: str = str(ROOT_DIR + "/resource/MPLUS1p-Medium.ttf")
34
+ color: List[int] = field(default_factory=lambda: [0, 0, 255]) # RGB
35
+ font_size: int = 18
36
+
37
+
38
+ @dataclass
39
+ class TextRecognizerPARSeqSmallConfig:
40
+ hf_hub_repo: str = "KotaroKinoshita/yomitoku-text-recognizer-parseq-small-open-beta"
41
+ charset: str = str(ROOT_DIR + "/resource/charset.txt")
42
+ num_tokens: int = 7312
43
+ max_label_length: int = 100
44
+ decode_ar: int = 1
45
+ refine_iters: int = 1
46
+
47
+ data: Data = field(default_factory=Data)
48
+ encoder: Encoder = field(default_factory=Encoder)
49
+ decoder: Decoder = field(default_factory=Decoder)
50
+
51
+ visualize: Visualize = field(default_factory=Visualize)
yomitoku/layout_parser.py CHANGED
@@ -1,11 +1,16 @@
1
1
  from typing import List, Union
2
2
 
3
3
  import cv2
4
+ import os
5
+ import onnx
6
+ import onnxruntime
4
7
  import torch
5
8
  import torchvision.transforms as T
6
9
  from PIL import Image
7
10
  from pydantic import conlist
8
11
 
12
+ from .constants import ROOT_DIR
13
+
9
14
  from .base import BaseModelCatalog, BaseModule, BaseSchema
10
15
  from .configs import LayoutParserRTDETRv2Config
11
16
  from .models import RTDETRv2
@@ -91,6 +96,7 @@ class LayoutParser(BaseModule):
91
96
  device="cuda",
92
97
  visualize=False,
93
98
  from_pretrained=True,
99
+ infer_onnx=False,
94
100
  ):
95
101
  super().__init__()
96
102
  self.load_model(model_name, path_cfg, from_pretrained)
@@ -119,11 +125,44 @@ class LayoutParser(BaseModule):
119
125
  }
120
126
 
121
127
  self.role = self._cfg.role
128
+ self.infer_onnx = infer_onnx
129
+ if infer_onnx:
130
+ name = self._cfg.hf_hub_repo.split("/")[-1]
131
+ path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
132
+ if not os.path.exists(path_onnx):
133
+ self.convert_onnx(path_onnx)
134
+
135
+ model = onnx.load(path_onnx)
136
+ if torch.cuda.is_available() and device == "cuda":
137
+ self.sess = onnxruntime.InferenceSession(
138
+ model.SerializeToString(), providers=["CUDAExecutionProvider"]
139
+ )
140
+ else:
141
+ self.sess = onnxruntime.InferenceSession(model.SerializeToString())
142
+
143
+ def convert_onnx(self, path_onnx):
144
+ dynamic_axes = {
145
+ "input": {0: "batch_size"},
146
+ "output": {0: "batch_size"},
147
+ }
148
+
149
+ img_size = self._cfg.data.img_size
150
+ dummy_input = torch.randn(1, 3, *img_size, requires_grad=True)
151
+
152
+ torch.onnx.export(
153
+ self.model,
154
+ dummy_input,
155
+ path_onnx,
156
+ opset_version=16,
157
+ input_names=["input"],
158
+ output_names=["pred_logits", "pred_boxes"],
159
+ dynamic_axes=dynamic_axes,
160
+ )
122
161
 
123
162
  def preprocess(self, img):
124
163
  cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
125
164
  img = Image.fromarray(cv_img)
126
- img_tensor = self.transforms(img)[None].to(self.device)
165
+ img_tensor = self.transforms(img)[None]
127
166
  return img_tensor
128
167
 
129
168
  def postprocess(self, preds, image_size):
@@ -175,8 +214,19 @@ class LayoutParser(BaseModule):
175
214
  ori_h, ori_w = img.shape[:2]
176
215
  img_tensor = self.preprocess(img)
177
216
 
178
- with torch.inference_mode():
179
- preds = self.model(img_tensor)
217
+ if self.infer_onnx:
218
+ input = img_tensor.numpy()
219
+ results = self.sess.run(None, {"input": input})
220
+ preds = {
221
+ "pred_logits": torch.tensor(results[0]).to(self.device),
222
+ "pred_boxes": torch.tensor(results[1]).to(self.device),
223
+ }
224
+
225
+ else:
226
+ with torch.inference_mode():
227
+ img_tensor = img_tensor.to(self.device)
228
+ preds = self.model(img_tensor)
229
+
180
230
  results = self.postprocess(preds, (ori_h, ori_w))
181
231
 
182
232
  vis = None
@@ -1,3 +1,16 @@
1
+ # Copyright(c) 2023 lyuwenyu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  import torch.nn as nn
2
15
 
3
16
 
@@ -1,4 +1,16 @@
1
- """Copyright(c) 2023 lyuwenyu. All Rights Reserved."""
1
+ # Copyright 2023 lyuwenyu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
2
14
 
3
15
  from collections import OrderedDict
4
16
 
@@ -1,4 +1,16 @@
1
- """Copyright(c) 2023 lyuwenyu. All Rights Reserved."""
1
+ # Copyright 2023 lyuwenyu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
2
14
 
3
15
  import copy
4
16
  from collections import OrderedDict
@@ -1,4 +1,17 @@
1
- """Copyright(c) 2023 lyuwenyu. All Rights Reserved."""
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2023 lyuwenyu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
2
15
 
3
16
  import copy
4
17
  import functools
yomitoku/models/parseq.py CHANGED
@@ -22,7 +22,6 @@ from huggingface_hub import PyTorchModelHubMixin
22
22
  from timm.models.helpers import named_apply
23
23
  from torch import Tensor
24
24
 
25
- from ..postprocessor import ParseqTokenizer as Tokenizer
26
25
  from .layers.parseq_transformer import Decoder, Encoder, TokenEmbedding
27
26
 
28
27
 
@@ -123,7 +122,6 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
123
122
 
124
123
  def forward(
125
124
  self,
126
- tokenizer: Tokenizer,
127
125
  images: Tensor,
128
126
  max_length: Optional[int] = None,
129
127
  ) -> Tensor:
@@ -150,11 +148,11 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
150
148
  if self.decode_ar:
151
149
  tgt_in = torch.full(
152
150
  (bs, num_steps),
153
- tokenizer.pad_id,
151
+ self.tokenizer.pad_id,
154
152
  dtype=torch.long,
155
153
  device=self._device,
156
154
  )
157
- tgt_in[:, 0] = tokenizer.bos_id
155
+ tgt_in[:, 0] = self.tokenizer.bos_id
158
156
 
159
157
  logits = []
160
158
  for i in range(num_steps):
@@ -177,7 +175,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
177
175
  # greedy decode. add the next token index to the target input
178
176
  tgt_in[:, j] = p_i.squeeze().argmax(-1)
179
177
  # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
180
- if testing and (tgt_in == tokenizer.eos_id).any(dim=-1).all():
178
+ if testing and (tgt_in == self.tokenizer.eos_id).any(dim=-1).all():
181
179
  break
182
180
 
183
181
  logits = torch.cat(logits, dim=1)
@@ -185,7 +183,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
185
183
  # No prior context, so input is just <bos>. We query all positions.
186
184
  tgt_in = torch.full(
187
185
  (bs, 1),
188
- tokenizer.bos_id,
186
+ self.tokenizer.bos_id,
189
187
  dtype=torch.long,
190
188
  device=self._device,
191
189
  )
@@ -200,7 +198,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
200
198
  torch.ones(
201
199
  num_steps,
202
200
  num_steps,
203
- dtype=torch.bool,
201
+ dtype=torch.int64,
204
202
  device=self._device,
205
203
  ),
206
204
  2,
@@ -208,7 +206,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
208
206
  ] = 0
209
207
  bos = torch.full(
210
208
  (bs, 1),
211
- tokenizer.bos_id,
209
+ self.tokenizer.bos_id,
212
210
  dtype=torch.long,
213
211
  device=self._device,
214
212
  )
@@ -216,7 +214,9 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
216
214
  # Prior context is the previous output.
217
215
  tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
218
216
  # Mask tokens beyond the first EOS token.
219
- tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(-1) > 0
217
+ tgt_padding_mask = (tgt_in == self.tokenizer.eos_id).int().cumsum(
218
+ -1
219
+ ) > 0
220
220
  tgt_out = self.decode(
221
221
  tgt_in,
222
222
  memory,
yomitoku/onnx/.gitkeep ADDED
File without changes
@@ -1,4 +1,17 @@
1
- """Copyright(c) 2023 lyuwenyu. All Rights Reserved."""
1
+ # Copyright 2023 lyuwenyu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
2
15
 
3
16
  import torch
4
17
  import torch.nn as nn
@@ -1,11 +1,16 @@
1
1
  from typing import List, Union
2
2
 
3
3
  import cv2
4
+ import os
5
+ import onnx
6
+ import onnxruntime
4
7
  import torch
5
8
  import torchvision.transforms as T
6
9
  from PIL import Image
7
10
  from pydantic import conlist
8
11
 
12
+ from .constants import ROOT_DIR
13
+
9
14
  from .base import BaseModelCatalog, BaseModule, BaseSchema
10
15
  from .configs import TableStructureRecognizerRTDETRv2Config
11
16
  from .layout_parser import filter_contained_rectangles_within_category
@@ -109,12 +114,13 @@ class TableStructureRecognizer(BaseModule):
109
114
  device="cuda",
110
115
  visualize=False,
111
116
  from_pretrained=True,
117
+ infer_onnx=False,
112
118
  ):
113
119
  super().__init__()
114
120
  self.load_model(
115
121
  model_name,
116
122
  path_cfg,
117
- from_pretrained=True,
123
+ from_pretrained=from_pretrained,
118
124
  )
119
125
  self.device = device
120
126
  self.visualize = visualize
@@ -127,6 +133,8 @@ class TableStructureRecognizer(BaseModule):
127
133
  num_top_queries=self._cfg.RTDETRTransformerv2.num_queries,
128
134
  )
129
135
 
136
+ self.save_config("table_structure_recognitizer.yaml")
137
+
130
138
  self.transforms = T.Compose(
131
139
  [
132
140
  T.Resize(self._cfg.data.img_size),
@@ -140,6 +148,40 @@ class TableStructureRecognizer(BaseModule):
140
148
  id: category for id, category in enumerate(self._cfg.category)
141
149
  }
142
150
 
151
+ self.infer_onnx = infer_onnx
152
+ if infer_onnx:
153
+ name = self._cfg.hf_hub_repo.split("/")[-1]
154
+ path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
155
+ if not os.path.exists(path_onnx):
156
+ self.convert_onnx(path_onnx)
157
+
158
+ model = onnx.load(path_onnx)
159
+ if torch.cuda.is_available() and device == "cuda":
160
+ self.sess = onnxruntime.InferenceSession(
161
+ model.SerializeToString(), providers=["CUDAExecutionProvider"]
162
+ )
163
+ else:
164
+ self.sess = onnxruntime.InferenceSession(model.SerializeToString())
165
+
166
+ def convert_onnx(self, path_onnx):
167
+ dynamic_axes = {
168
+ "input": {0: "batch_size"},
169
+ "output": {0: "batch_size"},
170
+ }
171
+
172
+ img_size = self._cfg.data.img_size
173
+ dummy_input = torch.randn(1, 3, *img_size, requires_grad=True)
174
+
175
+ torch.onnx.export(
176
+ self.model,
177
+ dummy_input,
178
+ path_onnx,
179
+ opset_version=16,
180
+ input_names=["input"],
181
+ output_names=["pred_logits", "pred_boxes"],
182
+ dynamic_axes=dynamic_axes,
183
+ )
184
+
143
185
  def preprocess(self, img, boxes):
144
186
  cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
145
187
 
@@ -149,7 +191,7 @@ class TableStructureRecognizer(BaseModule):
149
191
  table_img = cv_img[y1:y2, x1:x2, :]
150
192
  th, hw = table_img.shape[:2]
151
193
  table_img = Image.fromarray(table_img)
152
- img_tensor = self.transforms(table_img)[None].to(self.device)
194
+ img_tensor = self.transforms(table_img)[None]
153
195
  table_imgs.append(
154
196
  {
155
197
  "tensor": img_tensor,
@@ -226,8 +268,19 @@ class TableStructureRecognizer(BaseModule):
226
268
  img_tensors = self.preprocess(img, table_boxes)
227
269
  outputs = []
228
270
  for data in img_tensors:
229
- with torch.inference_mode():
230
- pred = self.model(data["tensor"])
271
+ if self.infer_onnx:
272
+ input = data["tensor"].numpy()
273
+ results = self.sess.run(None, {"input": input})
274
+ pred = {
275
+ "pred_logits": torch.tensor(results[0]).to(self.device),
276
+ "pred_boxes": torch.tensor(results[1]).to(self.device),
277
+ }
278
+
279
+ else:
280
+ with torch.inference_mode():
281
+ data["tensor"] = data["tensor"].to(self.device)
282
+ pred = self.model(data["tensor"])
283
+
231
284
  table = self.postprocess(pred, data)
232
285
  outputs.append(table)
233
286
 
yomitoku/text_detector.py CHANGED
@@ -2,6 +2,7 @@ from typing import List
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
+ import os
5
6
  from pydantic import conlist
6
7
 
7
8
  from .base import BaseModelCatalog, BaseModule, BaseSchema
@@ -14,6 +15,10 @@ from .data.functions import (
14
15
  from .models import DBNet
15
16
  from .postprocessor import DBnetPostProcessor
16
17
  from .utils.visualizer import det_visualizer
18
+ from .constants import ROOT_DIR
19
+
20
+ import onnx
21
+ import onnxruntime
17
22
 
18
23
 
19
24
  class TextDetectorModelCatalog(BaseModelCatalog):
@@ -43,12 +48,13 @@ class TextDetector(BaseModule):
43
48
  device="cuda",
44
49
  visualize=False,
45
50
  from_pretrained=True,
51
+ infer_onnx=False,
46
52
  ):
47
53
  super().__init__()
48
54
  self.load_model(
49
55
  model_name,
50
56
  path_cfg,
51
- from_pretrained=True,
57
+ from_pretrained=from_pretrained,
52
58
  )
53
59
 
54
60
  self.device = device
@@ -58,6 +64,39 @@ class TextDetector(BaseModule):
58
64
  self.model.to(self.device)
59
65
 
60
66
  self.post_processor = DBnetPostProcessor(**self._cfg.post_process)
67
+ self.infer_onnx = infer_onnx
68
+
69
+ if infer_onnx:
70
+ name = self._cfg.hf_hub_repo.split("/")[-1]
71
+ path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
72
+ if not os.path.exists(path_onnx):
73
+ self.convert_onnx(path_onnx)
74
+
75
+ model = onnx.load(path_onnx)
76
+ if torch.cuda.is_available() and device == "cuda":
77
+ self.sess = onnxruntime.InferenceSession(
78
+ model.SerializeToString(), providers=["CUDAExecutionProvider"]
79
+ )
80
+ else:
81
+ self.sess = onnxruntime.InferenceSession(model.SerializeToString())
82
+
83
+ def convert_onnx(self, path_onnx):
84
+ dynamic_axes = {
85
+ "input": {0: "batch_size", 2: "height", 3: "width"},
86
+ "output": {0: "batch_size", 2: "height", 3: "width"},
87
+ }
88
+
89
+ dummy_input = torch.randn(1, 3, 256, 256, requires_grad=True)
90
+
91
+ torch.onnx.export(
92
+ self.model,
93
+ dummy_input,
94
+ path_onnx,
95
+ opset_version=14,
96
+ input_names=["input"],
97
+ output_names=["output"],
98
+ dynamic_axes=dynamic_axes,
99
+ )
61
100
 
62
101
  def preprocess(self, img):
63
102
  img = img.copy()
@@ -81,9 +120,15 @@ class TextDetector(BaseModule):
81
120
 
82
121
  ori_h, ori_w = img.shape[:2]
83
122
  tensor = self.preprocess(img)
84
- tensor = tensor.to(self.device)
85
- with torch.inference_mode():
86
- preds = self.model(tensor)
123
+
124
+ if self.infer_onnx:
125
+ input = tensor.numpy()
126
+ results = self.sess.run(["output"], {"input": input})
127
+ preds = {"binary": torch.tensor(results[0])}
128
+ else:
129
+ with torch.inference_mode():
130
+ tensor = tensor.to(self.device)
131
+ preds = self.model(tensor)
87
132
 
88
133
  quads, scores = self.postprocess(preds, (ori_h, ori_w))
89
134
  outputs = {"points": quads, "scores": scores}
@@ -2,22 +2,28 @@ from typing import List
2
2
 
3
3
  import numpy as np
4
4
  import torch
5
+ import os
5
6
  import unicodedata
6
7
  from pydantic import conlist
7
8
 
8
9
  from .base import BaseModelCatalog, BaseModule, BaseSchema
9
- from .configs import TextRecognizerPARSeqConfig
10
+ from .configs import TextRecognizerPARSeqConfig, TextRecognizerPARSeqSmallConfig
10
11
  from .data.dataset import ParseqDataset
11
12
  from .models import PARSeq
12
13
  from .postprocessor import ParseqTokenizer as Tokenizer
13
14
  from .utils.misc import load_charset
14
15
  from .utils.visualizer import rec_visualizer
15
16
 
17
+ from .constants import ROOT_DIR
18
+ import onnx
19
+ import onnxruntime
20
+
16
21
 
17
22
  class TextRecognizerModelCatalog(BaseModelCatalog):
18
23
  def __init__(self):
19
24
  super().__init__()
20
25
  self.register("parseq", TextRecognizerPARSeqConfig, PARSeq)
26
+ self.register("parseq-small", TextRecognizerPARSeqSmallConfig, PARSeq)
21
27
 
22
28
 
23
29
  class TextRecognizerSchema(BaseSchema):
@@ -43,23 +49,41 @@ class TextRecognizer(BaseModule):
43
49
  device="cuda",
44
50
  visualize=False,
45
51
  from_pretrained=True,
52
+ infer_onnx=False,
46
53
  ):
47
54
  super().__init__()
48
55
  self.load_model(
49
56
  model_name,
50
57
  path_cfg,
51
- from_pretrained=True,
58
+ from_pretrained=from_pretrained,
52
59
  )
53
60
  self.charset = load_charset(self._cfg.charset)
54
61
  self.tokenizer = Tokenizer(self.charset)
55
62
 
56
63
  self.device = device
57
64
 
65
+ self.model.tokenizer = self.tokenizer
58
66
  self.model.eval()
59
67
  self.model.to(self.device)
60
68
 
61
69
  self.visualize = visualize
62
70
 
71
+ self.infer_onnx = infer_onnx
72
+
73
+ if infer_onnx:
74
+ name = self._cfg.hf_hub_repo.split("/")[-1]
75
+ path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
76
+ if not os.path.exists(path_onnx):
77
+ self.convert_onnx(path_onnx)
78
+
79
+ model = onnx.load(path_onnx)
80
+ if torch.cuda.is_available() and device == "cuda":
81
+ self.sess = onnxruntime.InferenceSession(
82
+ model.SerializeToString(), providers=["CUDAExecutionProvider"]
83
+ )
84
+ else:
85
+ self.sess = onnxruntime.InferenceSession(model.SerializeToString())
86
+
63
87
  def preprocess(self, img, polygons):
64
88
  dataset = ParseqDataset(self._cfg, img, polygons)
65
89
  dataloader = torch.utils.data.DataLoader(
@@ -71,6 +95,25 @@ class TextRecognizer(BaseModule):
71
95
 
72
96
  return dataloader
73
97
 
98
+ def convert_onnx(self, path_onnx):
99
+ img_size = self._cfg.data.img_size
100
+ input = torch.randn(1, 3, *img_size, requires_grad=True)
101
+ dynamic_axes = {
102
+ "input": {0: "batch_size"},
103
+ "output": {0: "batch_size"},
104
+ }
105
+
106
+ torch.onnx.export(
107
+ self.model,
108
+ input,
109
+ path_onnx,
110
+ opset_version=14,
111
+ input_names=["input"],
112
+ output_names=["output"],
113
+ do_constant_folding=True,
114
+ dynamic_axes=dynamic_axes,
115
+ )
116
+
74
117
  def postprocess(self, p, points):
75
118
  pred, score = self.tokenizer.decode(p)
76
119
  pred = [unicodedata.normalize("NFKC", x) for x in pred]
@@ -101,13 +144,19 @@ class TextRecognizer(BaseModule):
101
144
  scores = []
102
145
  directions = []
103
146
  for data in dataloader:
104
- data = data.to(self.device)
105
- with torch.inference_mode():
106
- p = self.model(self.tokenizer, data).softmax(-1)
107
- pred, score, direction = self.postprocess(p, points)
108
- preds.extend(pred)
109
- scores.extend(score)
110
- directions.extend(direction)
147
+ if self.infer_onnx:
148
+ input = data.numpy()
149
+ results = self.sess.run(["output"], {"input": input})
150
+ p = torch.tensor(results[0])
151
+ else:
152
+ with torch.inference_mode():
153
+ data = data.to(self.device)
154
+ p = self.model(data).softmax(-1)
155
+
156
+ pred, score, direction = self.postprocess(p, points)
157
+ preds.extend(pred)
158
+ scores.extend(score)
159
+ directions.extend(direction)
111
160
 
112
161
  outputs = {
113
162
  "contents": preds,
@@ -1,14 +1,17 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: yomitoku
3
- Version: 0.5.2
3
+ Version: 0.6.0
4
4
  Summary: Yomitoku is an AI-powered document image analysis package designed specifically for the Japanese language.
5
5
  Author-email: Kotaro Kinoshita <kotaro.kinoshita@mlism.com>
6
6
  License: CC BY-NC-SA 4.0
7
7
  Keywords: Deep Learning,Japanese,OCR
8
- Requires-Python: <3.13,>=3.9
8
+ Requires-Python: <3.13,>=3.10
9
9
  Requires-Dist: huggingface-hub>=0.26.1
10
10
  Requires-Dist: lxml>=5.3.0
11
11
  Requires-Dist: omegaconf>=2.3.0
12
+ Requires-Dist: onnx>=1.17.0
13
+ Requires-Dist: onnxruntime-gpu>=1.20.1
14
+ Requires-Dist: onnxruntime>=1.20.1
12
15
  Requires-Dist: opencv-python>=4.10.0.84
13
16
  Requires-Dist: pyclipper>=1.3.0.post6
14
17
  Requires-Dist: pydantic>=2.9.2
@@ -23,7 +26,7 @@ Description-Content-Type: text/markdown
23
26
 
24
27
  <img src="static/logo/horizontal.png" width="800px">
25
28
 
26
- ![Python](https://img.shields.io/badge/Python-3.9|3.10|3.11|3.12-F9DC3E.svg?logo=python&logoColor=&style=flat)
29
+ ![Python](https://img.shields.io/badge/Python-3.10|3.11|3.12-F9DC3E.svg?logo=python&logoColor=&style=flat)
27
30
  ![Pytorch](https://img.shields.io/badge/Pytorch-2.5-EE4C2C.svg?logo=Pytorch&style=fla)
28
31
  ![CUDA](https://img.shields.io/badge/CUDA->=11.8-76B900.svg?logo=NVIDIA&style=fla)
29
32
  ![OS](https://img.shields.io/badge/OS-Linux|Mac|Win-1793D1.svg?&style=fla)
@@ -69,19 +72,20 @@ Markdown でエクスポートした結果は関してはリポジトリ内の[s
69
72
  pip install yomitoku
70
73
  ```
71
74
 
72
- - pytorch はご自身の CUDAのバージョンにあったものをインストールしてください。デフォルトではCUDA12.4以上に対応したものがインストールされます。
73
- - pytorch は2.5以上のバージョンに対応しています。その関係でCUDA11.8以上のバージョンが必要になります。対応できない場合は、リポジトリ内のDockerfileを利用してください。
75
+ - pytorch はご自身の CUDA のバージョンにあったものをインストールしてください。デフォルトでは CUDA12.4 以上に対応したものがインストールされます。
76
+ - pytorch は 2.5 以上のバージョンに対応しています。その関係で CUDA11.8 以上のバージョンが必要になります。対応できない場合は、リポジトリ内の Dockerfile を利用してください。
74
77
 
75
78
  ## 🚀 実行方法
76
79
 
77
80
  ```
78
- yomitoku ${path_data} -f md -o results -v --figure
81
+ yomitoku ${path_data} -f md -o results -v --figure --lite
79
82
  ```
80
83
 
81
84
  - `${path_data}` 解析対象の画像が含まれたディレクトリか画像ファイルのパスを直接して指定してください。ディレクトリを対象とした場合はディレクトリのサブディレクトリ内の画像も含めて処理を実行します。
82
85
  - `-f`, `--format` 出力形式のファイルフォーマットを指定します。(json, csv, html, md をサポート)
83
86
  - `-o`, `--outdir` 出力先のディレクトリ名を指定します。存在しない場合は新規で作成されます。
84
87
  - `-v`, `--vis` を指定すると解析結果を可視化した画像を出力します。
88
+ - `-l`, `--lite` を指定すると軽量モデルで推論を実行します。通常より高速に推論できますが、若干、精度が低下する可能性があります。
85
89
  - `-d`, `--device` モデルを実行するためのデバイスを指定します。gpu が利用できない場合は cpu で推論が実行されます。(デフォルト: cuda)
86
90
  - `--ignore_line_break` 画像の改行位置を無視して、段落内の文章を連結して返します。(デフォルト:画像通りの改行位置位置で改行します。)
87
91
  - `--figure_letter` 検出した図表に含まれる文字も出力ファイルにエクスポートします。
@@ -94,6 +98,7 @@ yomitoku --help
94
98
  ```
95
99
 
96
100
  **NOTE**
101
+
97
102
  - GPU での実行を推奨します。CPU を用いての推論向けに最適化されておらず、処理時間が長くなります。
98
103
  - 活字のみ識別をサポートしております。手書き文字に関しては、読み取れる場合もありますが、公式にはサポートしておりません。
99
104
  - Yomitoku は文書 OCR 向けに最適化されており、情景 OCR(看板など紙以外にプリントされた文字の読み取り)向けには最適化されていません。
@@ -107,6 +112,6 @@ yomitoku --help
107
112
 
108
113
  本リポジトリ内に格納されているソースコードおよび本プロジェクトに関連する HuggingFaceHub 上のモデルの重みファイルのライセンスは CC BY-NC-SA 4.0 に従います。
109
114
  非商用での個人利用、研究目的での利用はご自由にお使いください。
110
- 商用目的での利用に関しては、別途、商用ライセンスを提供しますので、開発者にお問い合わせください。
115
+ 商用目的での利用に関しては、別途、商用ライセンスを提供しますので、https://www.mlism.com/ にお問い合わせください。
111
116
 
112
117
  YomiToku © 2024 by Kotaro Kinoshita is licensed under CC BY-NC-SA 4.0. To view a copy of this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
@@ -3,19 +3,20 @@ yomitoku/base.py,sha256=lzR_V8t87aRasmFdFwD-8KAeSahSTI3AZaEn6g8sOv8,3871
3
3
  yomitoku/constants.py,sha256=zlW5QRc_u_F3C2RAgBFWyHJZexBnJT5N15GC-9d3iLo,686
4
4
  yomitoku/document_analyzer.py,sha256=HIg-nVzDhJIP-h-tn4uU86KakgHdlAhosEqK_i-SWe4,9906
5
5
  yomitoku/layout_analyzer.py,sha256=QTeRcVd8aySz8u6dg2ikET77ar3sqlukRLBwYfTyMPM,2033
6
- yomitoku/layout_parser.py,sha256=V2jCNHE61jNp8ytYdKwPV34V5qEK7y-7-Mq7-AkoQhU,5898
6
+ yomitoku/layout_parser.py,sha256=Yni1C_7j4fzHcdmBNNGRZPc23W_6J6HwPPQVjYvaztM,7539
7
7
  yomitoku/ocr.py,sha256=Rcojw0aGA6yDF2RjqfK23_rMw-xm61KGd8JmTCTOOVU,2516
8
8
  yomitoku/reading_order.py,sha256=OfhOS9ttPDoPSuHrIRKyOzG19GGeRufbuSKDqhsohh4,6404
9
- yomitoku/table_structure_recognizer.py,sha256=CouRzfdO_toZKUQbzQqocKdMcgA3Pr7glkZuqD5itpg,7280
10
- yomitoku/text_detector.py,sha256=okp0xuq4lXgEDcfgCzeJcrj8hfSI4NvAgorsNwi_NYI,2682
11
- yomitoku/text_recognizer.py,sha256=RHdq1M3-e3C1RECgbaoqPngtxicG3izAma12juD2ICQ,3789
9
+ yomitoku/table_structure_recognizer.py,sha256=Wf_Ehmf6V27iVLmw2o9i7kJnbwEOhuExI-ljIO3a8NE,9043
10
+ yomitoku/text_detector.py,sha256=fbwKelsVfwCt5YL4h-WEf4qkniv5cXmyaLR6oSYz0eA,4167
11
+ yomitoku/text_recognizer.py,sha256=Iu-IzwaziNjmrTeSw9aoN9BDTHkNOzsZhViCv45yiN8,5422
12
12
  yomitoku/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- yomitoku/cli/main.py,sha256=MBD0S4sXgquJ8P2egkZjJcglXvCke5Uw46C28SDtr8g,6252
14
- yomitoku/configs/__init__.py,sha256=KBhb9S7xt22HZaIcoWSgZHfscXXj9YlimOwLH5z9CRo,454
13
+ yomitoku/cli/main.py,sha256=qDB_YNK7abstIr9tYLiJjNU3xLSCd5x1UNDKqwUi2Rk,6885
14
+ yomitoku/configs/__init__.py,sha256=e1Alss5QJLZSNfD6zLEG6xu5vDQDw-4Jayiqq8bq52s,571
15
15
  yomitoku/configs/cfg_layout_parser_rtdtrv2.py,sha256=8PRxB2Ar9UF7-DLtbgSokhrzdXb0veWI6Wc-X8qigRw,2329
16
16
  yomitoku/configs/cfg_table_structure_recognizer_rtdtrv2.py,sha256=o70GMHD8k-zeBeJtuhPS8x7vVB-ffucnJXeSyn-0AXo,2116
17
17
  yomitoku/configs/cfg_text_detector_dbnet.py,sha256=U9k48PON7haoOaytiELhbZRpv9RMiUm6nnfHmdxIa9Q,1153
18
18
  yomitoku/configs/cfg_text_recognizer_parseq.py,sha256=hpFs3nKqh4XdU3BZMTultegtLEGahEsCaZdjfKC_MO8,1247
19
+ yomitoku/configs/cfg_text_recognizer_parseq_small.py,sha256=uCm_VC_G79IbZpOiK8fgYzAJ4b98H5pf328wyQomtfo,1259
19
20
  yomitoku/data/__init__.py,sha256=KAofFc9rk9ZdTKBjemu9RM8Vj9XnKbWC2MPZ2RWtOdE,82
20
21
  yomitoku/data/dataset.py,sha256=-I4f-FDtgsPnJ2MnXB7FtwihMW3koDaSI1OEoqKneIg,1014
21
22
  yomitoku/data/functions.py,sha256=eOyxo8S6EoAf1xGSPLWQFb9-t5Rg52NggD9MFIrOSpY,7506
@@ -26,19 +27,20 @@ yomitoku/export/export_json.py,sha256=1ChvCAHfCmMQvCfcAb1p3fSpr4elNAs3xBSIbpfn3b
26
27
  yomitoku/export/export_markdown.py,sha256=mCcsXUWBLrYc1NcRSBFfBT28d6eCddAF1oHp0qdBEnE,3986
27
28
  yomitoku/models/__init__.py,sha256=Enxq9sjJWusZuxecTori8IQa8NEYKaiiptDluHX1avg,144
28
29
  yomitoku/models/dbnet_plus.py,sha256=jeWJZm0ihbxoJeAXBFK7uVIwoosx2IUNk7Ut5wRH0vA,7998
29
- yomitoku/models/parseq.py,sha256=7QT-q5_oWqXTDXobRk1R6Lpap_AxdC4AzkSsOgXjOwM,8611
30
+ yomitoku/models/parseq.py,sha256=-DQMQuON2jwtb4Ib2V0O19un9w-WG4rXS0SiscydrXU,8593
30
31
  yomitoku/models/rtdetr.py,sha256=oJsr8RHz3frslhLfXdVJve47lUsrmqLjfdTrZ41tlQ0,687
31
32
  yomitoku/models/layers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
32
- yomitoku/models/layers/activate.py,sha256=HUw0q-76RNjZF-o9O3fowfJcw0t1H5o0pbyioGdqUvU,668
33
+ yomitoku/models/layers/activate.py,sha256=S54GPssZBMloM2oFAXeDVMmBBZOWyjwU98Niq758txE,1244
33
34
  yomitoku/models/layers/dbnet_feature_attention.py,sha256=Vpp_PiLVuI7Zs30TTg4RNRn16KTb81ewonADpUHd4aE,6060
34
35
  yomitoku/models/layers/parseq_transformer.py,sha256=33eroJf8rmgIptP-NpZLJMhG7XOTwV4rXsq674VrKnU,6704
35
- yomitoku/models/layers/rtdetr_backbone.py,sha256=QjfLW-3qn2My3Jbg6yLORX8A-D2sph9J9u3r5nNnDLo,9386
36
- yomitoku/models/layers/rtdetr_hybrid_encoder.py,sha256=D3dK37k7_0jPqV39-6Se8kBzF_SyZttNlbLleyNFiJU,13607
37
- yomitoku/models/layers/rtdetrv2_decoder.py,sha256=5bVYPLFYCy3PcjyHTPFHNLWqg3bctrk-dKVG4kayhaw,27517
36
+ yomitoku/models/layers/rtdetr_backbone.py,sha256=VOWFW7XFfJl4cvPaupqqP4-I-YHdwlVltQEgliD69As,9904
37
+ yomitoku/models/layers/rtdetr_hybrid_encoder.py,sha256=ZnpEzJLzHgu_hrx7YK6myXZ4F1CDHRM501RbAPQdzdQ,14125
38
+ yomitoku/models/layers/rtdetrv2_decoder.py,sha256=ggUwTdWpBfyYHnZuLx8vyH8n0XfZkQFtxgpY-1YI2sI,28070
39
+ yomitoku/onnx/.gitkeep,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
38
40
  yomitoku/postprocessor/__init__.py,sha256=W4vUuqBaFtH5dlSBIYgyaCroGLMjpV6RrNGIBQ8NFVw,243
39
41
  yomitoku/postprocessor/dbnet_postporcessor.py,sha256=o_y8b5REd2dFEdIpRcr6o-XBfOCHo9rBYGwokP_uhTc,4948
40
42
  yomitoku/postprocessor/parseq_tokenizer.py,sha256=e89_g_bc4Au3SchuxoJfJNATJTxFmVYetzXyAzPWm28,4315
41
- yomitoku/postprocessor/rtdetr_postprocessor.py,sha256=f52wfRKrxqSXy_LeidKDR9XAta_qPjto-oYEdO0XL8A,3386
43
+ yomitoku/postprocessor/rtdetr_postprocessor.py,sha256=TCv1t1zCxg2rSirsLm4sXlaltGubH-roVdEqnUoRs-8,3905
42
44
  yomitoku/resource/MPLUS1p-Medium.ttf,sha256=KLL1KkCumIBkgQtx1n4SffdaFuCNffThktEAbkB1OU8,1758908
43
45
  yomitoku/resource/charset.txt,sha256=sU91kSi-9Wk4733bCXy4j_UDmvcsj96sHOq1ppUJlOY,21672
44
46
  yomitoku/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -46,7 +48,7 @@ yomitoku/utils/graph.py,sha256=LKNB8ZhSQwOZMfeAimPMF5UCVVr2ZaUWoGDkz8z-uGU,456
46
48
  yomitoku/utils/logger.py,sha256=uOmtQDr0A0JD7wyFshedL08BiNrQorHnpktRXba8bjU,424
47
49
  yomitoku/utils/misc.py,sha256=2Eyy7-9K_h4Mal1zGXq6OlxubfNzhS0mEYwn_xt7xl8,2497
48
50
  yomitoku/utils/visualizer.py,sha256=2pSmbhUPylzVVJ0bXtGDoNmMdArAByab4Py7Xavvs_A,5230
49
- yomitoku-0.5.2.dist-info/METADATA,sha256=qG0aq8sJb6iD-i0WvZL__YclRytpBdzyPzu6HNqtgIM,7819
50
- yomitoku-0.5.2.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
51
- yomitoku-0.5.2.dist-info/entry_points.txt,sha256=nFV3S11zgBNW0Qq_D0XQNg2R4lNXU_9XUFr6rdJoyF8,52
52
- yomitoku-0.5.2.dist-info/RECORD,,
51
+ yomitoku-0.6.0.dist-info/METADATA,sha256=XDmMBtDx9MjXPuzcARwOwJXRN7PMCsQDwc38jDSwX5g,8134
52
+ yomitoku-0.6.0.dist-info/WHEEL,sha256=C2FUgwZgiLbznR-k0b_5k3Ai_1aASOXDss3lzCUsUug,87
53
+ yomitoku-0.6.0.dist-info/entry_points.txt,sha256=nFV3S11zgBNW0Qq_D0XQNg2R4lNXU_9XUFr6rdJoyF8,52
54
+ yomitoku-0.6.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.25.0
2
+ Generator: hatchling 1.26.3
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any