yomitoku 0.4.1__py3-none-any.whl → 0.7.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. yomitoku/base.py +1 -1
  2. yomitoku/cli/main.py +219 -27
  3. yomitoku/configs/__init__.py +2 -0
  4. yomitoku/configs/cfg_text_detector_dbnet.py +1 -1
  5. yomitoku/configs/cfg_text_recognizer_parseq_small.py +51 -0
  6. yomitoku/data/functions.py +48 -23
  7. yomitoku/document_analyzer.py +243 -41
  8. yomitoku/export/__init__.py +18 -5
  9. yomitoku/export/export_csv.py +71 -2
  10. yomitoku/export/export_html.py +46 -12
  11. yomitoku/export/export_json.py +66 -3
  12. yomitoku/export/export_markdown.py +42 -6
  13. yomitoku/layout_analyzer.py +2 -9
  14. yomitoku/layout_parser.py +58 -4
  15. yomitoku/models/dbnet_plus.py +13 -39
  16. yomitoku/models/layers/activate.py +13 -0
  17. yomitoku/models/layers/rtdetr_backbone.py +18 -17
  18. yomitoku/models/layers/rtdetr_hybrid_encoder.py +19 -20
  19. yomitoku/models/layers/rtdetrv2_decoder.py +14 -1
  20. yomitoku/models/parseq.py +15 -22
  21. yomitoku/ocr.py +24 -27
  22. yomitoku/onnx/.gitkeep +0 -0
  23. yomitoku/postprocessor/dbnet_postporcessor.py +15 -14
  24. yomitoku/postprocessor/parseq_tokenizer.py +1 -3
  25. yomitoku/postprocessor/rtdetr_postprocessor.py +14 -1
  26. yomitoku/table_structure_recognizer.py +82 -9
  27. yomitoku/text_detector.py +57 -7
  28. yomitoku/text_recognizer.py +84 -16
  29. yomitoku/utils/misc.py +21 -14
  30. yomitoku/utils/visualizer.py +15 -8
  31. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/METADATA +34 -41
  32. yomitoku-0.7.4.dist-info/RECORD +54 -0
  33. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/WHEEL +1 -1
  34. yomitoku-0.4.1.dist-info/RECORD +0 -52
  35. {yomitoku-0.4.1.dist-info → yomitoku-0.7.4.dist-info}/entry_points.txt +0 -0
@@ -1,10 +1,11 @@
1
+ import os
1
2
  import re
3
+
2
4
  import cv2
3
- import os
4
5
 
5
6
 
6
7
  def escape_markdown_special_chars(text):
7
- special_chars = r"([`*_{}[\]()#+.!|-])"
8
+ special_chars = r"([`*{}[\]()#+!~|-])"
8
9
  return re.sub(special_chars, r"\\\1", text)
9
10
 
10
11
 
@@ -75,6 +76,8 @@ def figure_to_md(
75
76
  width=200,
76
77
  figure_dir="figures",
77
78
  ):
79
+ assert img is not None, "img is required for saving figures"
80
+
78
81
  elements = []
79
82
  for i, figure in enumerate(figures):
80
83
  x1, y1, x2, y2 = map(int, figure.box)
