python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. doctr/__init__.py +0 -1
  2. doctr/datasets/__init__.py +1 -5
  3. doctr/datasets/coco_text.py +139 -0
  4. doctr/datasets/cord.py +2 -1
  5. doctr/datasets/datasets/__init__.py +1 -6
  6. doctr/datasets/datasets/pytorch.py +2 -2
  7. doctr/datasets/funsd.py +2 -2
  8. doctr/datasets/generator/__init__.py +1 -6
  9. doctr/datasets/ic03.py +1 -1
  10. doctr/datasets/ic13.py +2 -1
  11. doctr/datasets/iiit5k.py +4 -1
  12. doctr/datasets/imgur5k.py +9 -2
  13. doctr/datasets/ocr.py +1 -1
  14. doctr/datasets/recognition.py +1 -1
  15. doctr/datasets/svhn.py +1 -1
  16. doctr/datasets/svt.py +2 -2
  17. doctr/datasets/synthtext.py +15 -2
  18. doctr/datasets/utils.py +7 -6
  19. doctr/datasets/vocabs.py +1100 -54
  20. doctr/file_utils.py +2 -92
  21. doctr/io/elements.py +37 -3
  22. doctr/io/image/__init__.py +1 -7
  23. doctr/io/image/pytorch.py +1 -1
  24. doctr/models/_utils.py +4 -4
  25. doctr/models/classification/__init__.py +1 -0
  26. doctr/models/classification/magc_resnet/__init__.py +1 -6
  27. doctr/models/classification/magc_resnet/pytorch.py +3 -4
  28. doctr/models/classification/mobilenet/__init__.py +1 -6
  29. doctr/models/classification/mobilenet/pytorch.py +15 -1
  30. doctr/models/classification/predictor/__init__.py +1 -6
  31. doctr/models/classification/predictor/pytorch.py +2 -2
  32. doctr/models/classification/resnet/__init__.py +1 -6
  33. doctr/models/classification/resnet/pytorch.py +26 -3
  34. doctr/models/classification/textnet/__init__.py +1 -6
  35. doctr/models/classification/textnet/pytorch.py +11 -2
  36. doctr/models/classification/vgg/__init__.py +1 -6
  37. doctr/models/classification/vgg/pytorch.py +16 -1
  38. doctr/models/classification/vip/__init__.py +1 -0
  39. doctr/models/classification/vip/layers/__init__.py +1 -0
  40. doctr/models/classification/vip/layers/pytorch.py +615 -0
  41. doctr/models/classification/vip/pytorch.py +505 -0
  42. doctr/models/classification/vit/__init__.py +1 -6
  43. doctr/models/classification/vit/pytorch.py +12 -3
  44. doctr/models/classification/zoo.py +7 -8
  45. doctr/models/detection/_utils/__init__.py +1 -6
  46. doctr/models/detection/core.py +1 -1
  47. doctr/models/detection/differentiable_binarization/__init__.py +1 -6
  48. doctr/models/detection/differentiable_binarization/base.py +7 -16
  49. doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
  50. doctr/models/detection/fast/__init__.py +1 -6
  51. doctr/models/detection/fast/base.py +6 -17
  52. doctr/models/detection/fast/pytorch.py +17 -8
  53. doctr/models/detection/linknet/__init__.py +1 -6
  54. doctr/models/detection/linknet/base.py +5 -15
  55. doctr/models/detection/linknet/pytorch.py +12 -3
  56. doctr/models/detection/predictor/__init__.py +1 -6
  57. doctr/models/detection/predictor/pytorch.py +1 -1
  58. doctr/models/detection/zoo.py +15 -32
  59. doctr/models/factory/hub.py +9 -22
  60. doctr/models/kie_predictor/__init__.py +1 -6
  61. doctr/models/kie_predictor/pytorch.py +3 -7
  62. doctr/models/modules/layers/__init__.py +1 -6
  63. doctr/models/modules/layers/pytorch.py +52 -4
  64. doctr/models/modules/transformer/__init__.py +1 -6
  65. doctr/models/modules/transformer/pytorch.py +2 -2
  66. doctr/models/modules/vision_transformer/__init__.py +1 -6
  67. doctr/models/predictor/__init__.py +1 -6
  68. doctr/models/predictor/base.py +3 -8
  69. doctr/models/predictor/pytorch.py +3 -6
  70. doctr/models/preprocessor/__init__.py +1 -6
  71. doctr/models/preprocessor/pytorch.py +27 -32
  72. doctr/models/recognition/__init__.py +1 -0
  73. doctr/models/recognition/crnn/__init__.py +1 -6
  74. doctr/models/recognition/crnn/pytorch.py +16 -7
  75. doctr/models/recognition/master/__init__.py +1 -6
  76. doctr/models/recognition/master/pytorch.py +15 -6
  77. doctr/models/recognition/parseq/__init__.py +1 -6
  78. doctr/models/recognition/parseq/pytorch.py +26 -8
  79. doctr/models/recognition/predictor/__init__.py +1 -6
  80. doctr/models/recognition/predictor/_utils.py +100 -47
  81. doctr/models/recognition/predictor/pytorch.py +4 -5
  82. doctr/models/recognition/sar/__init__.py +1 -6
  83. doctr/models/recognition/sar/pytorch.py +13 -4
  84. doctr/models/recognition/utils.py +56 -47
  85. doctr/models/recognition/viptr/__init__.py +1 -0
  86. doctr/models/recognition/viptr/pytorch.py +277 -0
  87. doctr/models/recognition/vitstr/__init__.py +1 -6
  88. doctr/models/recognition/vitstr/pytorch.py +13 -4
  89. doctr/models/recognition/zoo.py +13 -8
  90. doctr/models/utils/__init__.py +1 -6
  91. doctr/models/utils/pytorch.py +29 -19
  92. doctr/transforms/functional/__init__.py +1 -6
  93. doctr/transforms/functional/pytorch.py +4 -4
  94. doctr/transforms/modules/__init__.py +1 -7
  95. doctr/transforms/modules/base.py +26 -92
  96. doctr/transforms/modules/pytorch.py +28 -26
  97. doctr/utils/data.py +1 -1
  98. doctr/utils/geometry.py +7 -11
  99. doctr/utils/visualization.py +1 -1
  100. doctr/version.py +1 -1
  101. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
  102. python_doctr-1.0.0.dist-info/RECORD +149 -0
  103. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
  104. doctr/datasets/datasets/tensorflow.py +0 -59
  105. doctr/datasets/generator/tensorflow.py +0 -58
  106. doctr/datasets/loader.py +0 -94
  107. doctr/io/image/tensorflow.py +0 -101
  108. doctr/models/classification/magc_resnet/tensorflow.py +0 -196
  109. doctr/models/classification/mobilenet/tensorflow.py +0 -433
  110. doctr/models/classification/predictor/tensorflow.py +0 -60
  111. doctr/models/classification/resnet/tensorflow.py +0 -397
  112. doctr/models/classification/textnet/tensorflow.py +0 -266
  113. doctr/models/classification/vgg/tensorflow.py +0 -116
  114. doctr/models/classification/vit/tensorflow.py +0 -192
  115. doctr/models/detection/_utils/tensorflow.py +0 -34
  116. doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
  117. doctr/models/detection/fast/tensorflow.py +0 -419
  118. doctr/models/detection/linknet/tensorflow.py +0 -369
  119. doctr/models/detection/predictor/tensorflow.py +0 -70
  120. doctr/models/kie_predictor/tensorflow.py +0 -187
  121. doctr/models/modules/layers/tensorflow.py +0 -171
  122. doctr/models/modules/transformer/tensorflow.py +0 -235
  123. doctr/models/modules/vision_transformer/tensorflow.py +0 -100
  124. doctr/models/predictor/tensorflow.py +0 -155
  125. doctr/models/preprocessor/tensorflow.py +0 -122
  126. doctr/models/recognition/crnn/tensorflow.py +0 -308
  127. doctr/models/recognition/master/tensorflow.py +0 -313
  128. doctr/models/recognition/parseq/tensorflow.py +0 -508
  129. doctr/models/recognition/predictor/tensorflow.py +0 -79
  130. doctr/models/recognition/sar/tensorflow.py +0 -416
  131. doctr/models/recognition/vitstr/tensorflow.py +0 -278
  132. doctr/models/utils/tensorflow.py +0 -182
  133. doctr/transforms/functional/tensorflow.py +0 -254
  134. doctr/transforms/modules/tensorflow.py +0 -562
  135. python_doctr-0.11.0.dist-info/RECORD +0 -173
  136. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
  137. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
  138. {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
@@ -179,6 +179,15 @@ class DBNet(_DBNet, nn.Module):
179
179
  m.weight.data.fill_(1.0)
180
180
  m.bias.data.zero_()
181
181
 
182
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
183
+ """Load pretrained parameters onto the model
184
+
185
+ Args:
186
+ path_or_url: the path or URL to the model parameters (checkpoint)
187
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
188
+ """
189
+ load_pretrained_params(self, path_or_url, **kwargs)
190
+
182
191
  def forward(
183
192
  self,
184
193
  x: torch.Tensor,
@@ -206,7 +215,7 @@ class DBNet(_DBNet, nn.Module):
206
215
 
207
216
  if target is None or return_preds:
208
217
  # Disable for torch.compile compatibility
209
- @torch.compiler.disable # type: ignore[attr-defined]
218
+ @torch.compiler.disable
210
219
  def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
211
220
  return [
212
221
  dict(zip(self.class_names, preds))
@@ -252,7 +261,7 @@ class DBNet(_DBNet, nn.Module):
252
261
  prob_map = torch.sigmoid(out_map)
253
262
  thresh_map = torch.sigmoid(thresh_map)
254
263
 
255
- targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
264
+ targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
256
265
 
257
266
  seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
258
267
  seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
@@ -276,7 +285,7 @@ class DBNet(_DBNet, nn.Module):
276
285
  dice_map = torch.softmax(out_map, dim=1)
277
286
  else:
278
287
  # compute binary map instead
279
- dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map))) # type: ignore[assignment]
288
+ dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
280
289
  # Class reduced
