python-doctr 0.11.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +0 -1
- doctr/datasets/__init__.py +1 -5
- doctr/datasets/coco_text.py +139 -0
- doctr/datasets/cord.py +2 -1
- doctr/datasets/datasets/__init__.py +1 -6
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +2 -2
- doctr/datasets/generator/__init__.py +1 -6
- doctr/datasets/ic03.py +1 -1
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +4 -1
- doctr/datasets/imgur5k.py +9 -2
- doctr/datasets/ocr.py +1 -1
- doctr/datasets/recognition.py +1 -1
- doctr/datasets/svhn.py +1 -1
- doctr/datasets/svt.py +2 -2
- doctr/datasets/synthtext.py +15 -2
- doctr/datasets/utils.py +7 -6
- doctr/datasets/vocabs.py +1100 -54
- doctr/file_utils.py +2 -92
- doctr/io/elements.py +37 -3
- doctr/io/image/__init__.py +1 -7
- doctr/io/image/pytorch.py +1 -1
- doctr/models/_utils.py +4 -4
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/__init__.py +1 -6
- doctr/models/classification/magc_resnet/pytorch.py +3 -4
- doctr/models/classification/mobilenet/__init__.py +1 -6
- doctr/models/classification/mobilenet/pytorch.py +15 -1
- doctr/models/classification/predictor/__init__.py +1 -6
- doctr/models/classification/predictor/pytorch.py +2 -2
- doctr/models/classification/resnet/__init__.py +1 -6
- doctr/models/classification/resnet/pytorch.py +26 -3
- doctr/models/classification/textnet/__init__.py +1 -6
- doctr/models/classification/textnet/pytorch.py +11 -2
- doctr/models/classification/vgg/__init__.py +1 -6
- doctr/models/classification/vgg/pytorch.py +16 -1
- doctr/models/classification/vip/__init__.py +1 -0
- doctr/models/classification/vip/layers/__init__.py +1 -0
- doctr/models/classification/vip/layers/pytorch.py +615 -0
- doctr/models/classification/vip/pytorch.py +505 -0
- doctr/models/classification/vit/__init__.py +1 -6
- doctr/models/classification/vit/pytorch.py +12 -3
- doctr/models/classification/zoo.py +7 -8
- doctr/models/detection/_utils/__init__.py +1 -6
- doctr/models/detection/core.py +1 -1
- doctr/models/detection/differentiable_binarization/__init__.py +1 -6
- doctr/models/detection/differentiable_binarization/base.py +7 -16
- doctr/models/detection/differentiable_binarization/pytorch.py +13 -4
- doctr/models/detection/fast/__init__.py +1 -6
- doctr/models/detection/fast/base.py +6 -17
- doctr/models/detection/fast/pytorch.py +17 -8
- doctr/models/detection/linknet/__init__.py +1 -6
- doctr/models/detection/linknet/base.py +5 -15
- doctr/models/detection/linknet/pytorch.py +12 -3
- doctr/models/detection/predictor/__init__.py +1 -6
- doctr/models/detection/predictor/pytorch.py +1 -1
- doctr/models/detection/zoo.py +15 -32
- doctr/models/factory/hub.py +9 -22
- doctr/models/kie_predictor/__init__.py +1 -6
- doctr/models/kie_predictor/pytorch.py +3 -7
- doctr/models/modules/layers/__init__.py +1 -6
- doctr/models/modules/layers/pytorch.py +52 -4
- doctr/models/modules/transformer/__init__.py +1 -6
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/vision_transformer/__init__.py +1 -6
- doctr/models/predictor/__init__.py +1 -6
- doctr/models/predictor/base.py +3 -8
- doctr/models/predictor/pytorch.py +3 -6
- doctr/models/preprocessor/__init__.py +1 -6
- doctr/models/preprocessor/pytorch.py +27 -32
- doctr/models/recognition/__init__.py +1 -0
- doctr/models/recognition/crnn/__init__.py +1 -6
- doctr/models/recognition/crnn/pytorch.py +16 -7
- doctr/models/recognition/master/__init__.py +1 -6
- doctr/models/recognition/master/pytorch.py +15 -6
- doctr/models/recognition/parseq/__init__.py +1 -6
- doctr/models/recognition/parseq/pytorch.py +26 -8
- doctr/models/recognition/predictor/__init__.py +1 -6
- doctr/models/recognition/predictor/_utils.py +100 -47
- doctr/models/recognition/predictor/pytorch.py +4 -5
- doctr/models/recognition/sar/__init__.py +1 -6
- doctr/models/recognition/sar/pytorch.py +13 -4
- doctr/models/recognition/utils.py +56 -47
- doctr/models/recognition/viptr/__init__.py +1 -0
- doctr/models/recognition/viptr/pytorch.py +277 -0
- doctr/models/recognition/vitstr/__init__.py +1 -6
- doctr/models/recognition/vitstr/pytorch.py +13 -4
- doctr/models/recognition/zoo.py +13 -8
- doctr/models/utils/__init__.py +1 -6
- doctr/models/utils/pytorch.py +29 -19
- doctr/transforms/functional/__init__.py +1 -6
- doctr/transforms/functional/pytorch.py +4 -4
- doctr/transforms/modules/__init__.py +1 -7
- doctr/transforms/modules/base.py +26 -92
- doctr/transforms/modules/pytorch.py +28 -26
- doctr/utils/data.py +1 -1
- doctr/utils/geometry.py +7 -11
- doctr/utils/visualization.py +1 -1
- doctr/version.py +1 -1
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/METADATA +22 -63
- python_doctr-1.0.0.dist-info/RECORD +149 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/WHEEL +1 -1
- doctr/datasets/datasets/tensorflow.py +0 -59
- doctr/datasets/generator/tensorflow.py +0 -58
- doctr/datasets/loader.py +0 -94
- doctr/io/image/tensorflow.py +0 -101
- doctr/models/classification/magc_resnet/tensorflow.py +0 -196
- doctr/models/classification/mobilenet/tensorflow.py +0 -433
- doctr/models/classification/predictor/tensorflow.py +0 -60
- doctr/models/classification/resnet/tensorflow.py +0 -397
- doctr/models/classification/textnet/tensorflow.py +0 -266
- doctr/models/classification/vgg/tensorflow.py +0 -116
- doctr/models/classification/vit/tensorflow.py +0 -192
- doctr/models/detection/_utils/tensorflow.py +0 -34
- doctr/models/detection/differentiable_binarization/tensorflow.py +0 -414
- doctr/models/detection/fast/tensorflow.py +0 -419
- doctr/models/detection/linknet/tensorflow.py +0 -369
- doctr/models/detection/predictor/tensorflow.py +0 -70
- doctr/models/kie_predictor/tensorflow.py +0 -187
- doctr/models/modules/layers/tensorflow.py +0 -171
- doctr/models/modules/transformer/tensorflow.py +0 -235
- doctr/models/modules/vision_transformer/tensorflow.py +0 -100
- doctr/models/predictor/tensorflow.py +0 -155
- doctr/models/preprocessor/tensorflow.py +0 -122
- doctr/models/recognition/crnn/tensorflow.py +0 -308
- doctr/models/recognition/master/tensorflow.py +0 -313
- doctr/models/recognition/parseq/tensorflow.py +0 -508
- doctr/models/recognition/predictor/tensorflow.py +0 -79
- doctr/models/recognition/sar/tensorflow.py +0 -416
- doctr/models/recognition/vitstr/tensorflow.py +0 -278
- doctr/models/utils/tensorflow.py +0 -182
- doctr/transforms/functional/tensorflow.py +0 -254
- doctr/transforms/modules/tensorflow.py +0 -562
- python_doctr-0.11.0.dist-info/RECORD +0 -173
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info/licenses}/LICENSE +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.11.0.dist-info → python_doctr-1.0.0.dist-info}/zip-safe +0 -0
|
@@ -179,6 +179,15 @@ class DBNet(_DBNet, nn.Module):
|
|
|
179
179
|
m.weight.data.fill_(1.0)
|
|
180
180
|
m.bias.data.zero_()
|
|
181
181
|
|
|
182
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
183
|
+
"""Load pretrained parameters onto the model
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
187
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
188
|
+
"""
|
|
189
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
190
|
+
|
|
182
191
|
def forward(
|
|
183
192
|
self,
|
|
184
193
|
x: torch.Tensor,
|
|
@@ -206,7 +215,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
206
215
|
|
|
207
216
|
if target is None or return_preds:
|
|
208
217
|
# Disable for torch.compile compatibility
|
|
209
|
-
@torch.compiler.disable
|
|
218
|
+
@torch.compiler.disable
|
|
210
219
|
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
211
220
|
return [
|
|
212
221
|
dict(zip(self.class_names, preds))
|
|
@@ -252,7 +261,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
252
261
|
prob_map = torch.sigmoid(out_map)
|
|
253
262
|
thresh_map = torch.sigmoid(thresh_map)
|
|
254
263
|
|
|
255
|
-
targets = self.build_target(target, out_map.shape[1:]
|
|
264
|
+
targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
|
|
256
265
|
|
|
257
266
|
seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
|
|
258
267
|
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
|
|
@@ -276,7 +285,7 @@ class DBNet(_DBNet, nn.Module):
|
|
|
276
285
|
dice_map = torch.softmax(out_map, dim=1)
|
|
277
286
|
else:
|
|
278
287
|
# compute binary map instead
|
|
279
|
-
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
288
|
+
dice_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
|
|
280
289
|
# Class reduced
|
|
281
290
|
inter = (seg_mask * dice_map * seg_target).sum((0, 2, 3))
|
|
282
291
|
cardinality = (seg_mask * (dice_map + seg_target)).sum((0, 2, 3))
|
|
@@ -328,7 +337,7 @@ def _dbnet(
|
|
|
328
337
|
_ignore_keys = (
|
|
329
338
|
ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
|
|
330
339
|
)
|
|
331
|
-
|
|
340
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
332
341
|
|
|
333
342
|
return model
|
|
334
343
|
|
|
@@ -56,9 +56,8 @@ class FASTPostProcessor(DetectionPostProcessor):
|
|
|
56
56
|
area = (rect[1][0] + 1) * (1 + rect[1][1])
|
|
57
57
|
length = 2 * (rect[1][0] + rect[1][1]) + 2
|
|
58
58
|
else:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
length = poly.length
|
|
59
|
+
area = cv2.contourArea(points)
|
|
60
|
+
length = cv2.arcLength(points, closed=True)
|
|
62
61
|
distance = area * self.unclip_ratio / length # compute distance to expand polygon
|
|
63
62
|
offset = pyclipper.PyclipperOffset()
|
|
64
63
|
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
@@ -154,14 +153,12 @@ class _FAST(BaseModel):
|
|
|
154
153
|
self,
|
|
155
154
|
target: list[dict[str, np.ndarray]],
|
|
156
155
|
output_shape: tuple[int, int, int],
|
|
157
|
-
channels_last: bool = True,
|
|
158
156
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
159
157
|
"""Build the target, and it's mask to be used from loss computation.
|
|
160
158
|
|
|
161
159
|
Args:
|
|
162
160
|
target: target coming from dataset
|
|
163
161
|
output_shape: shape of the output of the model without batch_size
|
|
164
|
-
channels_last: whether channels are last or not
|
|
165
162
|
|
|
166
163
|
Returns:
|
|
167
164
|
the new formatted target, mask and shrunken text kernel
|
|
@@ -173,10 +170,8 @@ class _FAST(BaseModel):
|
|
|
173
170
|
|
|
174
171
|
h: int
|
|
175
172
|
w: int
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
else:
|
|
179
|
-
num_classes, h, w = output_shape
|
|
173
|
+
|
|
174
|
+
num_classes, h, w = output_shape
|
|
180
175
|
target_shape = (len(target), num_classes, h, w)
|
|
181
176
|
|
|
182
177
|
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
@@ -236,14 +231,8 @@ class _FAST(BaseModel):
|
|
|
236
231
|
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
237
232
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
238
233
|
continue
|
|
239
|
-
cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
234
|
+
cv2.fillPoly(shrunken_kernel[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
240
235
|
# draw the original polygon on the segmentation target
|
|
241
|
-
cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0)
|
|
242
|
-
|
|
243
|
-
# Don't forget to switch back to channel last if Tensorflow is used
|
|
244
|
-
if channels_last:
|
|
245
|
-
seg_target = seg_target.transpose((0, 2, 3, 1))
|
|
246
|
-
seg_mask = seg_mask.transpose((0, 2, 3, 1))
|
|
247
|
-
shrunken_kernel = shrunken_kernel.transpose((0, 2, 3, 1))
|
|
236
|
+
cv2.fillPoly(seg_target[idx, class_idx], [poly.astype(np.int32)], 1.0)
|
|
248
237
|
|
|
249
238
|
return seg_target, seg_mask, shrunken_kernel
|
|
@@ -170,6 +170,15 @@ class FAST(_FAST, nn.Module):
|
|
|
170
170
|
m.weight.data.fill_(1.0)
|
|
171
171
|
m.bias.data.zero_()
|
|
172
172
|
|
|
173
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
174
|
+
"""Load pretrained parameters onto the model
|
|
175
|
+
|
|
176
|
+
Args:
|
|
177
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
178
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
179
|
+
"""
|
|
180
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
181
|
+
|
|
173
182
|
def forward(
|
|
174
183
|
self,
|
|
175
184
|
x: torch.Tensor,
|
|
@@ -197,7 +206,7 @@ class FAST(_FAST, nn.Module):
|
|
|
197
206
|
|
|
198
207
|
if target is None or return_preds:
|
|
199
208
|
# Disable for torch.compile compatibility
|
|
200
|
-
@torch.compiler.disable
|
|
209
|
+
@torch.compiler.disable
|
|
201
210
|
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
202
211
|
return [
|
|
203
212
|
dict(zip(self.class_names, preds))
|
|
@@ -229,7 +238,7 @@ class FAST(_FAST, nn.Module):
|
|
|
229
238
|
Returns:
|
|
230
239
|
A loss tensor
|
|
231
240
|
"""
|
|
232
|
-
targets = self.build_target(target, out_map.shape[1:]
|
|
241
|
+
targets = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
|
|
233
242
|
|
|
234
243
|
seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
|
|
235
244
|
shrunken_kernel = torch.from_numpy(targets[2]).to(out_map.device)
|
|
@@ -294,7 +303,7 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
|
|
|
294
303
|
|
|
295
304
|
for module in model.modules():
|
|
296
305
|
if hasattr(module, "reparameterize_layer"):
|
|
297
|
-
module.reparameterize_layer()
|
|
306
|
+
module.reparameterize_layer() # type: ignore[operator]
|
|
298
307
|
|
|
299
308
|
for name, child in model.named_children():
|
|
300
309
|
if isinstance(child, nn.BatchNorm2d):
|
|
@@ -302,12 +311,12 @@ def reparameterize(model: FAST | nn.Module) -> FAST:
|
|
|
302
311
|
if last_conv is None:
|
|
303
312
|
continue
|
|
304
313
|
conv_w = last_conv.weight
|
|
305
|
-
conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean)
|
|
314
|
+
conv_b = last_conv.bias if last_conv.bias is not None else torch.zeros_like(child.running_mean) # type: ignore[arg-type]
|
|
306
315
|
|
|
307
|
-
factor = child.weight / torch.sqrt(child.running_var + child.eps)
|
|
316
|
+
factor = child.weight / torch.sqrt(child.running_var + child.eps) # type: ignore
|
|
308
317
|
last_conv.weight = nn.Parameter(conv_w * factor.reshape([last_conv.out_channels, 1, 1, 1]))
|
|
309
|
-
last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias)
|
|
310
|
-
model._modules[last_conv_name] = last_conv
|
|
318
|
+
last_conv.bias = nn.Parameter((conv_b - child.running_mean) * factor + child.bias) # type: ignore[operator]
|
|
319
|
+
model._modules[last_conv_name] = last_conv # type: ignore[index]
|
|
311
320
|
model._modules[name] = nn.Identity()
|
|
312
321
|
last_conv = None
|
|
313
322
|
elif isinstance(child, nn.Conv2d):
|
|
@@ -349,7 +358,7 @@ def _fast(
|
|
|
349
358
|
_ignore_keys = (
|
|
350
359
|
ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
|
|
351
360
|
)
|
|
352
|
-
|
|
361
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
353
362
|
|
|
354
363
|
return model
|
|
355
364
|
|
|
@@ -56,9 +56,8 @@ class LinkNetPostProcessor(DetectionPostProcessor):
|
|
|
56
56
|
area = (rect[1][0] + 1) * (1 + rect[1][1])
|
|
57
57
|
length = 2 * (rect[1][0] + rect[1][1]) + 2
|
|
58
58
|
else:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
length = poly.length
|
|
59
|
+
area = cv2.contourArea(points)
|
|
60
|
+
length = cv2.arcLength(points, closed=True)
|
|
62
61
|
distance = area * self.unclip_ratio / length # compute distance to expand polygon
|
|
63
62
|
offset = pyclipper.PyclipperOffset()
|
|
64
63
|
offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
|
|
@@ -157,14 +156,12 @@ class _LinkNet(BaseModel):
|
|
|
157
156
|
self,
|
|
158
157
|
target: list[dict[str, np.ndarray]],
|
|
159
158
|
output_shape: tuple[int, int, int],
|
|
160
|
-
channels_last: bool = True,
|
|
161
159
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
162
160
|
"""Build the target, and it's mask to be used from loss computation.
|
|
163
161
|
|
|
164
162
|
Args:
|
|
165
163
|
target: target coming from dataset
|
|
166
164
|
output_shape: shape of the output of the model without batch_size
|
|
167
|
-
channels_last: whether channels are last or not
|
|
168
165
|
|
|
169
166
|
Returns:
|
|
170
167
|
the new formatted target and the mask
|
|
@@ -176,10 +173,8 @@ class _LinkNet(BaseModel):
|
|
|
176
173
|
|
|
177
174
|
h: int
|
|
178
175
|
w: int
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
else:
|
|
182
|
-
num_classes, h, w = output_shape
|
|
176
|
+
|
|
177
|
+
num_classes, h, w = output_shape
|
|
183
178
|
target_shape = (len(target), num_classes, h, w)
|
|
184
179
|
|
|
185
180
|
seg_target: np.ndarray = np.zeros(target_shape, dtype=np.uint8)
|
|
@@ -238,11 +233,6 @@ class _LinkNet(BaseModel):
|
|
|
238
233
|
if shrunken.shape[0] <= 2 or not Polygon(shrunken).is_valid:
|
|
239
234
|
seg_mask[idx, class_idx, box[1] : box[3] + 1, box[0] : box[2] + 1] = False
|
|
240
235
|
continue
|
|
241
|
-
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
242
|
-
|
|
243
|
-
# Don't forget to switch back to channel last if Tensorflow is used
|
|
244
|
-
if channels_last:
|
|
245
|
-
seg_target = seg_target.transpose((0, 2, 3, 1))
|
|
246
|
-
seg_mask = seg_mask.transpose((0, 2, 3, 1))
|
|
236
|
+
cv2.fillPoly(seg_target[idx, class_idx], [shrunken.astype(np.int32)], 1.0)
|
|
247
237
|
|
|
248
238
|
return seg_target, seg_mask
|
|
@@ -160,6 +160,15 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
160
160
|
m.weight.data.fill_(1.0)
|
|
161
161
|
m.bias.data.zero_()
|
|
162
162
|
|
|
163
|
+
def from_pretrained(self, path_or_url: str, **kwargs: Any) -> None:
|
|
164
|
+
"""Load pretrained parameters onto the model
|
|
165
|
+
|
|
166
|
+
Args:
|
|
167
|
+
path_or_url: the path or URL to the model parameters (checkpoint)
|
|
168
|
+
**kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params`
|
|
169
|
+
"""
|
|
170
|
+
load_pretrained_params(self, path_or_url, **kwargs)
|
|
171
|
+
|
|
163
172
|
def forward(
|
|
164
173
|
self,
|
|
165
174
|
x: torch.Tensor,
|
|
@@ -184,7 +193,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
184
193
|
|
|
185
194
|
if target is None or return_preds:
|
|
186
195
|
# Disable for torch.compile compatibility
|
|
187
|
-
@torch.compiler.disable
|
|
196
|
+
@torch.compiler.disable
|
|
188
197
|
def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]:
|
|
189
198
|
return [
|
|
190
199
|
dict(zip(self.class_names, preds))
|
|
@@ -221,7 +230,7 @@ class LinkNet(nn.Module, _LinkNet):
|
|
|
221
230
|
Returns:
|
|
222
231
|
A loss tensor
|
|
223
232
|
"""
|
|
224
|
-
_target, _mask = self.build_target(target, out_map.shape[1:]
|
|
233
|
+
_target, _mask = self.build_target(target, out_map.shape[1:]) # type: ignore[arg-type]
|
|
225
234
|
|
|
226
235
|
seg_target, seg_mask = torch.from_numpy(_target).to(dtype=out_map.dtype), torch.from_numpy(_mask)
|
|
227
236
|
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
|
|
@@ -282,7 +291,7 @@ def _linknet(
|
|
|
282
291
|
_ignore_keys = (
|
|
283
292
|
ignore_keys if kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]) else None
|
|
284
293
|
)
|
|
285
|
-
|
|
294
|
+
model.from_pretrained(default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
|
|
286
295
|
|
|
287
296
|
return model
|
|
288
297
|
|
|
@@ -36,7 +36,7 @@ class DetectionPredictor(nn.Module):
|
|
|
36
36
|
@torch.inference_mode()
|
|
37
37
|
def forward(
|
|
38
38
|
self,
|
|
39
|
-
pages: list[np.ndarray
|
|
39
|
+
pages: list[np.ndarray],
|
|
40
40
|
return_maps: bool = False,
|
|
41
41
|
**kwargs: Any,
|
|
42
42
|
) -> list[dict[str, np.ndarray]] | tuple[list[dict[str, np.ndarray]], list[np.ndarray]]:
|
doctr/models/detection/zoo.py
CHANGED
|
@@ -5,7 +5,7 @@
|
|
|
5
5
|
|
|
6
6
|
from typing import Any
|
|
7
7
|
|
|
8
|
-
from doctr.
|
|
8
|
+
from doctr.models.utils import _CompiledModule
|
|
9
9
|
|
|
10
10
|
from .. import detection
|
|
11
11
|
from ..detection.fast import reparameterize
|
|
@@ -16,30 +16,17 @@ __all__ = ["detection_predictor"]
|
|
|
16
16
|
|
|
17
17
|
ARCHS: list[str]
|
|
18
18
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
]
|
|
31
|
-
elif is_torch_available():
|
|
32
|
-
ARCHS = [
|
|
33
|
-
"db_resnet34",
|
|
34
|
-
"db_resnet50",
|
|
35
|
-
"db_mobilenet_v3_large",
|
|
36
|
-
"linknet_resnet18",
|
|
37
|
-
"linknet_resnet34",
|
|
38
|
-
"linknet_resnet50",
|
|
39
|
-
"fast_tiny",
|
|
40
|
-
"fast_small",
|
|
41
|
-
"fast_base",
|
|
42
|
-
]
|
|
19
|
+
ARCHS = [
|
|
20
|
+
"db_resnet34",
|
|
21
|
+
"db_resnet50",
|
|
22
|
+
"db_mobilenet_v3_large",
|
|
23
|
+
"linknet_resnet18",
|
|
24
|
+
"linknet_resnet34",
|
|
25
|
+
"linknet_resnet50",
|
|
26
|
+
"fast_tiny",
|
|
27
|
+
"fast_small",
|
|
28
|
+
"fast_base",
|
|
29
|
+
]
|
|
43
30
|
|
|
44
31
|
|
|
45
32
|
def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
|
|
@@ -56,12 +43,8 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
|
|
|
56
43
|
if isinstance(_model, detection.FAST):
|
|
57
44
|
_model = reparameterize(_model)
|
|
58
45
|
else:
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
# Adding the type for torch compiled models to the allowed architectures
|
|
62
|
-
from doctr.models.utils import _CompiledModule
|
|
63
|
-
|
|
64
|
-
allowed_archs.append(_CompiledModule)
|
|
46
|
+
# Adding the type for torch compiled models to the allowed architectures
|
|
47
|
+
allowed_archs = [detection.DBNet, detection.LinkNet, detection.FAST, _CompiledModule]
|
|
65
48
|
|
|
66
49
|
if not isinstance(arch, tuple(allowed_archs)):
|
|
67
50
|
raise ValueError(f"unknown architecture: {type(arch)}")
|
|
@@ -76,7 +59,7 @@ def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True,
|
|
|
76
59
|
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
|
|
77
60
|
kwargs["batch_size"] = kwargs.get("batch_size", 2)
|
|
78
61
|
predictor = DetectionPredictor(
|
|
79
|
-
PreProcessor(_model.cfg["input_shape"][
|
|
62
|
+
PreProcessor(_model.cfg["input_shape"][1:], **kwargs),
|
|
80
63
|
_model,
|
|
81
64
|
)
|
|
82
65
|
return predictor
|
doctr/models/factory/hub.py
CHANGED
|
@@ -13,6 +13,7 @@ import textwrap
|
|
|
13
13
|
from pathlib import Path
|
|
14
14
|
from typing import Any
|
|
15
15
|
|
|
16
|
+
import torch
|
|
16
17
|
from huggingface_hub import (
|
|
17
18
|
HfApi,
|
|
18
19
|
Repository,
|
|
@@ -23,10 +24,6 @@ from huggingface_hub import (
|
|
|
23
24
|
)
|
|
24
25
|
|
|
25
26
|
from doctr import models
|
|
26
|
-
from doctr.file_utils import is_tf_available, is_torch_available
|
|
27
|
-
|
|
28
|
-
if is_torch_available():
|
|
29
|
-
import torch
|
|
30
27
|
|
|
31
28
|
__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]
|
|
32
29
|
|
|
@@ -61,19 +58,14 @@ def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task
|
|
|
61
58
|
"""Save model and config to disk for pushing to huggingface hub
|
|
62
59
|
|
|
63
60
|
Args:
|
|
64
|
-
model:
|
|
61
|
+
model: PyTorch model to be saved
|
|
65
62
|
save_dir: directory to save model and config
|
|
66
63
|
arch: architecture name
|
|
67
64
|
task: task name
|
|
68
65
|
"""
|
|
69
66
|
save_directory = Path(save_dir)
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
weights_path = save_directory / "pytorch_model.bin"
|
|
73
|
-
torch.save(model.state_dict(), weights_path)
|
|
74
|
-
elif is_tf_available():
|
|
75
|
-
weights_path = save_directory / "tf_model.weights.h5"
|
|
76
|
-
model.save_weights(str(weights_path))
|
|
67
|
+
weights_path = save_directory / "pytorch_model.bin"
|
|
68
|
+
torch.save(model.state_dict(), weights_path)
|
|
77
69
|
|
|
78
70
|
config_path = save_directory / "config.json"
|
|
79
71
|
|
|
@@ -96,7 +88,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
96
88
|
>>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
|
|
97
89
|
|
|
98
90
|
Args:
|
|
99
|
-
model:
|
|
91
|
+
model: PyTorch model to be saved
|
|
100
92
|
model_name: name of the model which is also the repository name
|
|
101
93
|
task: task name
|
|
102
94
|
**kwargs: keyword arguments for push_to_hf_hub
|
|
@@ -120,7 +112,7 @@ def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: #
|
|
|
120
112
|
<img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
|
|
121
113
|
</p>
|
|
122
114
|
|
|
123
|
-
**Optical Character Recognition made seamless & accessible to anyone, powered by
|
|
115
|
+
**Optical Character Recognition made seamless & accessible to anyone, powered by PyTorch**
|
|
124
116
|
|
|
125
117
|
## Task: {task}
|
|
126
118
|
|
|
@@ -214,13 +206,8 @@ def from_hub(repo_id: str, **kwargs: Any):
|
|
|
214
206
|
|
|
215
207
|
# update model cfg
|
|
216
208
|
model.cfg = cfg
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
|
|
221
|
-
model.load_state_dict(state_dict)
|
|
222
|
-
else: # tf
|
|
223
|
-
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
|
|
224
|
-
model.load_weights(weights)
|
|
209
|
+
# load the weights
|
|
210
|
+
weights = hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs)
|
|
211
|
+
model.from_pretrained(weights)
|
|
225
212
|
|
|
226
213
|
return model
|
|
@@ -68,14 +68,14 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
68
68
|
@torch.inference_mode()
|
|
69
69
|
def forward(
|
|
70
70
|
self,
|
|
71
|
-
pages: list[np.ndarray
|
|
71
|
+
pages: list[np.ndarray],
|
|
72
72
|
**kwargs: Any,
|
|
73
73
|
) -> Document:
|
|
74
74
|
# Dimension check
|
|
75
75
|
if any(page.ndim != 3 for page in pages):
|
|
76
76
|
raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.")
|
|
77
77
|
|
|
78
|
-
origin_page_shapes = [page.shape[:2]
|
|
78
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
79
79
|
|
|
80
80
|
# Localize text elements
|
|
81
81
|
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
@@ -113,9 +113,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
113
113
|
dict_loc_preds[class_name] = _loc_preds
|
|
114
114
|
objectness_scores[class_name] = _scores
|
|
115
115
|
|
|
116
|
-
# Check whether crop mode should be switched to channels first
|
|
117
|
-
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
118
|
-
|
|
119
116
|
# Apply hooks to loc_preds if any
|
|
120
117
|
for hook in self.hooks:
|
|
121
118
|
dict_loc_preds = hook(dict_loc_preds)
|
|
@@ -126,7 +123,6 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
126
123
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
127
124
|
pages,
|
|
128
125
|
dict_loc_preds[class_name],
|
|
129
|
-
channels_last=channels_last,
|
|
130
126
|
assume_straight_pages=self.assume_straight_pages,
|
|
131
127
|
assume_horizontal=self._page_orientation_disabled,
|
|
132
128
|
)
|
|
@@ -173,7 +169,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
173
169
|
boxes_per_page,
|
|
174
170
|
objectness_scores_per_page,
|
|
175
171
|
text_preds_per_page,
|
|
176
|
-
origin_page_shapes,
|
|
172
|
+
origin_page_shapes,
|
|
177
173
|
crop_orientations_per_page,
|
|
178
174
|
orientations,
|
|
179
175
|
languages_dict,
|
|
@@ -8,7 +8,55 @@ import numpy as np
|
|
|
8
8
|
import torch
|
|
9
9
|
import torch.nn as nn
|
|
10
10
|
|
|
11
|
-
__all__ = ["FASTConvLayer"]
|
|
11
|
+
__all__ = ["FASTConvLayer", "DropPath", "AdaptiveAvgPool2d"]
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DropPath(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
DropPath (Drop Connect) layer. This is a stochastic version of the identity layer.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
# Borrowed from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/drop.py
|
|
20
|
+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
|
|
21
|
+
super(DropPath, self).__init__()
|
|
22
|
+
self.drop_prob = drop_prob
|
|
23
|
+
self.scale_by_keep = scale_by_keep
|
|
24
|
+
|
|
25
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
26
|
+
if self.drop_prob == 0.0 or not self.training:
|
|
27
|
+
return x
|
|
28
|
+
keep_prob = 1 - self.drop_prob
|
|
29
|
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with different dimensions
|
|
30
|
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
|
31
|
+
if keep_prob > 0.0 and self.scale_by_keep:
|
|
32
|
+
random_tensor.div_(keep_prob)
|
|
33
|
+
return x * random_tensor
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AdaptiveAvgPool2d(nn.Module):
|
|
37
|
+
"""
|
|
38
|
+
Custom AdaptiveAvgPool2d implementation which is ONNX and `torch.compile` compatible.
|
|
39
|
+
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def __init__(self, output_size):
|
|
43
|
+
super().__init__()
|
|
44
|
+
self.output_size = output_size
|
|
45
|
+
|
|
46
|
+
def forward(self, x: torch.Tensor):
|
|
47
|
+
H_out, W_out = self.output_size
|
|
48
|
+
N, C, H, W = x.shape
|
|
49
|
+
|
|
50
|
+
out = torch.empty((N, C, H_out, W_out), device=x.device, dtype=x.dtype)
|
|
51
|
+
for oh in range(H_out):
|
|
52
|
+
start_h = (oh * H) // H_out
|
|
53
|
+
end_h = ((oh + 1) * H + H_out - 1) // H_out # ceil((oh+1)*H / H_out)
|
|
54
|
+
for ow in range(W_out):
|
|
55
|
+
start_w = (ow * W) // W_out
|
|
56
|
+
end_w = ((ow + 1) * W + W_out - 1) // W_out # ceil((ow+1)*W / W_out)
|
|
57
|
+
# average over the window
|
|
58
|
+
out[:, :, oh, ow] = x[:, :, start_h:end_h, start_w:end_w].mean(dim=(-2, -1))
|
|
59
|
+
return out
|
|
12
60
|
|
|
13
61
|
|
|
14
62
|
class FASTConvLayer(nn.Module):
|
|
@@ -103,16 +151,16 @@ class FASTConvLayer(nn.Module):
|
|
|
103
151
|
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
|
|
104
152
|
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
105
153
|
kernel = self.id_tensor
|
|
106
|
-
std = (identity.running_var + identity.eps).sqrt()
|
|
154
|
+
std = (identity.running_var + identity.eps).sqrt()
|
|
107
155
|
t = (identity.weight / std).reshape(-1, 1, 1, 1)
|
|
108
|
-
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
|
|
156
|
+
return kernel * t, identity.bias - identity.running_mean * identity.weight / std # type: ignore[operator]
|
|
109
157
|
|
|
110
158
|
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> tuple[torch.Tensor, torch.Tensor]:
|
|
111
159
|
kernel = conv.weight
|
|
112
160
|
kernel = self._pad_to_mxn_tensor(kernel)
|
|
113
161
|
std = (bn.running_var + bn.eps).sqrt() # type: ignore
|
|
114
162
|
t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
|
115
|
-
return kernel * t, bn.bias - bn.running_mean * bn.weight / std
|
|
163
|
+
return kernel * t, bn.bias - bn.running_mean * bn.weight / std # type: ignore[operator]
|
|
116
164
|
|
|
117
165
|
def _get_equivalent_kernel_bias(self) -> tuple[torch.Tensor, torch.Tensor]:
|
|
118
166
|
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
|
|
@@ -50,8 +50,8 @@ def scaled_dot_product_attention(
|
|
|
50
50
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
|
51
51
|
if mask is not None:
|
|
52
52
|
# NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
|
|
53
|
-
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
54
|
-
p_attn = torch.softmax(scores, dim=-1)
|
|
53
|
+
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
54
|
+
p_attn = torch.softmax(scores, dim=-1)
|
|
55
55
|
return torch.matmul(p_attn, value), p_attn
|
|
56
56
|
|
|
57
57
|
|