@@ -108,11 +111,11 @@ def figure_to_md(
108
111
  return elements
109
112
 
110
113
 
111
- def export_markdown(
114
+ def convert_markdown(
112
115
  inputs,
113
- out_path: str,
116
+ out_path,
117
+ ignore_line_break=False,
114
118
  img=None,
115
- ignore_line_break: bool = False,
116
119
  export_figure_letter=False,
117
120
  export_figure=True,
118
121
  figure_width=200,
@@ -140,6 +143,39 @@ def export_markdown(
140
143
 
141
144
  elements = sorted(elements, key=lambda x: x["order"])
142
145
  markdown = "\n".join([element["md"] for element in elements])
146
+ return markdown, elements
143
147
 
144
- with open(out_path, "w", encoding="utf-8") as f:
148
+
149
+ def export_markdown(
150
+ inputs,
151
+ out_path: str,
152
+ ignore_line_break: bool = False,
153
+ img=None,
154
+ export_figure_letter=False,
155
+ export_figure=True,
156
+ figure_width=200,
157
+ figure_dir="figures",
158
+ encoding: str = "utf-8",
159
+ ):
160
+ markdown, elements = convert_markdown(
161
+ inputs,
162
+ out_path,
163
+ ignore_line_break,
164
+ img,
165
+ export_figure_letter,
166
+ export_figure,
167
+ figure_width,
168
+ figure_dir,
169
+ )
170
+
171
+ save_markdown(markdown, out_path, encoding)
172
+ return markdown
173
+
174
+
175
+ def save_markdown(
176
+ markdown,
177
+ out_path,
178
+ encoding,
179
+ ):
180
+ with open(out_path, "w", encoding=encoding, errors="ignore") as f:
145
181
  f.write(markdown)
@@ -15,7 +15,7 @@ class LayoutAnalyzerSchema(BaseSchema):
15
15
 
16
16
 
17
17
  class LayoutAnalyzer:
18
- def __init__(self, configs=None, device="cuda", visualize=False):
18
+ def __init__(self, configs={}, device="cuda", visualize=False):
19
19
  layout_parser_kwargs = {
20
20
  "device": device,
21
21
  "visualize": visualize,
@@ -26,11 +26,6 @@ class LayoutAnalyzer:
26
26
  }
27
27
 
28
28
  if isinstance(configs, dict):
29
- assert (
30
- "layout_parser" in configs
31
- or "table_structure_recognizer" in configs
32
- ), "Invalid config key. Please check the config keys."
33
-
34
29
  if "layout_parser" in configs:
35
30
  layout_parser_kwargs.update(configs["layout_parser"])
36
31
 
@@ -53,9 +48,7 @@ class LayoutAnalyzer:
53
48
  def __call__(self, img):
54
49
  layout_results, vis = self.layout_parser(img)
55
50
  table_boxes = [table.box for table in layout_results.tables]
56
- table_results, vis = self.table_structure_recognizer(
57
- img, table_boxes, vis=vis
58
- )
51
+ table_results, vis = self.table_structure_recognizer(img, table_boxes, vis=vis)
59
52
 
60
53
  results = LayoutAnalyzerSchema(
61
54
  paragraphs=layout_results.paragraphs,
yomitoku/layout_parser.py CHANGED
@@ -1,11 +1,16 @@
1
1
  from typing import List, Union
2
2
 
3
3
  import cv2
4
+ import os
5
+ import onnx
6
+ import onnxruntime
4
7
  import torch
5
8
  import torchvision.transforms as T
6
9
  from PIL import Image
7
10
  from pydantic import conlist
8
11
 
12
+ from .constants import ROOT_DIR
13
+
9
14
  from .base import BaseModelCatalog, BaseModule, BaseSchema
10
15
  from .configs import LayoutParserRTDETRv2Config
11
16
  from .models import RTDETRv2
@@ -91,6 +96,7 @@ class LayoutParser(BaseModule):
91
96
  device="cuda",
92
97
  visualize=False,
93
98
  from_pretrained=True,
99
+ infer_onnx=False,
94
100
  ):
95
101
  super().__init__()
96
102
  self.load_model(model_name, path_cfg, from_pretrained)
@@ -98,7 +104,6 @@ class LayoutParser(BaseModule):
98
104
  self.visualize = visualize
99
105
 
100
106
  self.model.eval()
101
- self.model.to(self.device)
102
107
 
103
108
  self.postprocessor = RTDETRPostProcessor(
104
109
  num_classes=self._cfg.RTDETRTransformerv2.num_classes,
@@ -119,11 +124,49 @@ class LayoutParser(BaseModule):
119
124
  }
120
125
 
121
126
  self.role = self._cfg.role
127
+ self.infer_onnx = infer_onnx
128
+ if infer_onnx:
129
+ name = self._cfg.hf_hub_repo.split("/")[-1]
130
+ path_onnx = f"{ROOT_DIR}/onnx/{name}.onnx"
131
+ if not os.path.exists(path_onnx):
132
+ self.convert_onnx(path_onnx)
133
+
134
+ self.model = None
135
+
136
+ model = onnx.load(path_onnx)
137
+ if torch.cuda.is_available() and device == "cuda":
138
+ self.sess = onnxruntime.InferenceSession(
139
+ model.SerializeToString(), providers=["CUDAExecutionProvider"]
140
+ )
141
+ else:
142
+ self.sess = onnxruntime.InferenceSession(model.SerializeToString())
143
+
144
+ if self.model is not None:
145
+ self.model.to(self.device)
146
+
147
+ def convert_onnx(self, path_onnx):
148
+ dynamic_axes = {
149
+ "input": {0: "batch_size"},
150
+ "output": {0: "batch_size"},
151
+ }
152
+
153
+ img_size = self._cfg.data.img_size
154
+ dummy_input = torch.randn(1, 3, *img_size, requires_grad=True)
155
+
156
+ torch.onnx.export(
157
+ self.model,
158
+ dummy_input,
159
+ path_onnx,
160
+ opset_version=16,
161
+ input_names=["input"],
162
+ output_names=["pred_logits", "pred_boxes"],
163
+ dynamic_axes=dynamic_axes,
164
+ )
122
165
 
123
166
  def preprocess(self, img):
124
167
  cv_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
125
168
  img = Image.fromarray(cv_img)
126
- img_tensor = self.transforms(img)[None].to(self.device)
169
+ img_tensor = self.transforms(img)[None]
127
170
  return img_tensor
128
171
 
129
172
  def postprocess(self, preds, image_size):
@@ -175,8 +218,19 @@ class LayoutParser(BaseModule):
175
218
  ori_h, ori_w = img.shape[:2]
176
219
  img_tensor = self.preprocess(img)
177
220
 
178
- with torch.inference_mode():
179
- preds = self.model(img_tensor)
221
+ if self.infer_onnx:
222
+ input = img_tensor.numpy()
223
+ results = self.sess.run(None, {"input": input})
224
+ preds = {
225
+ "pred_logits": torch.tensor(results[0]).to(self.device),
226
+ "pred_boxes": torch.tensor(results[1]).to(self.device),
227
+ }
228
+
229
+ else:
230
+ with torch.inference_mode():
231
+ img_tensor = img_tensor.to(self.device)
232
+ preds = self.model(img_tensor)
233
+
180
234
  results = self.postprocess(preds, (ori_h, ori_w))
181
235
 
182
236
  vis = None
@@ -20,9 +20,7 @@ class BackboneBase(nn.Module):
20
20
  "layer4": "layer4",
21
21
  }
22
22
 
23
- self.body = IntermediateLayerGetter(
24
- backbone, return_layers=return_layers
25
- )
23
+ self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
26
24
 
27
25
  def forward(self, tensor):
28
26
  xs = self.body(tensor)
@@ -57,18 +55,10 @@ class DBNetDecoder(nn.Module):
57
55
  self.training = True
58
56
  self.input_proj = nn.ModuleDict(
59
57
  {
60
- "layer1": nn.Conv2d(
61
- in_channels[0], self.d_model, 1, bias=False
62
- ),
63
- "layer2": nn.Conv2d(
64
- in_channels[1], self.d_model, 1, bias=False
65
- ),
66
- "layer3": nn.Conv2d(
67
- in_channels[2], self.d_model, 1, bias=False
68
- ),
69
- "layer4": nn.Conv2d(
70
- in_channels[3], self.d_model, 1, bias=False
71
- ),
58
+ "layer1": nn.Conv2d(in_channels[0], self.d_model, 1, bias=False),
59
+ "layer2": nn.Conv2d(in_channels[1], self.d_model, 1, bias=False),
60
+ "layer3": nn.Conv2d(in_channels[2], self.d_model, 1, bias=False),
61
+ "layer4": nn.Conv2d(in_channels[3], self.d_model, 1, bias=False),
72
62
  }
73
63
  )
74
64
 
@@ -89,9 +79,7 @@ class DBNetDecoder(nn.Module):
89
79
  padding=1,
90
80
  bias=False,
91
81
  ),
92
- nn.Upsample(
93
- scale_factor=2, mode="bilinear", align_corners=False
94
- ),
82
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
95
83
  ),
96
84
  "layer3": nn.Sequential(
97
85
  nn.Conv2d(
@@ -101,9 +89,7 @@ class DBNetDecoder(nn.Module):
101
89
  padding=1,
102
90
  bias=False,
103
91
  ),
104
- nn.Upsample(
105
- scale_factor=4, mode="bilinear", align_corners=False
106
- ),
92
+ nn.Upsample(scale_factor=4, mode="bilinear", align_corners=False),
107
93
  ),
108
94
  "layer4": nn.Sequential(
109
95
  nn.Conv2d(
@@ -113,17 +99,13 @@ class DBNetDecoder(nn.Module):
113
99
  padding=1,
114
100
  bias=False,
115
101
  ),
116
- nn.Upsample(
117
- scale_factor=4, mode="bilinear", align_corners=False
118
- ),
102
+ nn.Upsample(scale_factor=4, mode="bilinear", align_corners=False),
119
103
  ),
120
104
  }