281
290
  inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
282
291
  cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
@@ -328,7 +337,7 @@ def _dbnet(
328
337
  _ignore_keys = (
329
338
  ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
330
339
  )
331
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
340
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
332
341
 
333
342
  return model
334
343
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -56,9 +56,8 @@ class FASTPostProcessor(DetectionPostProcessor):
56
56
  area = (rect[1][0] + 1) * (1 + rect[1][1])
57
57
  length = 2 * (rect[1][0] + rect[1][1]) + 2
58
58
  else:
59
- poly = Polygon(points)
60
- area = poly.area
61
- length = poly.length
59
+ area = cv2.contourArea(points)
60
+ length = cv2.arcLength(points, closed=True)
62
61
  distance = area * self.unclip_ratio / length # compute distance to expand polygon
63
62
  offset = pyclipper.PyclipperOffset()
64
63
  offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
@@ -154,14 +153,12 @@ class _FAST(BaseModel):
154
153
  self,
155
154
  target: list[dict[str, np.ndarray]],
156
155
  output_shape: tuple[int, int, int],
157
- channels_last: bool = True,
158
156
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
159
157
  """Build the target, and it's mask to be used from loss computation.
160
158
 
161
159
  Args:
162
160
  target: target coming from dataset
163
161
  output_shape: shape of the output of the model without batch_size
164
- channels_last: whether channels are last or not
165
162
 
166
163
  Returns:
167
164
  the new formatted target, mask and shrunken text kernel
@@ -173,10 +170,8 @@ class _FAST(BaseModel):
173
170
 
174
171
  h: int
175
172
  w: int
176
- if channels_last:
177
- h, w, num_classes = output_shape
178
- else:
179
- num_classes, h, w = output_shape
173
+
174
+ num_classes, h, w = output_shape
180
175
  target_shape = (len(target), num_classes, h, w)
181
176
 
182
177
  seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
@@ -236,14 +231,8 @@ class _FAST(BaseModel):
236
231
  if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
237
232
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
238
233
  continue
239
- cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
234
+ cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
240
235
  # draw the original polygon on the segmentation target
241
- cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0) # type: ignore[call-overload]
242
-
243
- # Don't forget to switch back to channel last if Tensorflow is used
244
- if channels_last:
245
- seg_target = seg_target.transpose((0, 2, 3, 1))
246
- seg_mask = seg_mask.transpose((0, 2, 3, 1))
247
- shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))
236
+ cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0)
248
237
 
249
238
  return seg_target, seg_mask, shrunken_kernel
@@ -170,6 +170,15 @@ class FAST(_FAST, nn.Module):
170
170
  m.weight.data.fill_(1.0)
171
171
  m.bias.data.zero_()
172
172
 
173
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
174
+ """Load pretrained parameters onto the model
175
+
176
+ Args:
177
+ path_or_url: the path or URL to the model parameters (checkpoint)
178
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
179
+ """
180
+ load_pretrained_params(self, path_or_url, **kwargs)
181
+
173
182
  def forward(
174
183
  self,
175
184
  x: torch.Tensor,
@@ -197,7 +206,7 @@ class FAST(_FAST, nn.Module):
197
206
 
198
207
  if target is None or return_preds:
199
208
  # Disable for torch.compile compatibility
200
- @torch.compiler.disable # type: ignore[attr-defined]
209
+ @torch.compiler.disable
201
210
  def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
202
211
  return [
203
212
  dict(zip(self.class_names, preds))
@@ -229,7 +238,7 @@ class FAST(_FAST, nn.Module):
229
238
  Returns:
230
239
  A loss tensor
231
240
  """
232
- targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
241
+ targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
233
242
 
234
243
  seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
235
244
  shrunken_kernel = torch.from_numpy(targets[2]).to(out_map.device)
@@ -294,7 +303,7 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
294
303
 
295
304
  for module in model.modules():
296
305
  if hasattr(module, "reparameterize_layer"):
297
- module.reparameterize_layer()
306
+ module.reparameterize_layer() # type: ignore[operator]
298
307
 
299
308
  for name, child in model.named_children():
300
309
  if isinstance(child, nn.BatchNorm2d):
@@ -302,12 +311,12 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
302
311
  if last_conv is None:
303
312
  continue
304
313
  conv_w = last_conv.weight
305
- conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean)
314
+ conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean) # type: ignore[arg-type]
306
315
 
