python-doctr 0.8.1__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/__init__.py +1 -1
- doctr/contrib/__init__.py +0 -0
- doctr/contrib/artefacts.py +131 -0
- doctr/contrib/base.py +105 -0
- doctr/datasets/cord.py +10 -1
- doctr/datasets/datasets/pytorch.py +2 -2
- doctr/datasets/funsd.py +11 -1
- doctr/datasets/generator/base.py +6 -5
- doctr/datasets/ic03.py +11 -1
- doctr/datasets/ic13.py +10 -1
- doctr/datasets/iiit5k.py +26 -16
- doctr/datasets/imgur5k.py +11 -2
- doctr/datasets/loader.py +1 -6
- doctr/datasets/sroie.py +11 -1
- doctr/datasets/svhn.py +11 -1
- doctr/datasets/svt.py +11 -1
- doctr/datasets/synthtext.py +11 -1
- doctr/datasets/utils.py +9 -3
- doctr/datasets/vocabs.py +15 -4
- doctr/datasets/wildreceipt.py +12 -1
- doctr/file_utils.py +45 -12
- doctr/io/elements.py +52 -10
- doctr/io/html.py +2 -2
- doctr/io/image/pytorch.py +6 -8
- doctr/io/image/tensorflow.py +1 -1
- doctr/io/pdf.py +5 -2
- doctr/io/reader.py +6 -0
- doctr/models/__init__.py +0 -1
- doctr/models/_utils.py +57 -20
- doctr/models/builder.py +73 -15
- doctr/models/classification/magc_resnet/tensorflow.py +13 -6
- doctr/models/classification/mobilenet/pytorch.py +47 -9
- doctr/models/classification/mobilenet/tensorflow.py +51 -14
- doctr/models/classification/predictor/pytorch.py +28 -17
- doctr/models/classification/predictor/tensorflow.py +26 -16
- doctr/models/classification/resnet/tensorflow.py +21 -8
- doctr/models/classification/textnet/pytorch.py +3 -3
- doctr/models/classification/textnet/tensorflow.py +11 -5
- doctr/models/classification/vgg/tensorflow.py +9 -3
- doctr/models/classification/vit/tensorflow.py +10 -4
- doctr/models/classification/zoo.py +55 -19
- doctr/models/detection/_utils/__init__.py +1 -0
- doctr/models/detection/_utils/base.py +66 -0
- doctr/models/detection/differentiable_binarization/base.py +4 -3
- doctr/models/detection/differentiable_binarization/pytorch.py +2 -2
- doctr/models/detection/differentiable_binarization/tensorflow.py +34 -12
- doctr/models/detection/fast/base.py +6 -5
- doctr/models/detection/fast/pytorch.py +4 -4
- doctr/models/detection/fast/tensorflow.py +15 -12
- doctr/models/detection/linknet/base.py +4 -3
- doctr/models/detection/linknet/tensorflow.py +23 -11
- doctr/models/detection/predictor/pytorch.py +15 -1
- doctr/models/detection/predictor/tensorflow.py +17 -3
- doctr/models/detection/zoo.py +7 -2
- doctr/models/factory/hub.py +8 -18
- doctr/models/kie_predictor/base.py +13 -3
- doctr/models/kie_predictor/pytorch.py +45 -20
- doctr/models/kie_predictor/tensorflow.py +44 -17
- doctr/models/modules/layers/pytorch.py +2 -3
- doctr/models/modules/layers/tensorflow.py +6 -8
- doctr/models/modules/transformer/pytorch.py +2 -2
- doctr/models/modules/transformer/tensorflow.py +0 -2
- doctr/models/modules/vision_transformer/pytorch.py +1 -1
- doctr/models/modules/vision_transformer/tensorflow.py +1 -1
- doctr/models/predictor/base.py +97 -58
- doctr/models/predictor/pytorch.py +35 -20
- doctr/models/predictor/tensorflow.py +35 -18
- doctr/models/preprocessor/pytorch.py +4 -4
- doctr/models/preprocessor/tensorflow.py +3 -2
- doctr/models/recognition/crnn/tensorflow.py +8 -6
- doctr/models/recognition/master/pytorch.py +2 -2
- doctr/models/recognition/master/tensorflow.py +9 -4
- doctr/models/recognition/parseq/pytorch.py +4 -3
- doctr/models/recognition/parseq/tensorflow.py +14 -11
- doctr/models/recognition/sar/pytorch.py +7 -6
- doctr/models/recognition/sar/tensorflow.py +10 -12
- doctr/models/recognition/vitstr/pytorch.py +1 -1
- doctr/models/recognition/vitstr/tensorflow.py +9 -4
- doctr/models/recognition/zoo.py +1 -1
- doctr/models/utils/pytorch.py +1 -1
- doctr/models/utils/tensorflow.py +15 -15
- doctr/models/zoo.py +2 -2
- doctr/py.typed +0 -0
- doctr/transforms/functional/base.py +1 -1
- doctr/transforms/functional/pytorch.py +5 -5
- doctr/transforms/modules/base.py +37 -15
- doctr/transforms/modules/pytorch.py +73 -14
- doctr/transforms/modules/tensorflow.py +78 -19
- doctr/utils/fonts.py +7 -5
- doctr/utils/geometry.py +141 -31
- doctr/utils/metrics.py +34 -175
- doctr/utils/reconstitution.py +212 -0
- doctr/utils/visualization.py +5 -118
- doctr/version.py +1 -1
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/METADATA +85 -81
- python_doctr-0.10.0.dist-info/RECORD +173 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/WHEEL +1 -1
- doctr/models/artefacts/__init__.py +0 -2
- doctr/models/artefacts/barcode.py +0 -74
- doctr/models/artefacts/face.py +0 -63
- doctr/models/obj_detection/__init__.py +0 -1
- doctr/models/obj_detection/faster_rcnn/__init__.py +0 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +0 -81
- python_doctr-0.8.1.dist-info/RECORD +0 -173
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/LICENSE +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/top_level.txt +0 -0
- {python_doctr-0.8.1.dist-info → python_doctr-0.10.0.dist-info}/zip-safe +0 -0
|
@@ -10,10 +10,10 @@ import torch
|
|
|
10
10
|
from torch import nn
|
|
11
11
|
|
|
12
12
|
from doctr.io.elements import Document
|
|
13
|
-
from doctr.models._utils import
|
|
13
|
+
from doctr.models._utils import get_language
|
|
14
14
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
15
15
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
16
|
-
from doctr.utils.geometry import
|
|
16
|
+
from doctr.utils.geometry import detach_scores
|
|
17
17
|
|
|
18
18
|
from .base import _OCRPredictor
|
|
19
19
|
|
|
@@ -55,7 +55,13 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
55
55
|
self.det_predictor = det_predictor.eval() # type: ignore[attr-defined]
|
|
56
56
|
self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined]
|
|
57
57
|
_OCRPredictor.__init__(
|
|
58
|
-
self,
|
|
58
|
+
self,
|
|
59
|
+
assume_straight_pages,
|
|
60
|
+
straighten_pages,
|
|
61
|
+
preserve_aspect_ratio,
|
|
62
|
+
symmetric_pad,
|
|
63
|
+
detect_orientation,
|
|
64
|
+
**kwargs,
|
|
59
65
|
)
|
|
60
66
|
self.detect_orientation = detect_orientation
|
|
61
67
|
self.detect_language = detect_language
|
|
@@ -81,19 +87,19 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
81
87
|
for out_map in out_maps
|
|
82
88
|
]
|
|
83
89
|
if self.detect_orientation:
|
|
84
|
-
|
|
90
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps) # type: ignore[arg-type]
|
|
85
91
|
orientations = [
|
|
86
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
92
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
87
93
|
]
|
|
88
94
|
else:
|
|
89
95
|
orientations = None
|
|
96
|
+
general_pages_orientations = None
|
|
97
|
+
origin_pages_orientations = None
|
|
90
98
|
if self.straighten_pages:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
)
|
|
96
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
99
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations) # type: ignore
|
|
100
|
+
# update page shapes after straightening
|
|
101
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
102
|
+
|
|
97
103
|
# Forward again to get predictions on straight pages
|
|
98
104
|
loc_preds = self.det_predictor(pages, **kwargs)
|
|
99
105
|
|
|
@@ -102,30 +108,37 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
102
108
|
), "Detection Model in ocr_predictor should output only one class"
|
|
103
109
|
|
|
104
110
|
loc_preds = [list(loc_pred.values())[0] for loc_pred in loc_preds]
|
|
111
|
+
# Detach objectness scores from loc_preds
|
|
112
|
+
loc_preds, objectness_scores = detach_scores(loc_preds)
|
|
105
113
|
# Check whether crop mode should be switched to channels first
|
|
106
114
|
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
107
115
|
|
|
108
|
-
# Rectify crops if aspect ratio
|
|
109
|
-
loc_preds = self._remove_padding(pages, loc_preds)
|
|
110
|
-
|
|
111
116
|
# Apply hooks to loc_preds if any
|
|
112
117
|
for hook in self.hooks:
|
|
113
118
|
loc_preds = hook(loc_preds)
|
|
114
119
|
|
|
115
120
|
# Crop images
|
|
116
121
|
crops, loc_preds = self._prepare_crops(
|
|
117
|
-
pages,
|
|
122
|
+
pages, # type: ignore[arg-type]
|
|
118
123
|
loc_preds,
|
|
119
124
|
channels_last=channels_last,
|
|
120
125
|
assume_straight_pages=self.assume_straight_pages,
|
|
126
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
121
127
|
)
|
|
122
|
-
# Rectify crop orientation
|
|
128
|
+
# Rectify crop orientation and get crop orientation predictions
|
|
129
|
+
crop_orientations: Any = []
|
|
123
130
|
if not self.assume_straight_pages:
|
|
124
|
-
crops, loc_preds = self._rectify_crops(crops, loc_preds)
|
|
131
|
+
crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
|
|
132
|
+
crop_orientations = [
|
|
133
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
|
|
134
|
+
]
|
|
135
|
+
|
|
125
136
|
# Identify character sequences
|
|
126
137
|
word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
|
|
138
|
+
if not crop_orientations:
|
|
139
|
+
crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
|
|
127
140
|
|
|
128
|
-
boxes, text_preds = self._process_predictions(loc_preds, word_preds)
|
|
141
|
+
boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
|
|
129
142
|
|
|
130
143
|
if self.detect_language:
|
|
131
144
|
languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
|
|
@@ -134,10 +147,12 @@ class OCRPredictor(nn.Module, _OCRPredictor):
|
|
|
134
147
|
languages_dict = None
|
|
135
148
|
|
|
136
149
|
out = self.doc_builder(
|
|
137
|
-
pages,
|
|
150
|
+
pages, # type: ignore[arg-type]
|
|
138
151
|
boxes,
|
|
152
|
+
objectness_scores,
|
|
139
153
|
text_preds,
|
|
140
|
-
origin_page_shapes,
|
|
154
|
+
origin_page_shapes, # type: ignore[arg-type]
|
|
155
|
+
crop_orientations,
|
|
141
156
|
orientations,
|
|
142
157
|
languages_dict,
|
|
143
158
|
)
|
|
@@ -9,10 +9,10 @@ import numpy as np
|
|
|
9
9
|
import tensorflow as tf
|
|
10
10
|
|
|
11
11
|
from doctr.io.elements import Document
|
|
12
|
-
from doctr.models._utils import
|
|
12
|
+
from doctr.models._utils import get_language
|
|
13
13
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
14
14
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
15
|
-
from doctr.utils.geometry import
|
|
15
|
+
from doctr.utils.geometry import detach_scores
|
|
16
16
|
from doctr.utils.repr import NestedObject
|
|
17
17
|
|
|
18
18
|
from .base import _OCRPredictor
|
|
@@ -56,7 +56,13 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
56
56
|
self.det_predictor = det_predictor
|
|
57
57
|
self.reco_predictor = reco_predictor
|
|
58
58
|
_OCRPredictor.__init__(
|
|
59
|
-
self,
|
|
59
|
+
self,
|
|
60
|
+
assume_straight_pages,
|
|
61
|
+
straighten_pages,
|
|
62
|
+
preserve_aspect_ratio,
|
|
63
|
+
symmetric_pad,
|
|
64
|
+
detect_orientation,
|
|
65
|
+
**kwargs,
|
|
60
66
|
)
|
|
61
67
|
self.detect_orientation = detect_orientation
|
|
62
68
|
self.detect_language = detect_language
|
|
@@ -81,19 +87,19 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
81
87
|
for out_map in out_maps
|
|
82
88
|
]
|
|
83
89
|
if self.detect_orientation:
|
|
84
|
-
|
|
90
|
+
general_pages_orientations, origin_pages_orientations = self._get_orientations(pages, seg_maps)
|
|
85
91
|
orientations = [
|
|
86
|
-
{"value": orientation_page, "confidence": None} for orientation_page in
|
|
92
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_pages_orientations
|
|
87
93
|
]
|
|
88
94
|
else:
|
|
89
95
|
orientations = None
|
|
96
|
+
general_pages_orientations = None
|
|
97
|
+
origin_pages_orientations = None
|
|
90
98
|
if self.straighten_pages:
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
)
|
|
96
|
-
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
99
|
+
pages = self._straighten_pages(pages, seg_maps, general_pages_orientations, origin_pages_orientations)
|
|
100
|
+
# update page shapes after straightening
|
|
101
|
+
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
102
|
+
|
|
97
103
|
# forward again to get predictions on straight pages
|
|
98
104
|
loc_preds_dict = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
99
105
|
|
|
@@ -101,9 +107,8 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
101
107
|
len(loc_pred) == 1 for loc_pred in loc_preds_dict
|
|
102
108
|
), "Detection Model in ocr_predictor should output only one class"
|
|
103
109
|
loc_preds: List[np.ndarray] = [list(loc_pred.values())[0] for loc_pred in loc_preds_dict] # type: ignore[union-attr]
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
loc_preds = self._remove_padding(pages, loc_preds)
|
|
110
|
+
# Detach objectness scores from loc_preds
|
|
111
|
+
loc_preds, objectness_scores = detach_scores(loc_preds)
|
|
107
112
|
|
|
108
113
|
# Apply hooks to loc_preds if any
|
|
109
114
|
for hook in self.hooks:
|
|
@@ -111,16 +116,26 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
111
116
|
|
|
112
117
|
# Crop images
|
|
113
118
|
crops, loc_preds = self._prepare_crops(
|
|
114
|
-
pages,
|
|
119
|
+
pages,
|
|
120
|
+
loc_preds,
|
|
121
|
+
channels_last=True,
|
|
122
|
+
assume_straight_pages=self.assume_straight_pages,
|
|
123
|
+
assume_horizontal=self._page_orientation_disabled,
|
|
115
124
|
)
|
|
116
|
-
# Rectify crop orientation
|
|
125
|
+
# Rectify crop orientation and get crop orientation predictions
|
|
126
|
+
crop_orientations: Any = []
|
|
117
127
|
if not self.assume_straight_pages:
|
|
118
|
-
crops, loc_preds = self._rectify_crops(crops, loc_preds)
|
|
128
|
+
crops, loc_preds, _crop_orientations = self._rectify_crops(crops, loc_preds)
|
|
129
|
+
crop_orientations = [
|
|
130
|
+
{"value": orientation[0], "confidence": orientation[1]} for orientation in _crop_orientations
|
|
131
|
+
]
|
|
119
132
|
|
|
120
133
|
# Identify character sequences
|
|
121
134
|
word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs)
|
|
135
|
+
if not crop_orientations:
|
|
136
|
+
crop_orientations = [{"value": 0, "confidence": None} for _ in word_preds]
|
|
122
137
|
|
|
123
|
-
boxes, text_preds = self._process_predictions(loc_preds, word_preds)
|
|
138
|
+
boxes, text_preds, crop_orientations = self._process_predictions(loc_preds, word_preds, crop_orientations)
|
|
124
139
|
|
|
125
140
|
if self.detect_language:
|
|
126
141
|
languages = [get_language(" ".join([item[0] for item in text_pred])) for text_pred in text_preds]
|
|
@@ -131,8 +146,10 @@ class OCRPredictor(NestedObject, _OCRPredictor):
|
|
|
131
146
|
out = self.doc_builder(
|
|
132
147
|
pages,
|
|
133
148
|
boxes,
|
|
149
|
+
objectness_scores,
|
|
134
150
|
text_preds,
|
|
135
151
|
origin_page_shapes, # type: ignore[arg-type]
|
|
152
|
+
crop_orientations,
|
|
136
153
|
orientations,
|
|
137
154
|
languages_dict,
|
|
138
155
|
)
|
|
@@ -79,7 +79,7 @@ class PreProcessor(nn.Module):
|
|
|
79
79
|
else:
|
|
80
80
|
x = x.to(dtype=torch.float32) # type: ignore[union-attr]
|
|
81
81
|
|
|
82
|
-
return x
|
|
82
|
+
return x
|
|
83
83
|
|
|
84
84
|
def __call__(self, x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]]) -> List[torch.Tensor]:
|
|
85
85
|
"""Prepare document data for model forwarding
|
|
@@ -103,7 +103,7 @@ class PreProcessor(nn.Module):
|
|
|
103
103
|
elif x.dtype not in (torch.uint8, torch.float16, torch.float32):
|
|
104
104
|
raise TypeError("unsupported data type for torch.Tensor")
|
|
105
105
|
# Resizing
|
|
106
|
-
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
|
|
106
|
+
if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]:
|
|
107
107
|
x = F.resize(
|
|
108
108
|
x, self.resize.size, interpolation=self.resize.interpolation, antialias=self.resize.antialias
|
|
109
109
|
)
|
|
@@ -118,11 +118,11 @@ class PreProcessor(nn.Module):
|
|
|
118
118
|
# Sample transform (to tensor, resize)
|
|
119
119
|
samples = list(multithread_exec(self.sample_transforms, x))
|
|
120
120
|
# Batching
|
|
121
|
-
batches = self.batch_inputs(samples)
|
|
121
|
+
batches = self.batch_inputs(samples)
|
|
122
122
|
else:
|
|
123
123
|
raise TypeError(f"invalid input type: {type(x)}")
|
|
124
124
|
|
|
125
125
|
# Batch transforms (normalize)
|
|
126
126
|
batches = list(multithread_exec(self.normalize, batches))
|
|
127
127
|
|
|
128
|
-
return batches
|
|
128
|
+
return batches
|
|
@@ -41,6 +41,7 @@ class PreProcessor(NestedObject):
|
|
|
41
41
|
self.resize = Resize(output_size, **kwargs)
|
|
42
42
|
# Perform the division by 255 at the same time
|
|
43
43
|
self.normalize = Normalize(mean, std)
|
|
44
|
+
self._runs_on_cuda = tf.config.list_physical_devices("GPU") != []
|
|
44
45
|
|
|
45
46
|
def batch_inputs(self, samples: List[tf.Tensor]) -> List[tf.Tensor]:
|
|
46
47
|
"""Gather samples into batches for inference purposes
|
|
@@ -113,13 +114,13 @@ class PreProcessor(NestedObject):
|
|
|
113
114
|
|
|
114
115
|
elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x):
|
|
115
116
|
# Sample transform (to tensor, resize)
|
|
116
|
-
samples = list(multithread_exec(self.sample_transforms, x))
|
|
117
|
+
samples = list(multithread_exec(self.sample_transforms, x, threads=1 if self._runs_on_cuda else None))
|
|
117
118
|
# Batching
|
|
118
119
|
batches = self.batch_inputs(samples)
|
|
119
120
|
else:
|
|
120
121
|
raise TypeError(f"invalid input type: {type(x)}")
|
|
121
122
|
|
|
122
123
|
# Batch transforms (normalize)
|
|
123
|
-
batches = list(multithread_exec(self.normalize, batches))
|
|
124
|
+
batches = list(multithread_exec(self.normalize, batches, threads=1 if self._runs_on_cuda else None))
|
|
124
125
|
|
|
125
126
|
return batches
|
|
@@ -13,7 +13,7 @@ from tensorflow.keras.models import Model, Sequential
|
|
|
13
13
|
from doctr.datasets import VOCABS
|
|
14
14
|
|
|
15
15
|
from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r
|
|
16
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
17
17
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["CRNN", "crnn_vgg16_bn", "crnn_mobilenet_v3_small", "crnn_mobilenet_v3_large"]
|
|
@@ -24,21 +24,21 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
25
25
|
"input_shape": (32, 128, 3),
|
|
26
26
|
"vocab": VOCABS["legacy_french"],
|
|
27
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
27
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_vgg16_bn-9c188f45.weights.h5&src=0",
|
|
28
28
|
},
|
|
29
29
|
"crnn_mobilenet_v3_small": {
|
|
30
30
|
"mean": (0.694, 0.695, 0.693),
|
|
31
31
|
"std": (0.299, 0.296, 0.301),
|
|
32
32
|
"input_shape": (32, 128, 3),
|
|
33
33
|
"vocab": VOCABS["french"],
|
|
34
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
34
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_small-54850265.weights.h5&src=0",
|
|
35
35
|
},
|
|
36
36
|
"crnn_mobilenet_v3_large": {
|
|
37
37
|
"mean": (0.694, 0.695, 0.693),
|
|
38
38
|
"std": (0.299, 0.296, 0.301),
|
|
39
39
|
"input_shape": (32, 128, 3),
|
|
40
40
|
"vocab": VOCABS["french"],
|
|
41
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
41
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/crnn_mobilenet_v3_large-c64045e5.weights.h5&src=0",
|
|
42
42
|
},
|
|
43
43
|
}
|
|
44
44
|
|
|
@@ -128,7 +128,7 @@ class CRNN(RecognitionModel, Model):
|
|
|
128
128
|
|
|
129
129
|
def __init__(
|
|
130
130
|
self,
|
|
131
|
-
feature_extractor:
|
|
131
|
+
feature_extractor: Model,
|
|
132
132
|
vocab: str,
|
|
133
133
|
rnn_units: int = 128,
|
|
134
134
|
exportable: bool = False,
|
|
@@ -245,9 +245,11 @@ def _crnn(
|
|
|
245
245
|
|
|
246
246
|
# Build the model
|
|
247
247
|
model = CRNN(feat_extractor, cfg=_cfg, **kwargs)
|
|
248
|
+
_build_model(model)
|
|
248
249
|
# Load pretrained parameters
|
|
249
250
|
if pretrained:
|
|
250
|
-
|
|
251
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
252
|
+
load_pretrained_params(model, _cfg["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"])
|
|
251
253
|
|
|
252
254
|
return model
|
|
253
255
|
|
|
@@ -107,7 +107,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
107
107
|
# NOTE: nn.TransformerDecoder takes the inverse from this implementation
|
|
108
108
|
# [True, True, True, ..., False, False, False] -> False is masked
|
|
109
109
|
# (N, 1, 1, max_length)
|
|
110
|
-
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
110
|
+
target_pad_mask = (target != self.vocab_size + 2).unsqueeze(1).unsqueeze(1)
|
|
111
111
|
target_length = target.size(1)
|
|
112
112
|
# sub mask filled diagonal with True = see and False = masked (max_length, max_length)
|
|
113
113
|
# NOTE: onnxruntime tril/triu works only with float currently (onnxruntime 1.11.1 - opset 14)
|
|
@@ -142,7 +142,7 @@ class MASTER(_MASTER, nn.Module):
|
|
|
142
142
|
# Input length : number of timesteps
|
|
143
143
|
input_len = model_output.shape[1]
|
|
144
144
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
145
|
-
seq_len = seq_len + 1
|
|
145
|
+
seq_len = seq_len + 1
|
|
146
146
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
147
147
|
# The "masked" first gt char is <sos>. Delete last logit of the model output.
|
|
148
148
|
cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction="none")
|
|
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
|
|
|
13
13
|
from doctr.models.classification import magc_resnet31
|
|
14
14
|
from doctr.models.modules.transformer import Decoder, PositionalEncoding
|
|
15
15
|
|
|
16
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
17
17
|
from .base import _MASTER, _MASTERPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["MASTER", "master"]
|
|
@@ -25,7 +25,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
25
25
|
"std": (0.299, 0.296, 0.301),
|
|
26
26
|
"input_shape": (32, 128, 3),
|
|
27
27
|
"vocab": VOCABS["french"],
|
|
28
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
28
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/master-d7fdaeff.weights.h5&src=0",
|
|
29
29
|
},
|
|
30
30
|
}
|
|
31
31
|
|
|
@@ -51,7 +51,7 @@ class MASTER(_MASTER, Model):
|
|
|
51
51
|
|
|
52
52
|
def __init__(
|
|
53
53
|
self,
|
|
54
|
-
feature_extractor:
|
|
54
|
+
feature_extractor: Model,
|
|
55
55
|
vocab: str,
|
|
56
56
|
d_model: int = 512,
|
|
57
57
|
dff: int = 2048,
|
|
@@ -290,9 +290,14 @@ def _master(arch: str, pretrained: bool, backbone_fn, pretrained_backbone: bool
|
|
|
290
290
|
cfg=_cfg,
|
|
291
291
|
**kwargs,
|
|
292
292
|
)
|
|
293
|
+
_build_model(model)
|
|
294
|
+
|
|
293
295
|
# Load pretrained parameters
|
|
294
296
|
if pretrained:
|
|
295
|
-
|
|
297
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
298
|
+
load_pretrained_params(
|
|
299
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
300
|
+
)
|
|
296
301
|
|
|
297
302
|
return model
|
|
298
303
|
|
|
@@ -212,7 +212,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
212
212
|
|
|
213
213
|
sos_idx = torch.zeros(len(final_perms), 1, device=seqlen.device)
|
|
214
214
|
eos_idx = torch.full((len(final_perms), 1), max_num_chars + 1, device=seqlen.device)
|
|
215
|
-
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
215
|
+
combined = torch.cat([sos_idx, final_perms + 1, eos_idx], dim=1).int()
|
|
216
216
|
if len(combined) > 1:
|
|
217
217
|
combined[1, 1:] = max_num_chars + 1 - torch.arange(max_num_chars + 1, device=seqlen.device)
|
|
218
218
|
return combined
|
|
@@ -282,7 +282,8 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
282
282
|
ys[:, i + 1] = pos_prob.squeeze().argmax(-1)
|
|
283
283
|
|
|
284
284
|
# Stop decoding if all sequences have reached the EOS token
|
|
285
|
-
|
|
285
|
+
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
286
|
+
if not self.exportable and max_len is None and (ys == self.vocab_size).any(dim=-1).all():
|
|
286
287
|
break
|
|
287
288
|
|
|
288
289
|
logits = torch.cat(pos_logits, dim=1) # (N, max_length, vocab_size + 1)
|
|
@@ -297,7 +298,7 @@ class PARSeq(_PARSeq, nn.Module):
|
|
|
297
298
|
|
|
298
299
|
# Create padding mask for refined target input maskes all behind EOS token as False
|
|
299
300
|
# (N, 1, 1, max_length)
|
|
300
|
-
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
301
|
+
target_pad_mask = ~((ys == self.vocab_size).int().cumsum(-1) > 0).unsqueeze(1).unsqueeze(1)
|
|
301
302
|
mask = (target_pad_mask.bool() & query_mask[:, : ys.shape[1]].bool()).int()
|
|
302
303
|
logits = self.head(self.decode(ys, features, mask, target_query=pos_queries))
|
|
303
304
|
|
|
@@ -16,7 +16,7 @@ from doctr.datasets import VOCABS
|
|
|
16
16
|
from doctr.models.modules.transformer import MultiHeadAttention, PositionwiseFeedForward
|
|
17
17
|
|
|
18
18
|
from ...classification import vit_s
|
|
19
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
19
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
20
20
|
from .base import _PARSeq, _PARSeqPostProcessor
|
|
21
21
|
|
|
22
22
|
__all__ = ["PARSeq", "parseq"]
|
|
@@ -27,7 +27,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
27
27
|
"std": (0.299, 0.296, 0.301),
|
|
28
28
|
"input_shape": (32, 128, 3),
|
|
29
29
|
"vocab": VOCABS["french"],
|
|
30
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
30
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/parseq-4152a87e.weights.h5&src=0",
|
|
31
31
|
},
|
|
32
32
|
}
|
|
33
33
|
|
|
@@ -43,7 +43,7 @@ class CharEmbedding(layers.Layer):
|
|
|
43
43
|
|
|
44
44
|
def __init__(self, vocab_size: int, d_model: int):
|
|
45
45
|
super(CharEmbedding, self).__init__()
|
|
46
|
-
self.embedding =
|
|
46
|
+
self.embedding = layers.Embedding(vocab_size, d_model)
|
|
47
47
|
self.d_model = d_model
|
|
48
48
|
|
|
49
49
|
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
@@ -167,7 +167,6 @@ class PARSeq(_PARSeq, Model):
|
|
|
167
167
|
|
|
168
168
|
self.postprocessor = PARSeqPostProcessor(vocab=self.vocab)
|
|
169
169
|
|
|
170
|
-
@tf.function
|
|
171
170
|
def generate_permutations(self, seqlen: tf.Tensor) -> tf.Tensor:
|
|
172
171
|
# Generates permutations of the target sequence.
|
|
173
172
|
# Translated from https://github.com/baudm/parseq/blob/main/strhub/models/parseq/system.py
|
|
@@ -214,7 +213,6 @@ class PARSeq(_PARSeq, Model):
|
|
|
214
213
|
)
|
|
215
214
|
return combined
|
|
216
215
|
|
|
217
|
-
@tf.function
|
|
218
216
|
def generate_permutations_attention_masks(self, permutation: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
219
217
|
# Generate source and target mask for the decoder attention.
|
|
220
218
|
sz = permutation.shape[0]
|
|
@@ -234,11 +232,10 @@ class PARSeq(_PARSeq, Model):
|
|
|
234
232
|
target_mask = mask[1:, :-1]
|
|
235
233
|
return tf.cast(source_mask, dtype=tf.bool), tf.cast(target_mask, dtype=tf.bool)
|
|
236
234
|
|
|
237
|
-
@tf.function
|
|
238
235
|
def decode(
|
|
239
236
|
self,
|
|
240
237
|
target: tf.Tensor,
|
|
241
|
-
memory: tf,
|
|
238
|
+
memory: tf.Tensor,
|
|
242
239
|
target_mask: Optional[tf.Tensor] = None,
|
|
243
240
|
target_query: Optional[tf.Tensor] = None,
|
|
244
241
|
**kwargs: Any,
|
|
@@ -288,10 +285,11 @@ class PARSeq(_PARSeq, Model):
|
|
|
288
285
|
)
|
|
289
286
|
|
|
290
287
|
# Stop decoding if all sequences have reached the EOS token
|
|
291
|
-
#
|
|
288
|
+
# NOTE: `break` isn't correctly translated to Onnx so we don't break here if we want to export
|
|
292
289
|
if (
|
|
293
|
-
|
|
294
|
-
and
|
|
290
|
+
not self.exportable
|
|
291
|
+
and max_len is None
|
|
292
|
+
and tf.reduce_any(tf.reduce_all(tf.equal(ys, tf.constant(self.vocab_size)), axis=-1))
|
|
295
293
|
):
|
|
296
294
|
break
|
|
297
295
|
|
|
@@ -475,9 +473,14 @@ def _parseq(
|
|
|
475
473
|
|
|
476
474
|
# Build the model
|
|
477
475
|
model = PARSeq(feat_extractor, cfg=_cfg, **kwargs)
|
|
476
|
+
_build_model(model)
|
|
477
|
+
|
|
478
478
|
# Load pretrained parameters
|
|
479
479
|
if pretrained:
|
|
480
|
-
|
|
480
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
481
|
+
load_pretrained_params(
|
|
482
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
483
|
+
)
|
|
481
484
|
|
|
482
485
|
return model
|
|
483
486
|
|
|
@@ -125,25 +125,26 @@ class SARDecoder(nn.Module):
|
|
|
125
125
|
if t == 0:
|
|
126
126
|
# step to init the first states of the LSTMCell
|
|
127
127
|
hidden_state_init = cell_state_init = torch.zeros(
|
|
128
|
-
features.size(0), features.size(1), device=features.device
|
|
128
|
+
features.size(0), features.size(1), device=features.device, dtype=features.dtype
|
|
129
129
|
)
|
|
130
130
|
hidden_state, cell_state = hidden_state_init, cell_state_init
|
|
131
131
|
prev_symbol = holistic
|
|
132
132
|
elif t == 1:
|
|
133
133
|
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
|
|
134
134
|
# (N, vocab_size + 1) --> (N, embedding_units)
|
|
135
|
-
prev_symbol = torch.zeros(
|
|
135
|
+
prev_symbol = torch.zeros(
|
|
136
|
+
features.size(0), self.vocab_size + 1, device=features.device, dtype=features.dtype
|
|
137
|
+
)
|
|
136
138
|
prev_symbol = self.embed(prev_symbol)
|
|
137
139
|
else:
|
|
138
|
-
if gt is not None:
|
|
140
|
+
if gt is not None and self.training:
|
|
139
141
|
# (N, embedding_units) -2 because of <bos> and <eos> (same)
|
|
140
142
|
prev_symbol = self.embed(gt_embedding[:, t - 2])
|
|
141
143
|
else:
|
|
142
144
|
# -1 to start at timestep where prev_symbol was initialized
|
|
143
145
|
index = logits_list[t - 1].argmax(-1)
|
|
144
146
|
# update prev_symbol with ones at the index of the previous logit vector
|
|
145
|
-
|
|
146
|
-
prev_symbol = prev_symbol.scatter_(1, index.unsqueeze(1), 1)
|
|
147
|
+
prev_symbol = self.embed(self.embed_tgt(index))
|
|
147
148
|
|
|
148
149
|
# (N, C), (N, C) take the last hidden state and cell state from current timestep
|
|
149
150
|
hidden_state_init, cell_state_init = self.lstm_cell(prev_symbol, (hidden_state_init, cell_state_init))
|
|
@@ -292,7 +293,7 @@ class SAR(nn.Module, RecognitionModel):
|
|
|
292
293
|
# Input length : number of timesteps
|
|
293
294
|
input_len = model_output.shape[1]
|
|
294
295
|
# Add one for additional <eos> token
|
|
295
|
-
seq_len = seq_len + 1
|
|
296
|
+
seq_len = seq_len + 1
|
|
296
297
|
# Compute loss
|
|
297
298
|
# (N, L, vocab_size + 1)
|
|
298
299
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction="none")
|
|
@@ -13,7 +13,7 @@ from doctr.datasets import VOCABS
|
|
|
13
13
|
from doctr.utils.repr import NestedObject
|
|
14
14
|
|
|
15
15
|
from ...classification import resnet31
|
|
16
|
-
from ...utils.tensorflow import _bf16_to_float32, load_pretrained_params
|
|
16
|
+
from ...utils.tensorflow import _bf16_to_float32, _build_model, load_pretrained_params
|
|
17
17
|
from ..core import RecognitionModel, RecognitionPostProcessor
|
|
18
18
|
|
|
19
19
|
__all__ = ["SAR", "sar_resnet31"]
|
|
@@ -24,7 +24,7 @@ default_cfgs: Dict[str, Dict[str, Any]] = {
|
|
|
24
24
|
"std": (0.299, 0.296, 0.301),
|
|
25
25
|
"input_shape": (32, 128, 3),
|
|
26
26
|
"vocab": VOCABS["french"],
|
|
27
|
-
"url": "https://doctr-static.mindee.com/models?id=v0.
|
|
27
|
+
"url": "https://doctr-static.mindee.com/models?id=v0.9.0/sar_resnet31-5a58806c.weights.h5&src=0",
|
|
28
28
|
},
|
|
29
29
|
}
|
|
30
30
|
|
|
@@ -177,23 +177,17 @@ class SARDecoder(layers.Layer, NestedObject):
|
|
|
177
177
|
elif t == 1:
|
|
178
178
|
# step to init a 'blank' sequence of length vocab_size + 1 filled with zeros
|
|
179
179
|
# (N, vocab_size + 1) --> (N, embedding_units)
|
|
180
|
-
prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1])
|
|
180
|
+
prev_symbol = tf.zeros([features.shape[0], self.vocab_size + 1], dtype=features.dtype)
|
|
181
181
|
prev_symbol = self.embed(prev_symbol, **kwargs)
|
|
182
182
|
else:
|
|
183
|
-
if gt is not None:
|
|
183
|
+
if gt is not None and kwargs.get("training", False):
|
|
184
184
|
# (N, embedding_units) -2 because of <bos> and <eos> (same)
|
|
185
185
|
prev_symbol = self.embed(gt_embedding[:, t - 2], **kwargs)
|
|
186
186
|
else:
|
|
187
187
|
# -1 to start at timestep where prev_symbol was initialized
|
|
188
188
|
index = tf.argmax(logits_list[t - 1], axis=-1)
|
|
189
189
|
# update prev_symbol with ones at the index of the previous logit vector
|
|
190
|
-
|
|
191
|
-
index = tf.ones_like(index)
|
|
192
|
-
prev_symbol = tf.scatter_nd(
|
|
193
|
-
tf.expand_dims(index, axis=1),
|
|
194
|
-
prev_symbol,
|
|
195
|
-
tf.constant([features.shape[0], features.shape[-1]], dtype=tf.int64),
|
|
196
|
-
)
|
|
190
|
+
prev_symbol = self.embed(self.embed_tgt(index, **kwargs), **kwargs)
|
|
197
191
|
|
|
198
192
|
# (N, C), (N, C) take the last hidden state and cell state from current timestep
|
|
199
193
|
_, states = self.lstm_cells(prev_symbol, states, **kwargs)
|
|
@@ -398,9 +392,13 @@ def _sar(
|
|
|
398
392
|
|
|
399
393
|
# Build the model
|
|
400
394
|
model = SAR(feat_extractor, cfg=_cfg, **kwargs)
|
|
395
|
+
_build_model(model)
|
|
401
396
|
# Load pretrained parameters
|
|
402
397
|
if pretrained:
|
|
403
|
-
|
|
398
|
+
# The given vocab differs from the pretrained model => skip the mismatching layers for fine tuning
|
|
399
|
+
load_pretrained_params(
|
|
400
|
+
model, default_cfgs[arch]["url"], skip_mismatch=kwargs["vocab"] != default_cfgs[arch]["vocab"]
|
|
401
|
+
)
|
|
404
402
|
|
|
405
403
|
return model
|
|
406
404
|
|
|
@@ -137,7 +137,7 @@ class ViTSTR(_ViTSTR, nn.Module):
|
|
|
137
137
|
# Input length : number of steps
|
|
138
138
|
input_len = model_output.shape[1]
|
|
139
139
|
# Add one for additional <eos> token (sos disappear in shift!)
|
|
140
|
-
seq_len = seq_len + 1
|
|
140
|
+
seq_len = seq_len + 1
|
|
141
141
|
# Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]!
|
|
142
142
|
# The "masked" first gt char is <sos>.
|
|
143
143
|
cce = F.cross_entropy(model_output.permute(0, 2, 1), gt[:, 1:], reduction="none")
|