python-doctr 0.11.0__py3-none-any.whl → 0.12.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. doctr/datasets/__init__.py +1 -0
  2. doctr/datasets/coco_text.py +139 -0
  3. doctr/datasets/cord.py +2 -1
  4. doctr/datasets/funsd.py +2 -2
  5. doctr/datasets/ic03.py +1 -1
  6. doctr/datasets/ic13.py +2 -1
  7. doctr/datasets/iiit5k.py +4 -1
  8. doctr/datasets/imgur5k.py +9 -2
  9. doctr/datasets/loader.py +1 -1
  10. doctr/datasets/ocr.py +1 -1
  11. doctr/datasets/recognition.py +1 -1
  12. doctr/datasets/svhn.py +1 -1
  13. doctr/datasets/svt.py +2 -2
  14. doctr/datasets/synthtext.py +15 -2
  15. doctr/datasets/utils.py +7 -6
  16. doctr/datasets/vocabs.py +1102 -54
  17. doctr/file_utils.py +9 -0
  18. doctr/io/elements.py +37 -3
  19. doctr/models/_utils.py +1 -1
  20. doctr/models/classification/__init__.py +1 -0
  21. doctr/models/classification/magc_resnet/pytorch.py +1 -2
  22. doctr/models/classification/magc_resnet/tensorflow.py +3 -3
  23. doctr/models/classification/mobilenet/pytorch.py +15 -1
  24. doctr/models/classification/mobilenet/tensorflow.py +11 -2
  25. doctr/models/classification/predictor/pytorch.py +1 -1
  26. doctr/models/classification/resnet/pytorch.py +26 -3
  27. doctr/models/classification/resnet/tensorflow.py +25 -4
  28. doctr/models/classification/textnet/pytorch.py +10 -1
  29. doctr/models/classification/textnet/tensorflow.py +11 -2
  30. doctr/models/classification/vgg/pytorch.py +16 -1
  31. doctr/models/classification/vgg/tensorflow.py +11 -2
  32. doctr/models/classification/vip/__init__.py +4 -0
  33. doctr/models/classification/vip/layers/__init__.py +4 -0
  34. doctr/models/classification/vip/layers/pytorch.py +615 -0
  35. doctr/models/classification/vip/pytorch.py +505 -0
  36. doctr/models/classification/vit/pytorch.py +10 -1
  37. doctr/models/classification/vit/tensorflow.py +9 -0
  38. doctr/models/classification/zoo.py +4 -0
  39. doctr/models/detection/differentiable_binarization/base.py +3 -4
  40. doctr/models/detection/differentiable_binarization/pytorch.py +10 -1
  41. doctr/models/detection/differentiable_binarization/tensorflow.py +11 -4
  42. doctr/models/detection/fast/base.py +2 -3
  43. doctr/models/detection/fast/pytorch.py +13 -4
  44. doctr/models/detection/fast/tensorflow.py +10 -2
  45. doctr/models/detection/linknet/base.py +2 -3
  46. doctr/models/detection/linknet/pytorch.py +10 -1
  47. doctr/models/detection/linknet/tensorflow.py +10 -2
  48. doctr/models/factory/hub.py +3 -3
  49. doctr/models/kie_predictor/pytorch.py +1 -1
  50. doctr/models/kie_predictor/tensorflow.py +1 -1
  51. doctr/models/modules/layers/pytorch.py +49 -1
  52. doctr/models/predictor/pytorch.py +1 -1
  53. doctr/models/predictor/tensorflow.py +1 -1
  54. doctr/models/recognition/__init__.py +1 -0
  55. doctr/models/recognition/crnn/pytorch.py +10 -1
  56. doctr/models/recognition/crnn/tensorflow.py +10 -1
  57. doctr/models/recognition/master/pytorch.py +10 -1
  58. doctr/models/recognition/master/tensorflow.py +10 -3
  59. doctr/models/recognition/parseq/pytorch.py +23 -5
  60. doctr/models/recognition/parseq/tensorflow.py +13 -5
  61. doctr/models/recognition/predictor/_utils.py +107 -45
  62. doctr/models/recognition/predictor/pytorch.py +3 -3
  63. doctr/models/recognition/predictor/tensorflow.py +3 -3
  64. doctr/models/recognition/sar/pytorch.py +10 -1
  65. doctr/models/recognition/sar/tensorflow.py +10 -3
  66. doctr/models/recognition/utils.py +56 -47
  67. doctr/models/recognition/viptr/__init__.py +4 -0
  68. doctr/models/recognition/viptr/pytorch.py +277 -0
  69. doctr/models/recognition/vitstr/pytorch.py +10 -1
  70. doctr/models/recognition/vitstr/tensorflow.py +10 -3
  71. doctr/models/recognition/zoo.py +5 -0
  72. doctr/models/utils/pytorch.py +28 -18
  73. doctr/models/utils/tensorflow.py +15 -8
  74. doctr/utils/data.py +1 -1
  75. doctr/utils/geometry.py +1 -1
  76. doctr/version.py +1 -1
  77. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/METADATA +19 -3
  78. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/RECORD +82 -75
  79. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/WHEEL +1 -1
  80. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info/licenses}/LICENSE +0 -0
  81. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/top_level.txt +0 -0
  82. {python_doctr-0.11.0.dist-info → python_doctr-0.12.0.dist-info}/zip-safe +0 -0