307
- factor = child.weight / torch.sqrt(child.running_var + child.eps)
316
+ factor = child.weight / torch.sqrt(child.running_var + child.eps) # type: ignore
308
317
  last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
309
- last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
310
- model._modules[last_conv_name] = last_conv
318
+ last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias) # type: ignore[operator]
319
+ model._modules[last_conv_name] = last_conv # type: ignore[index]
311
320
  model._modules[name] = nn.Identity()
312
321
  last_conv = None
313
322
  elif isinstance(child, nn.Conv2d):
@@ -349,7 +358,7 @@ def _fast(
349
358
  _ignore_keys = (
350
359
  ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
351
360
  )
352
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
361
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
353
362
 
354
363
  return model
355
364
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -56,9 +56,8 @@ class LinkNetPostProcessor(DetectionPostProcessor):
56
56
  area = (rect[1][0] + 1) * (1 + rect[1][1])
57
57
  length = 2 * (rect[1][0] + rect[1][1]) + 2
58
58
  else:
59
- poly = Polygon(points)
60
- area = poly.area
61
- length = poly.length
59
+ area = cv2.contourArea(points)
60
+ length = cv2.arcLength(points, closed=True)
62
61
  distance = area * self.unclip_ratio / length # compute distance to expand polygon
63
62
  offset = pyclipper.PyclipperOffset()
64
63
  offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
@@ -157,14 +156,12 @@ class _LinkNet(BaseModel):
157
156
  self,
158
157
  target: list[dict[str, np.ndarray]],
159
158
  output_shape: tuple[int, int, int],
160
- channels_last: bool = True,
161
159
  ) -> tuple[np.ndarray, np.ndarray]:
162
160
  """Build the target, and it's mask to be used from loss computation.
163
161
 
164
162
  Args:
165
163
  target: target coming from dataset
166
164
  output_shape: shape of the output of the model without batch_size
167
- channels_last: whether channels are last or not
168
165
 
169
166
  Returns:
170
167
  the new formatted target and the mask
@@ -176,10 +173,8 @@ class _LinkNet(BaseModel):
176
173
 
177
174
  h: int
178
175
  w: int
179
- if channels_last:
180
- h, w, num_classes = output_shape
181
- else:
182
- num_classes, h, w = output_shape
176
+
177
+ num_classes, h, w = output_shape
183
178
  target_shape = (len(target), num_classes, h, w)
184
179
 
185
180
  seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
@@ -238,11 +233,6 @@ class _LinkNet(BaseModel):
238
233
  if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
239
234
  seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
240
235
  continue
241
- cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0) # type: ignore[call-overload]
242
-
243
- # Don't forget to switch back to channel last if Tensorflow is used
244
- if channels_last:
245
- seg_target = seg_target.transpose((0, 2, 3, 1))
246
- seg_mask = seg_mask.transpose((0, 2, 3, 1))
236
+ cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
247
237
 
248
238
  return seg_target, seg_mask
@@ -160,6 +160,15 @@ class LinkNet(nn.Module, _LinkNet):
160
160
  m.weight.data.fill_(1.0)
161
161
  m.bias.data.zero_()
162
162
 
163
+ def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
164
+ """Load pretrained parameters onto the model
165
+
166
+ Args:
167
+ path_or_url: the path or URL to the model parameters (checkpoint)
168
+ **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
169
+ """
170
+ load_pretrained_params(self, path_or_url, **kwargs)
171
+
163
172
  def forward(
164
173
  self,
165
174
  x: torch.Tensor,
@@ -184,7 +193,7 @@ class LinkNet(nn.Module, _LinkNet):
184
193
 
185
194
  if target is None or return_preds:
186
195
  # Disable for torch.compile compatibility
187
- @torch.compiler.disable # type: ignore[attr-defined]
196
+ @torch.compiler.disable
188
197
  def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
189
198
  return [
190
199
  dict(zip(self.class_names, preds))
@@ -221,7 +230,7 @@ class LinkNet(nn.Module, _LinkNet):
221
230
  Returns:
222
231
  A loss tensor
223
232
  """