121
105
  )
122
106
 
123
107
  self.binarize = nn.Sequential(
124
- nn.Conv2d(
125
- self.d_model, self.d_model // 4, 3, padding=1, bias=False
126
- ),
108
+ nn.Conv2d(self.d_model, self.d_model // 4, 3, padding=1, bias=False),
127
109
  nn.BatchNorm2d(self.d_model // 4),
128
110
  nn.ReLU(inplace=True),
129
111
  nn.ConvTranspose2d(self.d_model // 4, self.d_model // 4, 2, 2),
@@ -166,16 +148,12 @@ class DBNetDecoder(nn.Module):
166
148
  m.weight.data.fill_(1.0)
167
149
  m.bias.data.fill_(1e-4)
168
150
 
169
- def _init_thresh(
170
- self, inner_channels, serial=False, smooth=False, bias=False
171
- ):
151
+ def _init_thresh(self, inner_channels, serial=False, smooth=False, bias=False):
172
152
  in_channels = inner_channels
173
153
  if serial:
174
154
  in_channels += 1
175
155
  self.thresh = nn.Sequential(
176
- nn.Conv2d(
177
- in_channels, inner_channels // 4, 3, padding=1, bias=bias
178
- ),
156
+ nn.Conv2d(in_channels, inner_channels // 4, 3, padding=1, bias=bias),
179
157
  nn.BatchNorm2d(inner_channels // 4),
180
158
  nn.ReLU(inplace=True),
181
159
  self._init_upsample(
@@ -186,16 +164,12 @@ class DBNetDecoder(nn.Module):
186
164
  ),
187
165
  nn.BatchNorm2d(inner_channels // 4),
188
166
  nn.ReLU(inplace=True),
189
- self._init_upsample(
190
- inner_channels // 4, 1, smooth=smooth, bias=bias
191
- ),
167
+ self._init_upsample(inner_channels // 4, 1, smooth=smooth, bias=bias),
192
168
  nn.Sigmoid(),
193
169
  )
194
170
  return self.thresh
195
171
 
196
- def _init_upsample(
197
- self, in_channels, out_channels, smooth=False, bias=False
198
- ):
172
+ def _init_upsample(self, in_channels, out_channels, smooth=False, bias=False):
199
173
  if smooth:
200
174
  inter_out_channels = out_channels
201
175
  if out_channels == 1:
@@ -1,3 +1,16 @@
1
+ # Copyright(c) 2023 lyuwenyu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
1
14
  import torch.nn as nn
2
15
 
3
16
 
@@ -1,5 +1,16 @@
1
- """Copyright(c) 2023 lyuwenyu. All Rights Reserved.
2
- """
1
+ # Copyright 2023 lyuwenyu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
3
14
 
4
15
  from collections import OrderedDict
5
16
 
@@ -48,9 +59,7 @@ class ConvNormLayer(nn.Module):
48
59
  class BasicBlock(nn.Module):
49
60
  expansion = 1
50
61
 
51
- def __init__(
52
- self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"
53
- ):
62
+ def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
54
63
  super().__init__()
55
64
 
56
65
  self.shortcut = shortcut
@@ -89,9 +98,7 @@ class BasicBlock(nn.Module):
89
98
  class BottleNeck(nn.Module):
90
99
  expansion = 4
91
100
 
92
- def __init__(
93
- self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"
94
- ):
101
+ def __init__(self, ch_in, ch_out, stride, shortcut, act="relu", variant="b"):
95
102
  super().__init__()
96
103
 
97
104
  if variant == "a":
@@ -114,17 +121,13 @@ class BottleNeck(nn.Module):
114
121
  ("pool", nn.AvgPool2d(2, 2, 0, ceil_mode=True)),
115
122
  (
116
123
  "conv",
117
- ConvNormLayer(
118
- ch_in, ch_out * self.expansion, 1, 1
119
- ),
124
+ ConvNormLayer(ch_in, ch_out * self.expansion, 1, 1),
120
125
  ),
121
126
  ]
122
127
  )
123
128
  )
124
129
  else:
125
- self.short = ConvNormLayer(
126
- ch_in, ch_out * self.expansion, 1, stride
127
- )
130
+ self.short = ConvNormLayer(ch_in, ch_out * self.expansion, 1, stride)
128
131
 
129
132
  self.act = nn.Identity() if act is None else get_activation(act)
130
133
 
@@ -145,9 +148,7 @@ class BottleNeck(nn.Module):
145
148
 
146
149
 
147
150
  class Blocks(nn.Module):
148
- def __init__(
149
- self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"
150
- ):
151
+ def __init__(self, block, ch_in, ch_out, count, stage_num, act="relu", variant="b"):
151
152
  super().__init__()
152
153
 
153
154
  self.blocks = nn.ModuleList()
@@ -1,5 +1,16 @@
1
- """Copyright(c) 2023 lyuwenyu. All Rights Reserved.
2
- """
1
+ # Copyright 2023 lyuwenyu
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
3
14
 
4
15
  import copy
5
16
  from collections import OrderedDict
@@ -241,9 +252,7 @@ class HybridEncoder(nn.Module):
241
252
  for in_channel in in_channels:
242
253
  if version == "v1":
243
254
  proj = nn.Sequential(
244
- nn.Conv2d(
245
- in_channel, hidden_dim, kernel_size=1, bias=False
246
- ),
255
+ nn.Conv2d(in_channel, hidden_dim, kernel_size=1, bias=False),
247
256
  nn.BatchNorm2d(hidden_dim),
248
257
  )
249
258
  elif version == "v2":
@@ -279,9 +288,7 @@ class HybridEncoder(nn.Module):
279
288
 
280
289
  self.encoder = nn.ModuleList(
281
290
  [
282
- TransformerEncoder(
283
- copy.deepcopy(encoder_layer), num_encoder_layers
284
- )
291
+ TransformerEncoder(copy.deepcopy(encoder_layer), num_encoder_layers)
285
292
  for _ in range(len(use_encoder_idx))
286
293
  ]
287
294
  )
@@ -336,9 +343,7 @@ class HybridEncoder(nn.Module):
336
343
  # self.register_buffer(f'pos_embed{idx}', pos_embed)
337
344
 
338
345
  @staticmethod
339
- def build_2d_sincos_position_embedding(
340
- w, h, embed_dim=256, temperature=10000.0
341
- ):
346
+ def build_2d_sincos_position_embedding(w, h, embed_dim=256, temperature=10000.0):
342
347
  """ """
343
348
  grid_w = torch.arange(int(w), dtype=torch.float32)
344
349
  grid_h = torch.arange(int(h), dtype=torch.float32)
@@ -376,9 +381,7 @@ class HybridEncoder(nn.Module):
376
381
  src_flatten.device
377
382
  )
378
383
 
379
- memory: torch.Tensor = self.encoder[i](
380
- src_flatten, pos_embed=pos_embed
381
- )
384
+ memory: torch.Tensor = self.encoder[i](src_flatten, pos_embed=pos_embed)
382
385
  proj_feats[enc_ind] = (
383
386
  memory.permute(0, 2, 1)
384
387
  .reshape(-1, self.hidden_dim, h, w)
@@ -390,13 +393,9 @@ class HybridEncoder(nn.Module):
390
393
  for idx in range(len(self.in_channels) - 1, 0, -1):
391
394
  feat_heigh = inner_outs[0]
392
395
  feat_low = proj_feats[idx - 1]
393
- feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](
394
- feat_heigh
395
- )
396
+ feat_heigh = self.lateral_convs[len(self.in_channels) - 1 - idx](feat_heigh)
396
397
  inner_outs[0] = feat_heigh
397
- upsample_feat = F.interpolate(
398
- feat_heigh, scale_factor=2.0, mode="nearest"
399
- )
398
+ upsample_feat = F.interpolate(feat_heigh, scale_factor=2.0, mode="nearest")
400
399
  inner_out = self.fpn_blocks[len(self.in_channels) - 1 - idx](
401
400
  torch.concat([upsample_feat, feat_low], dim=1)
402
401
  )
@@ -1,4 +1,17 @@
1
- """Copyright(c) 2023 lyuwenyu. All Rights Reserved."""
1
+ # Scene Text Recognition Model Hub
2
+ # Copyright 2023 lyuwenyu
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
2
15
 
3
16
  import copy
4
17
  import functools
yomitoku/models/parseq.py CHANGED
@@ -22,13 +22,10 @@ from huggingface_hub import PyTorchModelHubMixin
22
22
  from timm.models.helpers import named_apply
23
23
  from torch import Tensor
24
24
 
25
- from ..postprocessor import ParseqTokenizer as Tokenizer
26
25
  from .layers.parseq_transformer import Decoder, Encoder, TokenEmbedding
27
26
 
28
27
 
29
- def init_weights(
30
- module: nn.Module, name: str = "", exclude: Sequence[str] = ()
31
- ):
28
+ def init_weights(module: nn.Module, name: str = "", exclude: Sequence[str] = ()):
32
29
  """Initialize the weights using the typical initialization schemes used in SOTA models."""
33
30
  if any(map(name.startswith, exclude)):
34
31
  return
@@ -41,9 +38,7 @@ def init_weights(
41
38
  if module.padding_idx is not None:
42
39
  module.weight.data[module.padding_idx].zero_()
43
40
  elif isinstance(module, nn.Conv2d):
44
- nn.init.kaiming_normal_(
45
- module.weight, mode="fan_out", nonlinearity="relu"
46
- )
41
+ nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
47
42
  if module.bias is not None:
48
43
  nn.init.zeros_(module.bias)
49
44
  elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
@@ -86,6 +81,8 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
86
81
  named_apply(partial(init_weights, exclude=["encoder"]), self)
87
82
  nn.init.trunc_normal_(self.pos_queries, std=0.02)
88
83
 
84
+ self.export_onnx = False
85
+
89
86
  @property
90
87
  def _device(self) -> torch.device:
91
88
  return next(self.head.parameters(recurse=False)).device
@@ -93,9 +90,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
93
90
  @torch.jit.ignore
94
91
  def no_weight_decay(self):
95
92
  param_names = {"text_embed.embedding.weight", "pos_queries"}
96
- enc_param_names = {
97
- "encoder." + n for n in self.encoder.no_weight_decay()
98
- }
93
+ enc_param_names = {"encoder." + n for n in self.encoder.no_weight_decay()}
99
94
  return param_names.union(enc_param_names)
100
95
 
101
96
  def encode(self, img: torch.Tensor):
@@ -129,7 +124,6 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
129
124
 
130
125
  def forward(
131
126
  self,
132
- tokenizer: Tokenizer,
133
127
  images: Tensor,
134
128
  max_length: Optional[int] = None,
135
129
  ) -> Tensor:
@@ -149,20 +143,18 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
149
143
 
150
144
  # Special case for the forward permutation. Faster than using `generate_attn_masks()`
151
145
  tgt_mask = query_mask = torch.triu(
152
- torch.ones(
153
- (num_steps, num_steps), dtype=torch.bool, device=self._device
154
- ),
146
+ torch.ones((num_steps, num_steps), dtype=torch.bool, device=self._device),
155
147
  1,
156
148
  )
157
149
 
158
150
  if self.decode_ar:
159
151
  tgt_in = torch.full(
160
152
  (bs, num_steps),
161
- tokenizer.pad_id,
153
+ self.tokenizer.pad_id,
162
154
  dtype=torch.long,
163
155
  device=self._device,
164
156
  )
165
- tgt_in[:, 0] = tokenizer.bos_id
157
+ tgt_in[:, 0] = self.tokenizer.bos_id
166
158
 
167
159
  logits = []
168
160
  for i in range(num_steps):
@@ -186,8 +178,9 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
186
178
  tgt_in[:, j] = p_i.squeeze().argmax(-1)
187
179
  # Efficient batch decoding: If all output words have at least one EOS token, end decoding.
188
180
  if (
189
- testing
190
- and (tgt_in == tokenizer.eos_id).any(dim=-1).all()
181
+ not self.export_onnx
182
+ and testing
183
+ and (tgt_in == self.tokenizer.eos_id).any(dim=-1).all()
191
184
  ):
192
185
  break
193
186
 
@@ -196,7 +189,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
196
189
  # No prior context, so input is just <bos>. We query all positions.
197
190
  tgt_in = torch.full(
198
191
  (bs, 1),
199
- tokenizer.bos_id,
192
+ self.tokenizer.bos_id,
200
193
  dtype=torch.long,
201
194
  device=self._device,
202
195
  )
@@ -211,7 +204,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
211
204
  torch.ones(
212
205
  num_steps,
213
206
  num_steps,
214
- dtype=torch.bool,
207
+ dtype=torch.int64,
215
208
  device=self._device,
216
209
  ),
217
210
  2,
@@ -219,7 +212,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
219
212
  ] = 0
220
213
  bos = torch.full(
221
214
  (bs, 1),
222
- tokenizer.bos_id,
215
+ self.tokenizer.bos_id,
223
216
  dtype=torch.long,
224
217
  device=self._device,
225
218
  )
@@ -227,7 +220,7 @@ class PARSeq(nn.Module, PyTorchModelHubMixin):
227
220
  # Prior context is the previous output.
228
221
  tgt_in = torch.cat([bos, logits[:, :-1].argmax(-1)], dim=1)
229
222
  # Mask tokens beyond the first EOS token.
230
- tgt_padding_mask = (tgt_in == tokenizer.eos_id).int().cumsum(
223
+ tgt_padding_mask = (tgt_in == self.tokenizer.eos_id).int().cumsum(
231
224
  -1
232
225
  ) > 0
233
226
  tgt_out = self.decode(