@@ -153,6 +153,15 @@ class FAST(_FAST, Model, NestedObject):
153
153
  # Pooling layer as erosion reversal as described in the paper
154
154
  self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same")
155
155
 
156
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
157
+ """Load pretrained parameters onto the model
158
+
159
+ Args:
160
+ path_or_url: the path or URL to the model parameters (checkpoint)
161
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
162
+ """
163
+ load_pretrained_params(self, path_or_url, **kwargs)
164
+
156
165
  def compute_loss(
157
166
  self,
158
167
  out_map: tf.Tensor,
@@ -332,8 +341,7 @@ def _fast(
332
341
  # Load pretrained parameters
333
342
  if pretrained:
334
343
  # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
335
- load_pretrained_params(
336
- model,
344
+ model.from_pretrained(
337
345
  _cfg["url"],
338
346
  skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
339
347
  )
@@ -56,9 +56,8 @@ class LinkNetPostProcessor(DetectionPostProcessor):
56
56
  area = (rect[1][0] + 1) * (1 + rect[1][1])
57
57
  length = 2 * (rect[1][0] + rect[1][1]) + 2
58
58
  else:
59
- poly = Polygon(points)
60
- area = poly.area
61
- length = poly.length
59
+ area = cv2.contourArea(points)
60
+ length = cv2.arcLength(points, closed=True)
62
61
  distance = area * self.unclip_ratio / length # compute distance to expand polygon
63
62
  offset = pyclipper.PyclipperOffset()
64
63
  offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
@@ -160,6 +160,15 @@ class LinkNet(nn.Module, _LinkNet):
160
160
  m.weight.data.fill_(1.0)
161
161
  m.bias.data.zero_()
162
162
 
163
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
164
+ """Load pretrained parameters onto the model
165
+
166
+ Args:
167
+ path_or_url: the path or URL to the model parameters (checkpoint)
168
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
169
+ """
170
+ load_pretrained_params(self, path_or_url, **kwargs)
171
+
163
172
  def forward(
164
173
  self,
165
174
  x: torch.Tensor,
@@ -282,7 +291,7 @@ def _linknet(
282
291
  _ignore_keys = (
283
292
  ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
284
293
  )
285
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
294
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
286
295
 
287
296
  return model
288
297
 
@@ -163,6 +163,15 @@ class LinkNet(_LinkNet, Model):
163
163
  assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
164
164
  )
165
165
 
166
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
167
+ """Load pretrained parameters onto the model
168
+
169
+ Args:
170
+ path_or_url: the path or URL to the model parameters (checkpoint)
171
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
172
+ """
173
+ load_pretrained_params(self, path_or_url, **kwargs)
174
+
166
175
  def compute_loss(
167
176
  self,
168
177
  out_map: tf.Tensor,
@@ -282,8 +291,7 @@ def _linknet(
282
291
  # Load pretrained parameters
283
292
  if pretrained:
284
293
  # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
285
- load_pretrained_params(
286
- model,
294
+ model.from_pretrained(
287
295
  _cfg["url"],
288
296
  skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
289
297
  )
@@ -217,10 +217,10 @@ def from_hub(repo_id: str, **kwargs: Any):
217
217
 
218
218
  # Load checkpoint
219
219
  if is_torch_available():
220
- state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
221
- model.load_state_dict(state_dict)
220
+ weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
222
221
  else: # tf
223
222
  weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
224
- model.load_weights(weights)
223
+
224
+ model.from_pretrained(weights)
225
225
 
226
226
  return model
@@ -173,7 +173,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
173
173
  boxes_per_page,
174
174
  objectness_scores_per_page,
175
175
  text_preds_per_page,
176
- origin_page_shapes, # type: ignore[arg-type]
176
+ origin_page_shapes,
177
177
  crop_orientations_per_page,
178
178
  orientations,
179
179
  languages_dict,
@@ -171,7 +171,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
171
171
  boxes_per_page,
172
172
  objectness_scores_per_page,
173
173
  text_preds_per_page,
174
- origin_page_shapes, # type: ignore[arg-type]
174
+ origin_page_shapes,
175
175
  crop_orientations_per_page,
176
176
  orientations,
177
177
  languages_dict,
@@ -8,7 +8,55 @@ import numpy as np
8
8
  import torch
9
9
  import torch.nn as nn
10
10
 
11
- __all__ = ["FASTConvLayer"]
11
+ __all__ = ["FASTConvLayer", "DropPath", "AdaptiveAvgPool2d"]
12
+
13
+
14
+ class DropPath(nn.Module):
15
+ """
16
+ DropPath (Drop Connect) layer. This is a stochastic version of the identity layer.
17
+ """
18
+
19
+ # Borrowed from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
20
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
21
+ super(DropPath, self).__init__()
22
+ self.drop_prob = drop_prob
23
+ self.scale_by_keep = scale_by_keep
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ if self.drop_prob == 0.0 or not self.training:
27
+ return x
28
+ keep_prob = 1 - self.drop_prob
29
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with different dimensions
30
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
31
+ if keep_prob > 0.0 and self.scale_by_keep:
32
+ random_tensor.div_(keep_prob)
33
+ return x * random_tensor
34
+
35
+
36
+ class AdaptiveAvgPool2d(nn.Module):
37
+ """
38
+ Custom AdaptiveAvgPool2d implementation which is ONNX and `torch.compile` compatible.
39
+
40
+ """
41
+
42
+ def __init__(self, output_size):
43
+ super().__init__()
44
+ self.output_size = output_size
45
+
46
+ def forward(self, x: torch.Tensor):
47
+ H_out, W_out = self.output_size
48
+ N, C, H, W = x.shape
49
+
50
+ out = torch.empty((N, C, H_out, W_out), device=x.device, dtype=x.dtype)
51
+ for oh in range(H_out):
52
+ start_h = (oh * H) // H_out
53
+ end_h = ((oh + 1) * H + H_out - 1) // H_out # ceil((oh+1)*H / H_out)
54
+ for ow in range(W_out):
55
+ start_w = (ow * W) // W_out
56
+ end_w = ((ow + 1) * W + W_out - 1) // W_out # ceil((ow+1)*W / W_out)
57
+ # average over the window
58
+ out[:, :, oh, ow] = x[:, :, start_h:end_h, start_w:end_w].mean(dim=(-2, -1))
59
+ return out
12
60
 
13
61
 
14
62
  class FASTConvLayer(nn.Module):
@@ -150,7 +150,7 @@ class OCRPredictor(nn.Module, _OCRPredictor):
150
150
  boxes,
151
151
  objectness_scores,
152
152
  text_preds,
153
- origin_page_shapes, # type: ignore[arg-type]
153
+ origin_page_shapes,
154
154
  crop_orientations,
155
155
  orientations,
156
156
  languages_dict,
@@ -147,7 +147,7 @@ class OCRPredictor(NestedObject, _OCRPredictor):
147
147
  boxes,
148
148
  objectness_scores,
149
149
  text_preds,
150
- origin_page_shapes, # type: ignore[arg-type]
150
+ origin_page_shapes,
151
151
  crop_orientations,
152
152
  orientations,
153
153
  languages_dict,
@@ -3,4 +3,5 @@ from .master import *
3
3
  from .sar import *
4
4
  from .vitstr import *
5
5
  from .parseq import *
6
+ from .viptr import *
6
7
  from .zoo import *
@@ -155,6 +155,15 @@ class CRNN(RecognitionModel, nn.Module):
155
155
  m.weight.data.fill_(1.0)
156
156
  m.bias.data.zero_()
157
157
 
158
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
159
+ """Load pretrained parameters onto the model
160
+
161
+ Args:
162
+ path_or_url: the path or URL to the model parameters (checkpoint)
163
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
164
+ """
165
+ load_pretrained_params(self, path_or_url, **kwargs)
166
+
158
167
  def compute_loss(
159
168
  self,
160
169
  model_output: torch.Tensor,
@@ -254,7 +263,7 @@ def _crnn(
254
263
  # The number of classes is not the same as the number of classes in the pretrained model =>
255
264
  # remove the last layer weights
256
265
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
257
- load_pretrained_params(model, _cfg["url"], ignore_keys=_ignore_keys)
266
+ model.from_pretrained(_cfg["url"], ignore_keys=_ignore_keys)
258
267
 
259
268
  return model
260
269
 
@@ -154,6 +154,15 @@ class CRNN(RecognitionModel, Model):
154
154
  self.beam_width = beam_width
155
155
  self.top_paths = top_paths
156
156
 
157
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
158
+ """Load pretrained parameters onto the model
159
+
160
+ Args:
161
+ path_or_url: the path or URL to the model parameters (checkpoint)
162
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
163
+ """
164
+ load_pretrained_params(self, path_or_url, **kwargs)
165
+
157
166
  def compute_loss(
158
167
  self,
159
168
  model_output: tf.Tensor,
@@ -243,7 +252,7 @@ def _crnn(
243
252
  # Load pretrained parameters
244
253
  if pretrained:
245
254
  # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
246
- load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
255
+ model.from_pretrained(_cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
247
256
 
248
257
  return model
249
258
 
@@ -151,6 +151,15 @@ class MASTER(_MASTER, nn.Module):
151
151
  ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype)
152
152
  return ce_loss.mean()
153
153
 
154
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
155
+ """Load pretrained parameters onto the model
156
+
157
+ Args:
158
+ path_or_url: the path or URL to the model parameters (checkpoint)
159
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
160
+ """
161
+ load_pretrained_params(self, path_or_url, **kwargs)
162
+
154
163
  def forward(
155
164
  self,
156
165
  x: torch.Tensor,
@@ -301,7 +310,7 @@ def _master(
301
310
  # The number of classes is not the same as the number of classes in the pretrained model =>
302
311
  # remove the last layer weights
303
312
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
304
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
313
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
305
314
 
306
315
  return model
307
316
 
@@ -87,6 +87,15 @@ class MASTER(_MASTER, Model):
87
87
  self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform())
88
88
  self.postprocessor = MASTERPostProcessor(vocab=self.vocab)
89
89
 
90
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
91
+ """Load pretrained parameters onto the model
92
+
93
+ Args:
94
+ path_or_url: the path or URL to the model parameters (checkpoint)
95
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
96
+ """
97
+ load_pretrained_params(self, path_or_url, **kwargs)
98
+
90
99
  @tf.function
91
100
  def make_source_and_target_mask(self, source: tf.Tensor, target: tf.Tensor) -> tuple[tf.Tensor, tf.Tensor]:
92
101
  # [1, 1, 1, ..., 0, 0, 0] -> 0 is masked
@@ -287,9 +296,7 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool
287
296
  # Load pretrained parameters
288
297
  if pretrained:
289
298
  # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
290
- load_pretrained_params(
291
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
292
- )
299
+ model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
293
300
 
294
301
  return model
295
302
 
@@ -76,8 +76,6 @@ class PARSeqDecoder(nn.Module):
76
76
  self.cross_attention = MultiHeadAttention(num_heads, d_model, dropout=dropout)
77
77
  self.position_feed_forward = PositionwiseFeedForward(d_model, ffd * ffd_ratio, dropout, nn.GELU())
78
78
 
79
- self.attention_norm = nn.LayerNorm(d_model, eps=1e-5)
80
- self.cross_attention_norm = nn.LayerNorm(d_model, eps=1e-5)
81
79
  self.query_norm = nn.LayerNorm(d_model, eps=1e-5)
82
80
  self.content_norm = nn.LayerNorm(d_model, eps=1e-5)
83
81
  self.feed_forward_norm = nn.LayerNorm(d_model, eps=1e-5)
@@ -173,6 +171,26 @@ class PARSeq(_PARSeq, nn.Module):
173
171
  nn.init.constant_(m.weight, 1)
174
172
  nn.init.constant_(m.bias, 0)
175
173
 
174
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
175
+ """Load pretrained parameters onto the model
176
+
177
+ Args:
178
+ path_or_url: the path or URL to the model parameters (checkpoint)
179
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
180
+ """
181
+ # NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
182
+ # ref.: https://github.com/mindee/doctr/issues/1911
183
+ if kwargs.get("ignore_keys") is None:
184
+ kwargs["ignore_keys"] = []
185
+
186
+ kwargs["ignore_keys"].extend([
187
+ "decoder.attention_norm.weight",
188
+ "decoder.attention_norm.bias",
189
+ "decoder.cross_attention_norm.weight",
190
+ "decoder.cross_attention_norm.bias",
191
+ ])
192
+ load_pretrained_params(self, path_or_url, **kwargs)
193
+
176
194
  def generate_permutations(self, seqlen: torch.Tensor) -> torch.Tensor:
177
195
  # Generates permutations of the target sequence.
178
196
  # Borrowed from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
@@ -210,7 +228,7 @@ class PARSeq(_PARSeq, nn.Module):
210
228
 
211
229
  sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
212
230
  eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
213
- combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int() # type: ignore[list-item]
231
+ combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
214
232
  if len(combined) > 1:
215
233
  combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
216
234
  return combined
@@ -349,7 +367,7 @@ class PARSeq(_PARSeq, nn.Module):
349
367
  # remove the [EOS] tokens for the succeeding perms
350
368
  if i == 1:
351
369
  gt_out = torch.where(gt_out == self.vocab_size, self.vocab_size + 2, gt_out)
352
- n = (gt_out != self.vocab_size + 2).sum().item()
370
+ n = (gt_out != self.vocab_size + 2).sum().item() # type: ignore[attr-defined]
353
371
 
354
372
  loss /= loss_numel
355
373
 
@@ -448,7 +466,7 @@ def _parseq(
448
466
  # The number of classes is not the same as the number of classes in the pretrained model =>
449
467
  # remove the last layer weights
450
468
  _ignore_keys = ignore_keys if _cfg["vocab"] != default_cfgs[arch]["vocab"] else None
451
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
469
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
452
470
 
453
471
  return model
454
472
 
@@ -76,8 +76,6 @@ class PARSeqDecoder(layers.Layer):
76
76
  d_model, ffd * ffd_ratio, dropout, layers.Activation(tf.nn.gelu)
77
77
  )
78
78
 
79
- self.attention_norm = layers.LayerNormalization(epsilon=1e-5)
80
- self.cross_attention_norm = layers.LayerNormalization(epsilon=1e-5)
81
79
  self.query_norm = layers.LayerNormalization(epsilon=1e-5)
82
80
  self.content_norm = layers.LayerNormalization(epsilon=1e-5)
83
81
  self.feed_forward_norm = layers.LayerNormalization(epsilon=1e-5)
@@ -165,6 +163,18 @@ class PARSeq(_PARSeq, Model):
165
163
 
166
164
  self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
167
165
 
166
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
167
+ """Load pretrained parameters onto the model
168
+
169
+ Args:
170
+ path_or_url: the path or URL to the model parameters (checkpoint)
171
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
172
+ """
173
+ # NOTE: This is required to make the model backward compatible with already trained models docTR version <0.11.1
174
+ # ref.: https://github.com/mindee/doctr/issues/1911
175
+ kwargs["skip_mismatch"] = True
176
+ load_pretrained_params(self, path_or_url, **kwargs)
177
+
168
178
  def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
169
179
  # Generates permutations of the target sequence.
170
180
  # Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
@@ -474,9 +484,7 @@ def _parseq(
474
484
  # Load pretrained parameters
475
485
  if pretrained:
476
486
  # The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
477
- load_pretrained_params(
478
- model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
479
- )
487
+ model.from_pretrained(default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
480
488
 
481
489
  return model
482
490
 
@@ -4,6 +4,8 @@
4
4
  # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
5
5
 
6
6
 
7
+ import math
8
+
7
9
  import numpy as np
8
10
 
9
11
  from ..utils import merge_multi_strings
@@ -15,69 +17,129 @@ def split_crops(
15
17
  crops: list[np.ndarray],
16
18
  max_ratio: float,
17
19
  target_ratio: int,
18
- dilation: float,
20
+ split_overlap_ratio: float,
19
21
  channels_last: bool = True,
20
- ) -> tuple[list[np.ndarray], list[int | tuple[int, int]], bool]:
21
- """Chunk crops horizontally to match a given aspect ratio
22
+ ) -> tuple[list[np.ndarray], list[int | tuple[int, int, float]], bool]:
23
+ """
24
+ Split crops horizontally if they exceed a given aspect ratio.
22
25
 
23
26
  Args:
24
- crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise
25
- max_ratio: the maximum aspect ratio that won't trigger the chunk
26
- target_ratio: when crops are chunked, they will be chunked to match this aspect ratio
27
- dilation: the width dilation of final chunks (to provide some overlaps)
28
- channels_last: whether the numpy array has dimensions in channels last order
27
+ crops: List of image crops (H, W, C) if channels_last else (C, H, W).
28
+ max_ratio: Aspect ratio threshold above which crops are split.
29
+ target_ratio: Target aspect ratio after splitting (e.g., 4 for 128x32).
30
+ split_overlap_ratio: Desired overlap between splits (as a fraction of split width).
31
+ channels_last: Whether the crops are in channels-last format.
29
32
 
30
33
  Returns:
31
- a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required
34
+ A tuple containing:
35
+ - The new list of crops (possibly with splits),
36
+ - A mapping indicating how to reassemble predictions,
37
+ - A boolean indicating whether remapping is required.
32
38
  """
33
- _remap_required = False
34
- crop_map: list[int | tuple[int, int]] = []
39
+ if split_overlap_ratio <= 0.0 or split_overlap_ratio >= 1.0:
40
+ raise ValueError(f"Valid range for split_overlap_ratio is (0.0, 1.0), but is: {split_overlap_ratio}")
41
+
42
+ remap_required = False
35
43
  new_crops: list[np.ndarray] = []
44
+ crop_map: list[int | tuple[int, int, float]] = []
45
+
36
46
  for crop in crops:
37
47
  h, w = crop.shape[:2] if channels_last else crop.shape[-2:]
38
48
  aspect_ratio = w / h
49
+
39
50
  if aspect_ratio > max_ratio:
40
- # Determine the number of crops, reference aspect ratio = 4 = 128 / 32
41
- num_subcrops = int(aspect_ratio // target_ratio)
42
- # Find the new widths, additional dilation factor to overlap crops
43
- width = dilation * w / num_subcrops
44
- centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)]
45
- # Get the crops
46
- if channels_last:
47
- _crops = [
48
- crop[:, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2))), :]
49
- for center in centers
50
- ]
51
+ split_width = max(1, math.ceil(h * target_ratio))
52
+ overlap_width = max(0, math.floor(split_width * split_overlap_ratio))
53
+
54
+ splits, last_overlap = _split_horizontally(crop, split_width, overlap_width, channels_last)
55
+
56
+ # Remove any empty splits
57
+ splits = [s for s in splits if all(dim > 0 for dim in s.shape)]
58
+ if splits:
59
+ crop_map.append((len(new_crops), len(new_crops) + len(splits), last_overlap))
60
+ new_crops.extend(splits)
61
+ remap_required = True
51
62
  else:
52
- _crops = [
53
- crop[:, :, max(0, int(round(center - width / 2))) : min(w - 1, int(round(center + width / 2)))]
54
- for center in centers
55
- ]
56
- # Avoid sending zero-sized crops
57
- _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)]
58
- # Record the slice of crops
59
- crop_map.append((len(new_crops), len(new_crops) + len(_crops)))
60
- new_crops.extend(_crops)
61
- # At least one crop will require merging
62
- _remap_required = True
63
+ # Fallback: treat it as a single crop
64
+ crop_map.append(len(new_crops))
65
+ new_crops.append(crop)
63
66
  else:
64
67
  crop_map.append(len(new_crops))
65
68
  new_crops.append(crop)
66
69
 
67
- return new_crops, crop_map, _remap_required
70
+ return new_crops, crop_map, remap_required
71
+
72
+
73
+ def _split_horizontally(
74
+ image: np.ndarray, split_width: int, overlap_width: int, channels_last: bool
75
+ ) -> tuple[list[np.ndarray], float]:
76
+ """
77
+ Horizontally split a single image with overlapping regions.
78
+
79
+ Args:
80
+ image: The image to split (H, W, C) if channels_last else (C, H, W).
81
+ split_width: Width of each split.
82
+ overlap_width: Width of the overlapping region.
83
+ channels_last: Whether the image is in channels-last format.
84
+
85
+ Returns:
86
+ - A list of horizontal image slices.
87
+ - The actual overlap ratio of the last split.
88
+ """
89
+ image_width = image.shape[1] if channels_last else image.shape[-1]
90
+ if image_width <= split_width:
91
+ return [image], 0.0
92
+
93
+ # Compute start columns for each split
94
+ step = split_width - overlap_width
95
+ starts = list(range(0, image_width - split_width + 1, step))
96
+
97
+ # Ensure the last patch reaches the end of the image
98
+ if starts[-1] + split_width < image_width:
99
+ starts.append(image_width - split_width)
100
+
101
+ splits = []
102
+ for start_col in starts:
103
+ end_col = start_col + split_width
104
+ if channels_last:
105
+ split = image[:, start_col:end_col, :]
106
+ else:
107
+ split = image[:, :, start_col:end_col]
108
+ splits.append(split)
109
+
110
+ # Calculate the last overlap ratio, if only one split no overlap
111
+ last_overlap = 0
112
+ if len(starts) > 1:
113
+ last_overlap = (starts[-2] + split_width) - starts[-1]
114
+ last_overlap_ratio = last_overlap / split_width if split_width else 0.0
115
+
116
+ return splits, last_overlap_ratio
68
117
 
69
118
 
70
119
  def remap_preds(
71
- preds: list[tuple[str, float]], crop_map: list[int | tuple[int, int]], dilation: float
120
+ preds: list[tuple[str, float]],
121
+ crop_map: list[int | tuple[int, int, float]],
122
+ overlap_ratio: float,
72
123
  ) -> list[tuple[str, float]]:
73
- remapped_out = []
74
- for _idx in crop_map:
75
- # Crop hasn't been split
76
- if isinstance(_idx, int):
77
- remapped_out.append(preds[_idx])
124
+ """
125
+ Reconstruct predictions from possibly split crops.
126
+
127
+ Args:
128
+ preds: List of (text, confidence) tuples from each crop.
129
+ crop_map: Map returned by `split_crops`.
130
+ overlap_ratio: Overlap ratio used during splitting.
131
+
132
+ Returns:
133
+ List of merged (text, confidence) tuples corresponding to original crops.
134
+ """
135
+ remapped = []
136
+ for item in crop_map:
137
+ if isinstance(item, int):
138
+ remapped.append(preds[item])
78
139
  else:
79
- # unzip
80
- vals, probs = zip(*preds[_idx[0] : _idx[1]])
81
- # Merge the string values
82
- remapped_out.append((merge_multi_strings(vals, dilation), min(probs))) # type: ignore[arg-type]
83
- return remapped_out
140
+ start_idx, end_idx, last_overlap = item
141
+ text_parts, confidences = zip(*preds[start_idx:end_idx])
142
+ merged_text = merge_multi_strings(list(text_parts), overlap_ratio, last_overlap)
143
+ merged_conf = sum(confidences) / len(confidences) # average confidence
144
+ remapped.append((merged_text, merged_conf))
145
+ return remapped
@@ -38,7 +38,7 @@ class RecognitionPredictor(nn.Module):
38
38
  self.model = model.eval()
39
39
  self.split_wide_crops = split_wide_crops
40
40
  self.critical_ar = 8 # Critical aspect ratio
41
- self.dil_factor = 1.4 # Dilation factor to overlap the crops
41
+ self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
42
42
  self.target_ar = 6 # Target aspect ratio
43
43
 
44
44
  @torch.inference_mode()
@@ -60,7 +60,7 @@ class RecognitionPredictor(nn.Module):
60
60
  crops, # type: ignore[arg-type]
61
61
  self.critical_ar,
62
62
  self.target_ar,
63
- self.dil_factor,
63
+ self.overlap_ratio,
64
64
  isinstance(crops[0], np.ndarray),
65
65
  )
66
66
  if remapped:
@@ -81,6 +81,6 @@ class RecognitionPredictor(nn.Module):
81
81
 
82
82
  # Remap crops
83
83
  if self.split_wide_crops and remapped:
84
- out = remap_preds(out, crop_map, self.dil_factor)
84
+ out = remap_preds(out, crop_map, self.overlap_ratio)
85
85
 
86
86
  return out
@@ -39,7 +39,7 @@ class RecognitionPredictor(NestedObject):
39
39
  self.model = model
40
40
  self.split_wide_crops = split_wide_crops
41
41
  self.critical_ar = 8 # Critical aspect ratio
42
- self.dil_factor = 1.4 # Dilation factor to overlap the crops
42
+ self.overlap_ratio = 0.5 # Ratio of overlap between neighboring crops
43
43
  self.target_ar = 6 # Target aspect ratio
44
44
 
45
45
  def __call__(
@@ -56,7 +56,7 @@ class RecognitionPredictor(NestedObject):
56
56
  # Split crops that are too wide
57
57
  remapped = False
58
58
  if self.split_wide_crops:
59
- new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.dil_factor)
59
+ new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.overlap_ratio)
60
60
  if remapped:
61
61
  crops = new_crops
62
62
 
@@ -74,6 +74,6 @@ class RecognitionPredictor(NestedObject):
74
74
 
75
75
  # Remap crops
76
76
  if self.split_wide_crops and remapped:
77
- out = remap_preds(out, crop_map, self.dil_factor)
77
+ out = remap_preds(out, crop_map, self.overlap_ratio)
78
78
 
79
79
  return out