224
- _target, _mask = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]
233
+ _target, _mask = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
225
234
 
226
235
  seg_target, seg_mask = torch.from_numpy(_target).to(dtype=out_map.dtype), torch.from_numpy(_mask)
227
236
  seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
@@ -282,7 +291,7 @@ def _linknet(
282
291
  _ignore_keys = (
283
292
  ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
284
293
  )
285
- load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
294
+ model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
286
295
 
287
296
  return model
288
297
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -36,7 +36,7 @@ class DetectionPredictor(nn.Module):
36
36
  @torch.inference_mode()
37
37
  def forward(
38
38
  self,
39
- pages: list[np.ndarray | torch.Tensor],
39
+ pages: list[np.ndarray],
40
40
  return_maps: bool = False,
41
41
  **kwargs: Any,
42
42
  ) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
@@ -5,7 +5,7 @@
5
5
 
6
6
  from typing import Any
7
7
 
8
- from doctr.file_utils import is_tf_available, is_torch_available
8
+ from doctr.models.utils import _CompiledModule
9
9
 
10
10
  from .. import detection
11
11
  from ..detection.fast import reparameterize
@@ -16,30 +16,17 @@ __all__ = ["detection_predictor"]
16
16
 
17
17
  ARCHS: list[str]
18
18
 
19
-
20
- if is_tf_available():
21
- ARCHS = [
22
- "db_resnet50",
23
- "db_mobilenet_v3_large",
24
- "linknet_resnet18",
25
- "linknet_resnet34",
26
- "linknet_resnet50",
27
- "fast_tiny",
28
- "fast_small",
29
- "fast_base",
30
- ]
31
- elif is_torch_available():
32
- ARCHS = [
33
- "db_resnet34",
34
- "db_resnet50",
35
- "db_mobilenet_v3_large",
36
- "linknet_resnet18",
37
- "linknet_resnet34",
38
- "linknet_resnet50",
39
- "fast_tiny",
40
- "fast_small",
41
- "fast_base",
42
- ]
19
+ ARCHS = [
20
+ "db_resnet34",
21
+ "db_resnet50",
22
+ "db_mobilenet_v3_large",
23
+ "linknet_resnet18",
24
+ "linknet_resnet34",
25
+ "linknet_resnet50",
26
+ "fast_tiny",
27
+ "fast_small",
28
+ "fast_base",
29
+ ]
43
30
 
44
31
 
45
32
  def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
@@ -56,12 +43,8 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
56
43
  if isinstance(_model, detection.FAST):
57
44
  _model = reparameterize(_model)
58
45
  else:
59
- allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST]
60
- if is_torch_available():
61
- # Adding the type for torch compiled models to the allowed architectures
62
- from doctr.models.utils import _CompiledModule
63
-
64
- allowed_archs.append(_CompiledModule)
46
+ # Adding the type for torch compiled models to the allowed architectures
47
+ allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST, _CompiledModule]
65
48
 
66
49
  if not isinstance(arch, tuple(allowed_archs)):
67
50
  raise ValueError(f"unknown architecture: {type(arch)}")
@@ -76,7 +59,7 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
76
59
  kwargs["std"] = kwargs.get("std", _model.cfg["std"])
77
60
  kwargs["batch_size"] = kwargs.get("batch_size", 2)
78
61
  predictor = DetectionPredictor(
79
- PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs),
62
+ PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
80
63
  _model,
81
64
  )
82
65
  return predictor
@@ -13,6 +13,7 @@ import textwrap
13
13
  from pathlib import Path
14
14
  from typing import Any
15
15
 
16
+ import torch
16
17
  from huggingface_hub import (
17
18
  HfApi,
18
19
  Repository,
@@ -23,10 +24,6 @@ from huggingface_hub import (
23
24
  )
24
25
 
25
26
  from doctr import models
26
- from doctr.file_utils import is_tf_available, is_torch_available
27
-
28
- if is_torch_available():
29
- import torch
30
27
 
31
28
  __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]
32
29
 
@@ -61,19 +58,14 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
61
58
  """Save model and config to disk for pushing to huggingface hub
62
59
 
63
60
  Args:
64
- model: TF or PyTorch model to be saved
61
+ model: PyTorch model to be saved
65
62
  save_dir: directory to save model and config
66
63
  arch: architecture name
67
64
  task: task name
68
65
  """
69
66
  save_directory = Path(save_dir)
70
-
71
- if is_torch_available():
72
- weights_path = save_directory / "pytorch_model.bin"
73
- torch.save(model.state_dict(), weights_path)
74
- elif is_tf_available():
75
- weights_path = save_directory / "tf_model.weights.h5"
76
- model.save_weights(str(weights_path))
67
+ weights_path = save_directory / "pytorch_model.bin"
68
+ torch.save(model.state_dict(), weights_path)
77
69
 
78
70
  config_path = save_directory / "config.json"
79
71
 
@@ -96,7 +88,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
96
88
  >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
97
89
 
98
90
  Args:
99
- model: TF or PyTorch model to be saved
91
+ model: PyTorch model to be saved
100
92
  model_name: name of the model which is also the repository name
101
93
  task: task name
102
94
  **kwargs: keyword arguments for push_to_hf_hub
@@ -120,7 +112,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
120
112
  <img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
121
113
  </p>
122
114
 
123
- **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch**
115
+ **Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**
124
116
 
125
117
  ## Task: {task}
126
118
 
@@ -214,13 +206,8 @@ def from_hub(repo_id: str, **kwargs: Any):
214
206
 
215
207
  # update model cfg
216
208
  model.cfg = cfg
217
-
218
- # Load checkpoint
219
- if is_torch_available():
220
- state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
221
- model.load_state_dict(state_dict)
222
- else: # tf
223
- weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
224
- model.load_weights(weights)
209
+ # load the weights
210
+ weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
211
+ model.from_pretrained(weights)
225
212
 
226
213
  return model
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *
@@ -68,14 +68,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
68
68
  @torch.inference_mode()
69
69
  def forward(
70
70
  self,
71
- pages: list[np.ndarray | torch.Tensor],
71
+ pages: list[np.ndarray],
72
72
  **kwargs: Any,
73
73
  ) -> Document:
74
74
  # Dimension check
75
75
  if any(page.ndim != 3 for page in pages):
76
76
  raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
77
77
 
78
- origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages]
78
+ origin_page_shapes = [page.shape[:2] for page in pages]
79
79
 
80
80
  # Localize text elements
81
81
  loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
@@ -113,9 +113,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
113
113
  dict_loc_preds[class_name] = _loc_preds
114
114
  objectness_scores[class_name] = _scores
115
115
 
116
- # Check whether crop mode should be switched to channels first
117
- channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
118
-
119
116
  # Apply hooks to loc_preds if any
120
117
  for hook in self.hooks:
121
118
  dict_loc_preds = hook(dict_loc_preds)
@@ -126,7 +123,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
126
123
  crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
127
124
  pages,
128
125
  dict_loc_preds[class_name],
129
- channels_last=channels_last,
130
126
  assume_straight_pages=self.assume_straight_pages,
131
127
  assume_horizontal=self._page_orientation_disabled,
132
128
  )
@@ -173,7 +169,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
173
169
  boxes_per_page,
174
170
  objectness_scores_per_page,
175
171
  text_preds_per_page,
176
- origin_page_shapes, # type: ignore[arg-type]
172
+ origin_page_shapes,
177
173
  crop_orientations_per_page,
178
174
  orientations,
179
175
  languages_dict,
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -8,7 +8,55 @@ import numpy as np
8
8
  import torch
9
9
  import torch.nn as nn
10
10
 
11
- __all__ = ["FASTConvLayer"]
11
+ __all__ = ["FASTConvLayer", "DropPath", "AdaptiveAvgPool2d"]
12
+
13
+
14
+ class DropPath(nn.Module):
15
+ """
16
+ DropPath (Drop Connect) layer. This is a stochastic version of the identity layer.
17
+ """
18
+
19
+ # Borrowed from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
20
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
21
+ super(DropPath, self).__init__()
22
+ self.drop_prob = drop_prob
23
+ self.scale_by_keep = scale_by_keep
24
+
25
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
26
+ if self.drop_prob == 0.0 or not self.training:
27
+ return x
28
+ keep_prob = 1 - self.drop_prob
29
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with different dimensions
30
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
31
+ if keep_prob > 0.0 and self.scale_by_keep:
32
+ random_tensor.div_(keep_prob)
33
+ return x * random_tensor
34
+
35
+
36
+ class AdaptiveAvgPool2d(nn.Module):
37
+ """
38
+ Custom AdaptiveAvgPool2d implementation which is ONNX and `torch.compile` compatible.
39
+
40
+ """
41
+
42
+ def __init__(self, output_size):
43
+ super().__init__()
44
+ self.output_size = output_size
45
+
46
+ def forward(self, x: torch.Tensor):
47
+ H_out, W_out = self.output_size
48
+ N, C, H, W = x.shape
49
+
50
+ out = torch.empty((N, C, H_out, W_out), device=x.device, dtype=x.dtype)
51
+ for oh in range(H_out):
52
+ start_h = (oh * H) // H_out
53
+ end_h = ((oh + 1) * H + H_out - 1) // H_out # ceil((oh+1)*H / H_out)
54
+ for ow in range(W_out):
55
+ start_w = (ow * W) // W_out
56
+ end_w = ((ow + 1) * W + W_out - 1) // W_out # ceil((ow+1)*W / W_out)
57
+ # average over the window
58
+ out[:, :, oh, ow] = x[:, :, start_h:end_h, start_w:end_w].mean(dim=(-2, -1))
59
+ return out
12
60
 
13
61
 
14
62
  class FASTConvLayer(nn.Module):
@@ -103,16 +151,16 @@ class FASTConvLayer(nn.Module):
103
151
  id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
104
152
  self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
105
153
  kernel = self.id_tensor
106
- std = (identity.running_var + identity.eps).sqrt() # type: ignore
154
+ std = (identity.running_var + identity.eps).sqrt()
107
155
  t = (identity.weight / std).reshape(-1, 1, 1, 1)
108
- return kernel * t, identity.bias - identity.running_mean * identity.weight / std
156
+ return kernel * t, identity.bias - identity.running_mean * identity.weight / std # type: ignore[operator]
109
157
 
110
158
  def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
111
159
  kernel = conv.weight
112
160
  kernel = self._pad_to_mxn_tensor(kernel)
113
161
  std = (bn.running_var + bn.eps).sqrt() # type: ignore
114
162
  t = (bn.weight / std).reshape(-1, 1, 1, 1)
115
- return kernel * t, bn.bias - bn.running_mean * bn.weight / std
163
+ return kernel * t, bn.bias - bn.running_mean * bn.weight / std # type: ignore[operator]
116
164
 
117
165
  def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
118
166
  kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -50,8 +50,8 @@ def scaled_dot_product_attention(
50
50
  scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
51
51
  if mask is not None:
52
52
  # NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
53
- scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
54
- p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
53
+ scores = scores.masked_fill(mask == 0, float("-inf"))
54
+ p_attn = torch.softmax(scores, dim=-1)
55
55
  return torch.matmul(p_attn, value), p_attn
56
56
 
57
57
 
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import *
1
+ from .pytorch import *
@@ -1,6 +1 @@
1
- from doctr.file_utils import is_tf_available, is_torch_available
2
-
3
- if is_torch_available():
4
- from .pytorch import *
5
- elif is_tf_available():
6
- from .tensorflow import * # type: ignore[assignment]
1
+ from .pytorch import *