python-doctr 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- doctr/datasets/__init__.py +2 -0
- doctr/datasets/cord.py +6 -4
- doctr/datasets/datasets/base.py +3 -2
- doctr/datasets/datasets/pytorch.py +4 -2
- doctr/datasets/datasets/tensorflow.py +4 -2
- doctr/datasets/detection.py +6 -3
- doctr/datasets/doc_artefacts.py +2 -1
- doctr/datasets/funsd.py +7 -8
- doctr/datasets/generator/base.py +3 -2
- doctr/datasets/generator/pytorch.py +3 -1
- doctr/datasets/generator/tensorflow.py +3 -1
- doctr/datasets/ic03.py +3 -2
- doctr/datasets/ic13.py +2 -1
- doctr/datasets/iiit5k.py +6 -4
- doctr/datasets/iiithws.py +2 -1
- doctr/datasets/imgur5k.py +3 -2
- doctr/datasets/loader.py +4 -2
- doctr/datasets/mjsynth.py +2 -1
- doctr/datasets/ocr.py +2 -1
- doctr/datasets/orientation.py +40 -0
- doctr/datasets/recognition.py +3 -2
- doctr/datasets/sroie.py +2 -1
- doctr/datasets/svhn.py +2 -1
- doctr/datasets/svt.py +3 -2
- doctr/datasets/synthtext.py +2 -1
- doctr/datasets/utils.py +27 -11
- doctr/datasets/vocabs.py +26 -1
- doctr/datasets/wildreceipt.py +111 -0
- doctr/file_utils.py +3 -1
- doctr/io/elements.py +52 -35
- doctr/io/html.py +5 -3
- doctr/io/image/base.py +5 -4
- doctr/io/image/pytorch.py +12 -7
- doctr/io/image/tensorflow.py +11 -6
- doctr/io/pdf.py +5 -4
- doctr/io/reader.py +13 -5
- doctr/models/_utils.py +30 -53
- doctr/models/artefacts/barcode.py +4 -3
- doctr/models/artefacts/face.py +4 -2
- doctr/models/builder.py +58 -43
- doctr/models/classification/__init__.py +1 -0
- doctr/models/classification/magc_resnet/pytorch.py +5 -2
- doctr/models/classification/magc_resnet/tensorflow.py +5 -2
- doctr/models/classification/mobilenet/pytorch.py +16 -4
- doctr/models/classification/mobilenet/tensorflow.py +29 -20
- doctr/models/classification/predictor/pytorch.py +3 -2
- doctr/models/classification/predictor/tensorflow.py +2 -1
- doctr/models/classification/resnet/pytorch.py +23 -13
- doctr/models/classification/resnet/tensorflow.py +33 -26
- doctr/models/classification/textnet/__init__.py +6 -0
- doctr/models/classification/textnet/pytorch.py +275 -0
- doctr/models/classification/textnet/tensorflow.py +267 -0
- doctr/models/classification/vgg/pytorch.py +4 -2
- doctr/models/classification/vgg/tensorflow.py +5 -2
- doctr/models/classification/vit/pytorch.py +9 -3
- doctr/models/classification/vit/tensorflow.py +9 -3
- doctr/models/classification/zoo.py +7 -2
- doctr/models/core.py +1 -1
- doctr/models/detection/__init__.py +1 -0
- doctr/models/detection/_utils/pytorch.py +7 -1
- doctr/models/detection/_utils/tensorflow.py +7 -3
- doctr/models/detection/core.py +9 -3
- doctr/models/detection/differentiable_binarization/base.py +37 -25
- doctr/models/detection/differentiable_binarization/pytorch.py +80 -104
- doctr/models/detection/differentiable_binarization/tensorflow.py +74 -55
- doctr/models/detection/fast/__init__.py +6 -0
- doctr/models/detection/fast/base.py +256 -0
- doctr/models/detection/fast/pytorch.py +442 -0
- doctr/models/detection/fast/tensorflow.py +428 -0
- doctr/models/detection/linknet/base.py +12 -5
- doctr/models/detection/linknet/pytorch.py +28 -15
- doctr/models/detection/linknet/tensorflow.py +68 -88
- doctr/models/detection/predictor/pytorch.py +16 -6
- doctr/models/detection/predictor/tensorflow.py +13 -5
- doctr/models/detection/zoo.py +19 -16
- doctr/models/factory/hub.py +20 -10
- doctr/models/kie_predictor/base.py +2 -1
- doctr/models/kie_predictor/pytorch.py +28 -36
- doctr/models/kie_predictor/tensorflow.py +27 -27
- doctr/models/modules/__init__.py +1 -0
- doctr/models/modules/layers/__init__.py +6 -0
- doctr/models/modules/layers/pytorch.py +166 -0
- doctr/models/modules/layers/tensorflow.py +175 -0
- doctr/models/modules/transformer/pytorch.py +24 -22
- doctr/models/modules/transformer/tensorflow.py +6 -4
- doctr/models/modules/vision_transformer/pytorch.py +2 -4
- doctr/models/modules/vision_transformer/tensorflow.py +2 -4
- doctr/models/obj_detection/faster_rcnn/pytorch.py +4 -2
- doctr/models/predictor/base.py +14 -3
- doctr/models/predictor/pytorch.py +26 -29
- doctr/models/predictor/tensorflow.py +25 -22
- doctr/models/preprocessor/pytorch.py +14 -9
- doctr/models/preprocessor/tensorflow.py +10 -5
- doctr/models/recognition/core.py +4 -1
- doctr/models/recognition/crnn/pytorch.py +23 -16
- doctr/models/recognition/crnn/tensorflow.py +25 -17
- doctr/models/recognition/master/base.py +4 -1
- doctr/models/recognition/master/pytorch.py +20 -9
- doctr/models/recognition/master/tensorflow.py +20 -8
- doctr/models/recognition/parseq/base.py +4 -1
- doctr/models/recognition/parseq/pytorch.py +28 -22
- doctr/models/recognition/parseq/tensorflow.py +22 -11
- doctr/models/recognition/predictor/_utils.py +3 -2
- doctr/models/recognition/predictor/pytorch.py +3 -2
- doctr/models/recognition/predictor/tensorflow.py +2 -1
- doctr/models/recognition/sar/pytorch.py +14 -7
- doctr/models/recognition/sar/tensorflow.py +23 -14
- doctr/models/recognition/utils.py +5 -1
- doctr/models/recognition/vitstr/base.py +4 -1
- doctr/models/recognition/vitstr/pytorch.py +22 -13
- doctr/models/recognition/vitstr/tensorflow.py +21 -10
- doctr/models/recognition/zoo.py +4 -2
- doctr/models/utils/pytorch.py +24 -6
- doctr/models/utils/tensorflow.py +22 -3
- doctr/models/zoo.py +21 -3
- doctr/transforms/functional/base.py +8 -3
- doctr/transforms/functional/pytorch.py +23 -6
- doctr/transforms/functional/tensorflow.py +25 -5
- doctr/transforms/modules/base.py +12 -5
- doctr/transforms/modules/pytorch.py +10 -12
- doctr/transforms/modules/tensorflow.py +17 -9
- doctr/utils/common_types.py +1 -1
- doctr/utils/data.py +4 -2
- doctr/utils/fonts.py +3 -2
- doctr/utils/geometry.py +95 -26
- doctr/utils/metrics.py +36 -22
- doctr/utils/multithreading.py +5 -3
- doctr/utils/repr.py +3 -1
- doctr/utils/visualization.py +31 -8
- doctr/version.py +1 -1
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/METADATA +67 -31
- python_doctr-0.8.1.dist-info/RECORD +173 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/WHEEL +1 -1
- python_doctr-0.7.0.dist-info/RECORD +0 -161
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/LICENSE +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/top_level.txt +0 -0
- {python_doctr-0.7.0.dist-info → python_doctr-0.8.1.dist-info}/zip-safe +0 -0
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -13,7 +13,7 @@ from doctr.io.elements import Document
|
|
|
13
13
|
from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
|
|
14
14
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
15
15
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
16
|
-
from doctr.utils.geometry import
|
|
16
|
+
from doctr.utils.geometry import rotate_image
|
|
17
17
|
|
|
18
18
|
from .base import _KIEPredictor
|
|
19
19
|
|
|
@@ -24,6 +24,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
+
----
|
|
27
28
|
det_predictor: detection module
|
|
28
29
|
reco_predictor: recognition module
|
|
29
30
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -35,7 +36,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
35
36
|
page. Doing so will slightly deteriorate the overall latency.
|
|
36
37
|
detect_language: if True, the language prediction will be added to the predictions for each
|
|
37
38
|
page. Doing so will slightly deteriorate the overall latency.
|
|
38
|
-
kwargs: keyword args of `DocumentBuilder`
|
|
39
|
+
**kwargs: keyword args of `DocumentBuilder`
|
|
39
40
|
"""
|
|
40
41
|
|
|
41
42
|
def __init__(
|
|
@@ -59,7 +60,7 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
59
60
|
self.detect_orientation = detect_orientation
|
|
60
61
|
self.detect_language = detect_language
|
|
61
62
|
|
|
62
|
-
@torch.
|
|
63
|
+
@torch.inference_mode()
|
|
63
64
|
def forward(
|
|
64
65
|
self,
|
|
65
66
|
pages: List[Union[np.ndarray, torch.Tensor]],
|
|
@@ -71,11 +72,20 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
71
72
|
|
|
72
73
|
origin_page_shapes = [page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] for page in pages]
|
|
73
74
|
|
|
75
|
+
# Localize text elements
|
|
76
|
+
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
77
|
+
|
|
74
78
|
# Detect document rotation and rotate pages
|
|
79
|
+
seg_maps = [
|
|
80
|
+
np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
|
|
81
|
+
np.uint8
|
|
82
|
+
)
|
|
83
|
+
for out_map in out_maps
|
|
84
|
+
]
|
|
75
85
|
if self.detect_orientation:
|
|
76
|
-
origin_page_orientations = [estimate_orientation(
|
|
86
|
+
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
77
87
|
orientations = [
|
|
78
|
-
{"value": orientation_page, "confidence":
|
|
88
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
|
|
79
89
|
]
|
|
80
90
|
else:
|
|
81
91
|
orientations = None
|
|
@@ -83,29 +93,28 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
83
93
|
origin_page_orientations = (
|
|
84
94
|
origin_page_orientations
|
|
85
95
|
if self.detect_orientation
|
|
86
|
-
else [estimate_orientation(
|
|
96
|
+
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
87
97
|
)
|
|
88
|
-
pages = [
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
]
|
|
98
|
+
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
99
|
+
# Forward again to get predictions on straight pages
|
|
100
|
+
loc_preds = self.det_predictor(pages, **kwargs)
|
|
92
101
|
|
|
93
|
-
# Localize text elements
|
|
94
|
-
loc_preds = self.det_predictor(pages, **kwargs)
|
|
95
102
|
dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore[assignment]
|
|
96
103
|
# Check whether crop mode should be switched to channels first
|
|
97
104
|
channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray)
|
|
98
105
|
|
|
99
106
|
# Rectify crops if aspect ratio
|
|
100
|
-
dict_loc_preds = {
|
|
101
|
-
|
|
102
|
-
|
|
107
|
+
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
|
|
108
|
+
|
|
109
|
+
# Apply hooks to loc_preds if any
|
|
110
|
+
for hook in self.hooks:
|
|
111
|
+
dict_loc_preds = hook(dict_loc_preds)
|
|
103
112
|
|
|
104
113
|
# Crop images
|
|
105
114
|
crops = {}
|
|
106
115
|
for class_name in dict_loc_preds.keys():
|
|
107
116
|
crops[class_name], dict_loc_preds[class_name] = self._prepare_crops(
|
|
108
|
-
pages,
|
|
117
|
+
pages,
|
|
109
118
|
dict_loc_preds[class_name],
|
|
110
119
|
channels_last=channels_last,
|
|
111
120
|
assume_straight_pages=self.assume_straight_pages,
|
|
@@ -136,29 +145,12 @@ class KIEPredictor(nn.Module, _KIEPredictor):
|
|
|
136
145
|
languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
|
|
137
146
|
else:
|
|
138
147
|
languages_dict = None
|
|
139
|
-
# Rotate back pages and boxes while keeping original image size
|
|
140
|
-
if self.straighten_pages:
|
|
141
|
-
boxes_per_page = [
|
|
142
|
-
{
|
|
143
|
-
k: rotate_boxes(
|
|
144
|
-
page_boxes,
|
|
145
|
-
angle,
|
|
146
|
-
orig_shape=page.shape[:2]
|
|
147
|
-
if isinstance(page, np.ndarray)
|
|
148
|
-
else page.shape[1:], # type: ignore[arg-type]
|
|
149
|
-
target_shape=mask, # type: ignore[arg-type]
|
|
150
|
-
)
|
|
151
|
-
for k, page_boxes in page_boxes_dict.items()
|
|
152
|
-
}
|
|
153
|
-
for page_boxes_dict, page, angle, mask in zip(
|
|
154
|
-
boxes_per_page, pages, origin_page_orientations, origin_page_shapes
|
|
155
|
-
)
|
|
156
|
-
]
|
|
157
148
|
|
|
158
149
|
out = self.doc_builder(
|
|
150
|
+
pages,
|
|
159
151
|
boxes_per_page,
|
|
160
152
|
text_preds_per_page,
|
|
161
|
-
|
|
153
|
+
origin_page_shapes,
|
|
162
154
|
orientations,
|
|
163
155
|
languages_dict,
|
|
164
156
|
)
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -12,7 +12,7 @@ from doctr.io.elements import Document
|
|
|
12
12
|
from doctr.models._utils import estimate_orientation, get_language, invert_data_structure
|
|
13
13
|
from doctr.models.detection.predictor import DetectionPredictor
|
|
14
14
|
from doctr.models.recognition.predictor import RecognitionPredictor
|
|
15
|
-
from doctr.utils.geometry import
|
|
15
|
+
from doctr.utils.geometry import rotate_image
|
|
16
16
|
from doctr.utils.repr import NestedObject
|
|
17
17
|
|
|
18
18
|
from .base import _KIEPredictor
|
|
@@ -24,6 +24,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
24
24
|
"""Implements an object able to localize and identify text elements in a set of documents
|
|
25
25
|
|
|
26
26
|
Args:
|
|
27
|
+
----
|
|
27
28
|
det_predictor: detection module
|
|
28
29
|
reco_predictor: recognition module
|
|
29
30
|
assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages
|
|
@@ -35,7 +36,7 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
35
36
|
page. Doing so will slightly deteriorate the overall latency.
|
|
36
37
|
detect_language: if True, the language prediction will be added to the predictions for each
|
|
37
38
|
page. Doing so will slightly deteriorate the overall latency.
|
|
38
|
-
kwargs: keyword args of `DocumentBuilder`
|
|
39
|
+
**kwargs: keyword args of `DocumentBuilder`
|
|
39
40
|
"""
|
|
40
41
|
|
|
41
42
|
_children_names = ["det_predictor", "reco_predictor", "doc_builder"]
|
|
@@ -71,27 +72,41 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
71
72
|
|
|
72
73
|
origin_page_shapes = [page.shape[:2] for page in pages]
|
|
73
74
|
|
|
75
|
+
# Localize text elements
|
|
76
|
+
loc_preds, out_maps = self.det_predictor(pages, return_maps=True, **kwargs)
|
|
77
|
+
|
|
74
78
|
# Detect document rotation and rotate pages
|
|
79
|
+
seg_maps = [
|
|
80
|
+
np.where(np.expand_dims(np.amax(out_map, axis=-1), axis=-1) > kwargs.get("bin_thresh", 0.3), 255, 0).astype(
|
|
81
|
+
np.uint8
|
|
82
|
+
)
|
|
83
|
+
for out_map in out_maps
|
|
84
|
+
]
|
|
75
85
|
if self.detect_orientation:
|
|
76
|
-
origin_page_orientations = [estimate_orientation(
|
|
86
|
+
origin_page_orientations = [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
77
87
|
orientations = [
|
|
78
|
-
{"value": orientation_page, "confidence":
|
|
88
|
+
{"value": orientation_page, "confidence": None} for orientation_page in origin_page_orientations
|
|
79
89
|
]
|
|
80
90
|
else:
|
|
81
91
|
orientations = None
|
|
82
92
|
if self.straighten_pages:
|
|
83
93
|
origin_page_orientations = (
|
|
84
|
-
origin_page_orientations
|
|
94
|
+
origin_page_orientations
|
|
95
|
+
if self.detect_orientation
|
|
96
|
+
else [estimate_orientation(seq_map) for seq_map in seg_maps]
|
|
85
97
|
)
|
|
86
|
-
pages = [rotate_image(page, -angle, expand=
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
loc_preds = self.det_predictor(pages, **kwargs)
|
|
98
|
+
pages = [rotate_image(page, -angle, expand=False) for page, angle in zip(pages, origin_page_orientations)]
|
|
99
|
+
# Forward again to get predictions on straight pages
|
|
100
|
+
loc_preds = self.det_predictor(pages, **kwargs) # type: ignore[assignment]
|
|
90
101
|
|
|
91
|
-
dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
|
|
102
|
+
dict_loc_preds: Dict[str, List[np.ndarray]] = invert_data_structure(loc_preds) # type: ignore
|
|
92
103
|
# Rectify crops if aspect ratio
|
|
93
104
|
dict_loc_preds = {k: self._remove_padding(pages, loc_pred) for k, loc_pred in dict_loc_preds.items()}
|
|
94
105
|
|
|
106
|
+
# Apply hooks to loc_preds if any
|
|
107
|
+
for hook in self.hooks:
|
|
108
|
+
dict_loc_preds = hook(dict_loc_preds)
|
|
109
|
+
|
|
95
110
|
# Crop images
|
|
96
111
|
crops = {}
|
|
97
112
|
for class_name in dict_loc_preds.keys():
|
|
@@ -126,24 +141,9 @@ class KIEPredictor(NestedObject, _KIEPredictor):
|
|
|
126
141
|
languages_dict = [{"value": lang[0], "confidence": lang[1]} for lang in languages]
|
|
127
142
|
else:
|
|
128
143
|
languages_dict = None
|
|
129
|
-
# Rotate back pages and boxes while keeping original image size
|
|
130
|
-
if self.straighten_pages:
|
|
131
|
-
boxes_per_page = [
|
|
132
|
-
{
|
|
133
|
-
k: rotate_boxes(
|
|
134
|
-
page_boxes,
|
|
135
|
-
angle,
|
|
136
|
-
orig_shape=page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:],
|
|
137
|
-
target_shape=mask, # type: ignore[arg-type]
|
|
138
|
-
)
|
|
139
|
-
for k, page_boxes in page_boxes_dict.items()
|
|
140
|
-
}
|
|
141
|
-
for page_boxes_dict, page, angle, mask in zip(
|
|
142
|
-
boxes_per_page, pages, origin_page_orientations, origin_page_shapes
|
|
143
|
-
)
|
|
144
|
-
]
|
|
145
144
|
|
|
146
145
|
out = self.doc_builder(
|
|
146
|
+
pages,
|
|
147
147
|
boxes_per_page,
|
|
148
148
|
text_preds_per_page,
|
|
149
149
|
origin_page_shapes, # type: ignore[arg-type]
|
doctr/models/modules/__init__.py
CHANGED
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Tuple, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
|
|
12
|
+
__all__ = ["FASTConvLayer"]
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FASTConvLayer(nn.Module):
|
|
16
|
+
"""Convolutional layer used in the TextNet and FAST architectures"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
in_channels: int,
|
|
21
|
+
out_channels: int,
|
|
22
|
+
kernel_size: Union[int, Tuple[int, int]],
|
|
23
|
+
stride: int = 1,
|
|
24
|
+
dilation: int = 1,
|
|
25
|
+
groups: int = 1,
|
|
26
|
+
bias: bool = False,
|
|
27
|
+
) -> None:
|
|
28
|
+
super().__init__()
|
|
29
|
+
|
|
30
|
+
self.groups = groups
|
|
31
|
+
self.in_channels = in_channels
|
|
32
|
+
self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
|
33
|
+
|
|
34
|
+
self.hor_conv, self.hor_bn = None, None
|
|
35
|
+
self.ver_conv, self.ver_bn = None, None
|
|
36
|
+
|
|
37
|
+
padding = (int(((self.converted_ks[0] - 1) * dilation) / 2), int(((self.converted_ks[1] - 1) * dilation) / 2))
|
|
38
|
+
|
|
39
|
+
self.activation = nn.ReLU(inplace=True)
|
|
40
|
+
self.conv = nn.Conv2d(
|
|
41
|
+
in_channels,
|
|
42
|
+
out_channels,
|
|
43
|
+
kernel_size=self.converted_ks,
|
|
44
|
+
stride=stride,
|
|
45
|
+
padding=padding,
|
|
46
|
+
dilation=dilation,
|
|
47
|
+
groups=groups,
|
|
48
|
+
bias=bias,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
self.bn = nn.BatchNorm2d(out_channels)
|
|
52
|
+
|
|
53
|
+
if self.converted_ks[1] != 1:
|
|
54
|
+
self.ver_conv = nn.Conv2d(
|
|
55
|
+
in_channels,
|
|
56
|
+
out_channels,
|
|
57
|
+
kernel_size=(self.converted_ks[0], 1),
|
|
58
|
+
padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
|
|
59
|
+
stride=stride,
|
|
60
|
+
dilation=dilation,
|
|
61
|
+
groups=groups,
|
|
62
|
+
bias=bias,
|
|
63
|
+
)
|
|
64
|
+
self.ver_bn = nn.BatchNorm2d(out_channels)
|
|
65
|
+
|
|
66
|
+
if self.converted_ks[0] != 1:
|
|
67
|
+
self.hor_conv = nn.Conv2d(
|
|
68
|
+
in_channels,
|
|
69
|
+
out_channels,
|
|
70
|
+
kernel_size=(1, self.converted_ks[1]),
|
|
71
|
+
padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
|
|
72
|
+
stride=stride,
|
|
73
|
+
dilation=dilation,
|
|
74
|
+
groups=groups,
|
|
75
|
+
bias=bias,
|
|
76
|
+
)
|
|
77
|
+
self.hor_bn = nn.BatchNorm2d(out_channels)
|
|
78
|
+
|
|
79
|
+
self.rbr_identity = nn.BatchNorm2d(in_channels) if out_channels == in_channels and stride == 1 else None
|
|
80
|
+
|
|
81
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
82
|
+
if hasattr(self, "fused_conv"):
|
|
83
|
+
return self.activation(self.fused_conv(x))
|
|
84
|
+
|
|
85
|
+
main_outputs = self.bn(self.conv(x))
|
|
86
|
+
vertical_outputs = self.ver_bn(self.ver_conv(x)) if self.ver_conv is not None and self.ver_bn is not None else 0
|
|
87
|
+
horizontal_outputs = (
|
|
88
|
+
self.hor_bn(self.hor_conv(x)) if self.hor_bn is not None and self.hor_conv is not None else 0
|
|
89
|
+
)
|
|
90
|
+
id_out = self.rbr_identity(x) if self.rbr_identity is not None and self.ver_bn is not None else 0
|
|
91
|
+
|
|
92
|
+
return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
|
|
93
|
+
|
|
94
|
+
# The following logic is used to reparametrize the layer
|
|
95
|
+
# Borrowed from: https://github.com/czczup/FAST/blob/main/models/utils/nas_utils.py
|
|
96
|
+
def _identity_to_conv(
|
|
97
|
+
self, identity: Union[nn.BatchNorm2d, None]
|
|
98
|
+
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
|
99
|
+
if identity is None or identity.running_var is None:
|
|
100
|
+
return 0, 0
|
|
101
|
+
if not hasattr(self, "id_tensor"):
|
|
102
|
+
input_dim = self.in_channels // self.groups
|
|
103
|
+
kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
|
|
104
|
+
for i in range(self.in_channels):
|
|
105
|
+
kernel_value[i, i % input_dim, 0, 0] = 1
|
|
106
|
+
id_tensor = torch.from_numpy(kernel_value).to(identity.weight.device)
|
|
107
|
+
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
108
|
+
kernel = self.id_tensor
|
|
109
|
+
std = (identity.running_var + identity.eps).sqrt() # type: ignore[attr-defined]
|
|
110
|
+
t = (identity.weight / std).reshape(-1, 1, 1, 1)
|
|
111
|
+
return kernel * t, identity.bias - identity.running_mean * identity.weight / std
|
|
112
|
+
|
|
113
|
+
def _fuse_bn_tensor(self, conv: nn.Conv2d, bn: nn.BatchNorm2d) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
114
|
+
kernel = conv.weight
|
|
115
|
+
kernel = self._pad_to_mxn_tensor(kernel)
|
|
116
|
+
std = (bn.running_var + bn.eps).sqrt() # type: ignore
|
|
117
|
+
t = (bn.weight / std).reshape(-1, 1, 1, 1)
|
|
118
|
+
return kernel * t, bn.bias - bn.running_mean * bn.weight / std
|
|
119
|
+
|
|
120
|
+
def _get_equivalent_kernel_bias(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
121
|
+
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
|
|
122
|
+
if self.ver_conv is not None:
|
|
123
|
+
kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn) # type: ignore[arg-type]
|
|
124
|
+
else:
|
|
125
|
+
kernel_mx1, bias_mx1 = 0, 0 # type: ignore[assignment]
|
|
126
|
+
if self.hor_conv is not None:
|
|
127
|
+
kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn) # type: ignore[arg-type]
|
|
128
|
+
else:
|
|
129
|
+
kernel_1xn, bias_1xn = 0, 0 # type: ignore[assignment]
|
|
130
|
+
kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
|
|
131
|
+
kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
|
|
132
|
+
bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
|
|
133
|
+
return kernel_mxn, bias_mxn
|
|
134
|
+
|
|
135
|
+
def _pad_to_mxn_tensor(self, kernel: torch.Tensor) -> torch.Tensor:
|
|
136
|
+
kernel_height, kernel_width = self.converted_ks
|
|
137
|
+
height, width = kernel.shape[2:]
|
|
138
|
+
pad_left_right = (kernel_width - width) // 2
|
|
139
|
+
pad_top_down = (kernel_height - height) // 2
|
|
140
|
+
return torch.nn.functional.pad(kernel, [pad_left_right, pad_left_right, pad_top_down, pad_top_down], value=0)
|
|
141
|
+
|
|
142
|
+
def reparameterize_layer(self):
|
|
143
|
+
if hasattr(self, "fused_conv"):
|
|
144
|
+
return
|
|
145
|
+
kernel, bias = self._get_equivalent_kernel_bias()
|
|
146
|
+
self.fused_conv = nn.Conv2d(
|
|
147
|
+
in_channels=self.conv.in_channels,
|
|
148
|
+
out_channels=self.conv.out_channels,
|
|
149
|
+
kernel_size=self.conv.kernel_size, # type: ignore[arg-type]
|
|
150
|
+
stride=self.conv.stride, # type: ignore[arg-type]
|
|
151
|
+
padding=self.conv.padding, # type: ignore[arg-type]
|
|
152
|
+
dilation=self.conv.dilation, # type: ignore[arg-type]
|
|
153
|
+
groups=self.conv.groups,
|
|
154
|
+
bias=True,
|
|
155
|
+
)
|
|
156
|
+
self.fused_conv.weight.data = kernel
|
|
157
|
+
self.fused_conv.bias.data = bias # type: ignore[union-attr]
|
|
158
|
+
self.deploy = True
|
|
159
|
+
for para in self.parameters():
|
|
160
|
+
para.detach_()
|
|
161
|
+
for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
|
|
162
|
+
if hasattr(self, attr):
|
|
163
|
+
self.__delattr__(attr)
|
|
164
|
+
|
|
165
|
+
if hasattr(self, "rbr_identity"):
|
|
166
|
+
self.__delattr__("rbr_identity")
|
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
|
+
|
|
3
|
+
# This program is licensed under the Apache License 2.0.
|
|
4
|
+
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
5
|
+
|
|
6
|
+
from typing import Any, Tuple, Union
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import tensorflow as tf
|
|
10
|
+
from tensorflow.keras import layers
|
|
11
|
+
|
|
12
|
+
from doctr.utils.repr import NestedObject
|
|
13
|
+
|
|
14
|
+
__all__ = ["FASTConvLayer"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FASTConvLayer(layers.Layer, NestedObject):
|
|
18
|
+
"""Convolutional layer used in the TextNet and FAST architectures"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
in_channels: int,
|
|
23
|
+
out_channels: int,
|
|
24
|
+
kernel_size: Union[int, Tuple[int, int]],
|
|
25
|
+
stride: int = 1,
|
|
26
|
+
dilation: int = 1,
|
|
27
|
+
groups: int = 1,
|
|
28
|
+
bias: bool = False,
|
|
29
|
+
) -> None:
|
|
30
|
+
super().__init__()
|
|
31
|
+
|
|
32
|
+
self.groups = groups
|
|
33
|
+
self.in_channels = in_channels
|
|
34
|
+
self.converted_ks = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
|
35
|
+
|
|
36
|
+
self.hor_conv, self.hor_bn = None, None
|
|
37
|
+
self.ver_conv, self.ver_bn = None, None
|
|
38
|
+
|
|
39
|
+
padding = ((self.converted_ks[0] - 1) * dilation // 2, (self.converted_ks[1] - 1) * dilation // 2)
|
|
40
|
+
|
|
41
|
+
self.activation = layers.ReLU()
|
|
42
|
+
self.conv_pad = layers.ZeroPadding2D(padding=padding)
|
|
43
|
+
|
|
44
|
+
self.conv = layers.Conv2D(
|
|
45
|
+
filters=out_channels,
|
|
46
|
+
kernel_size=self.converted_ks,
|
|
47
|
+
strides=stride,
|
|
48
|
+
dilation_rate=dilation,
|
|
49
|
+
groups=groups,
|
|
50
|
+
use_bias=bias,
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
self.bn = layers.BatchNormalization()
|
|
54
|
+
|
|
55
|
+
if self.converted_ks[1] != 1:
|
|
56
|
+
self.ver_pad = layers.ZeroPadding2D(
|
|
57
|
+
padding=(int(((self.converted_ks[0] - 1) * dilation) / 2), 0),
|
|
58
|
+
)
|
|
59
|
+
self.ver_conv = layers.Conv2D(
|
|
60
|
+
filters=out_channels,
|
|
61
|
+
kernel_size=(self.converted_ks[0], 1),
|
|
62
|
+
strides=stride,
|
|
63
|
+
dilation_rate=dilation,
|
|
64
|
+
groups=groups,
|
|
65
|
+
use_bias=bias,
|
|
66
|
+
)
|
|
67
|
+
self.ver_bn = layers.BatchNormalization()
|
|
68
|
+
|
|
69
|
+
if self.converted_ks[0] != 1:
|
|
70
|
+
self.hor_pad = layers.ZeroPadding2D(
|
|
71
|
+
padding=(0, int(((self.converted_ks[1] - 1) * dilation) / 2)),
|
|
72
|
+
)
|
|
73
|
+
self.hor_conv = layers.Conv2D(
|
|
74
|
+
filters=out_channels,
|
|
75
|
+
kernel_size=(1, self.converted_ks[1]),
|
|
76
|
+
strides=stride,
|
|
77
|
+
dilation_rate=dilation,
|
|
78
|
+
groups=groups,
|
|
79
|
+
use_bias=bias,
|
|
80
|
+
)
|
|
81
|
+
self.hor_bn = layers.BatchNormalization()
|
|
82
|
+
|
|
83
|
+
self.rbr_identity = layers.BatchNormalization() if out_channels == in_channels and stride == 1 else None
|
|
84
|
+
|
|
85
|
+
def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
|
|
86
|
+
if hasattr(self, "fused_conv"):
|
|
87
|
+
return self.activation(self.fused_conv(self.conv_pad(x, **kwargs), **kwargs))
|
|
88
|
+
|
|
89
|
+
main_outputs = self.bn(self.conv(self.conv_pad(x, **kwargs), **kwargs), **kwargs)
|
|
90
|
+
vertical_outputs = (
|
|
91
|
+
self.ver_bn(self.ver_conv(self.ver_pad(x, **kwargs), **kwargs), **kwargs)
|
|
92
|
+
if self.ver_conv is not None and self.ver_bn is not None
|
|
93
|
+
else 0
|
|
94
|
+
)
|
|
95
|
+
horizontal_outputs = (
|
|
96
|
+
self.hor_bn(self.hor_conv(self.hor_pad(x, **kwargs), **kwargs), **kwargs)
|
|
97
|
+
if self.hor_bn is not None and self.hor_conv is not None
|
|
98
|
+
else 0
|
|
99
|
+
)
|
|
100
|
+
id_out = self.rbr_identity(x, **kwargs) if self.rbr_identity is not None and self.ver_bn is not None else 0
|
|
101
|
+
|
|
102
|
+
return self.activation(main_outputs + vertical_outputs + horizontal_outputs + id_out)
|
|
103
|
+
|
|
104
|
+
# The following logic is used to reparametrize the layer
|
|
105
|
+
# Adapted from: https://github.com/mindee/doctr/blob/main/doctr/models/modules/layers/pytorch.py
|
|
106
|
+
def _identity_to_conv(
|
|
107
|
+
self, identity: layers.BatchNormalization
|
|
108
|
+
) -> Union[Tuple[tf.Tensor, tf.Tensor], Tuple[int, int]]:
|
|
109
|
+
if identity is None or not hasattr(identity, "moving_mean") or not hasattr(identity, "moving_variance"):
|
|
110
|
+
return 0, 0
|
|
111
|
+
if not hasattr(self, "id_tensor"):
|
|
112
|
+
input_dim = self.in_channels // self.groups
|
|
113
|
+
kernel_value = np.zeros((self.in_channels, input_dim, 1, 1), dtype=np.float32)
|
|
114
|
+
for i in range(self.in_channels):
|
|
115
|
+
kernel_value[i, i % input_dim, 0, 0] = 1
|
|
116
|
+
id_tensor = tf.constant(kernel_value, dtype=tf.float32)
|
|
117
|
+
self.id_tensor = self._pad_to_mxn_tensor(id_tensor)
|
|
118
|
+
kernel = self.id_tensor
|
|
119
|
+
std = tf.sqrt(identity.moving_variance + identity.epsilon)
|
|
120
|
+
t = tf.reshape(identity.gamma / std, (-1, 1, 1, 1))
|
|
121
|
+
return kernel * t, identity.beta - identity.moving_mean * identity.gamma / std
|
|
122
|
+
|
|
123
|
+
def _fuse_bn_tensor(self, conv: layers.Conv2D, bn: layers.BatchNormalization) -> Tuple[tf.Tensor, tf.Tensor]:
|
|
124
|
+
kernel = conv.kernel
|
|
125
|
+
kernel = self._pad_to_mxn_tensor(kernel)
|
|
126
|
+
std = tf.sqrt(bn.moving_variance + bn.epsilon)
|
|
127
|
+
t = tf.reshape(bn.gamma / std, (1, 1, 1, -1))
|
|
128
|
+
return kernel * t, bn.beta - bn.moving_mean * bn.gamma / std
|
|
129
|
+
|
|
130
|
+
def _get_equivalent_kernel_bias(self):
|
|
131
|
+
kernel_mxn, bias_mxn = self._fuse_bn_tensor(self.conv, self.bn)
|
|
132
|
+
if self.ver_conv is not None:
|
|
133
|
+
kernel_mx1, bias_mx1 = self._fuse_bn_tensor(self.ver_conv, self.ver_bn)
|
|
134
|
+
else:
|
|
135
|
+
kernel_mx1, bias_mx1 = 0, 0
|
|
136
|
+
if self.hor_conv is not None:
|
|
137
|
+
kernel_1xn, bias_1xn = self._fuse_bn_tensor(self.hor_conv, self.hor_bn)
|
|
138
|
+
else:
|
|
139
|
+
kernel_1xn, bias_1xn = 0, 0
|
|
140
|
+
kernel_id, bias_id = self._identity_to_conv(self.rbr_identity)
|
|
141
|
+
if not isinstance(kernel_id, int):
|
|
142
|
+
kernel_id = tf.transpose(kernel_id, (2, 3, 0, 1))
|
|
143
|
+
kernel_mxn = kernel_mxn + kernel_mx1 + kernel_1xn + kernel_id
|
|
144
|
+
bias_mxn = bias_mxn + bias_mx1 + bias_1xn + bias_id
|
|
145
|
+
return kernel_mxn, bias_mxn
|
|
146
|
+
|
|
147
|
+
def _pad_to_mxn_tensor(self, kernel: tf.Tensor) -> tf.Tensor:
|
|
148
|
+
kernel_height, kernel_width = self.converted_ks
|
|
149
|
+
height, width = kernel.shape[2:]
|
|
150
|
+
pad_left_right = tf.maximum(0, (kernel_width - width) // 2)
|
|
151
|
+
pad_top_down = tf.maximum(0, (kernel_height - height) // 2)
|
|
152
|
+
return tf.pad(kernel, [[0, 0], [0, 0], [pad_top_down, pad_top_down], [pad_left_right, pad_left_right]])
|
|
153
|
+
|
|
154
|
+
def reparameterize_layer(self):
|
|
155
|
+
kernel, bias = self._get_equivalent_kernel_bias()
|
|
156
|
+
self.fused_conv = layers.Conv2D(
|
|
157
|
+
filters=self.conv.filters,
|
|
158
|
+
kernel_size=self.conv.kernel_size,
|
|
159
|
+
strides=self.conv.strides,
|
|
160
|
+
padding=self.conv.padding,
|
|
161
|
+
dilation_rate=self.conv.dilation_rate,
|
|
162
|
+
groups=self.conv.groups,
|
|
163
|
+
use_bias=True,
|
|
164
|
+
)
|
|
165
|
+
# build layer to initialize weights and biases
|
|
166
|
+
self.fused_conv.build(input_shape=(None, None, None, kernel.shape[-2]))
|
|
167
|
+
self.fused_conv.set_weights([kernel.numpy(), bias.numpy()])
|
|
168
|
+
for para in self.trainable_variables:
|
|
169
|
+
para._trainable = False
|
|
170
|
+
for attr in ["conv", "bn", "ver_conv", "ver_bn", "hor_conv", "hor_bn"]:
|
|
171
|
+
if hasattr(self, attr):
|
|
172
|
+
delattr(self, attr)
|
|
173
|
+
|
|
174
|
+
if hasattr(self, "rbr_identity"):
|
|
175
|
+
delattr(self, "rbr_identity")
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
# Copyright (C) 2021-
|
|
1
|
+
# Copyright (C) 2021-2024, Mindee.
|
|
2
2
|
|
|
3
3
|
# This program is licensed under the Apache License 2.0.
|
|
4
4
|
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
|
|
@@ -30,14 +30,17 @@ class PositionalEncoding(nn.Module):
|
|
|
30
30
|
self.register_buffer("pe", pe.unsqueeze(0))
|
|
31
31
|
|
|
32
32
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
33
|
-
"""
|
|
33
|
+
"""Forward pass
|
|
34
|
+
|
|
34
35
|
Args:
|
|
36
|
+
----
|
|
35
37
|
x: embeddings (batch, max_len, d_model)
|
|
36
38
|
|
|
37
|
-
Returns
|
|
39
|
+
Returns
|
|
40
|
+
-------
|
|
38
41
|
positional embeddings (batch, max_len, d_model)
|
|
39
42
|
"""
|
|
40
|
-
x = x + self.pe[:, : x.size(1)]
|
|
43
|
+
x = x + self.pe[:, : x.size(1)]
|
|
41
44
|
return self.dropout(x)
|
|
42
45
|
|
|
43
46
|
|
|
@@ -45,12 +48,11 @@ def scaled_dot_product_attention(
|
|
|
45
48
|
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None
|
|
46
49
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
47
50
|
"""Scaled Dot-Product Attention"""
|
|
48
|
-
|
|
49
51
|
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(query.size(-1))
|
|
50
52
|
if mask is not None:
|
|
51
53
|
# NOTE: to ensure the ONNX compatibility, masked_fill works only with int equal condition
|
|
52
|
-
scores = scores.masked_fill(mask == 0, float("-inf"))
|
|
53
|
-
p_attn = torch.softmax(scores, dim=-1)
|
|
54
|
+
scores = scores.masked_fill(mask == 0, float("-inf")) # type: ignore[attr-defined]
|
|
55
|
+
p_attn = torch.softmax(scores, dim=-1) # type: ignore[call-overload]
|
|
54
56
|
return torch.matmul(p_attn, value), p_attn
|
|
55
57
|
|
|
56
58
|
|
|
@@ -121,12 +123,12 @@ class EncoderBlock(nn.Module):
|
|
|
121
123
|
self.layer_norm_output = nn.LayerNorm(d_model, eps=1e-5)
|
|
122
124
|
self.dropout = nn.Dropout(dropout)
|
|
123
125
|
|
|
124
|
-
self.attention = nn.ModuleList(
|
|
125
|
-
|
|
126
|
-
)
|
|
127
|
-
self.position_feed_forward = nn.ModuleList(
|
|
128
|
-
|
|
129
|
-
)
|
|
126
|
+
self.attention = nn.ModuleList([
|
|
127
|
+
MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)
|
|
128
|
+
])
|
|
129
|
+
self.position_feed_forward = nn.ModuleList([
|
|
130
|
+
PositionwiseFeedForward(d_model, dff, dropout, activation_fct) for _ in range(self.num_layers)
|
|
131
|
+
])
|
|
130
132
|
|
|
131
133
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
132
134
|
output = x
|
|
@@ -167,15 +169,15 @@ class Decoder(nn.Module):
|
|
|
167
169
|
self.embed = nn.Embedding(vocab_size, d_model)
|
|
168
170
|
self.positional_encoding = PositionalEncoding(d_model, dropout, maximum_position_encoding)
|
|
169
171
|
|
|
170
|
-
self.attention = nn.ModuleList(
|
|
171
|
-
|
|
172
|
-
)
|
|
173
|
-
self.source_attention = nn.ModuleList(
|
|
174
|
-
|
|
175
|
-
)
|
|
176
|
-
self.position_feed_forward = nn.ModuleList(
|
|
177
|
-
|
|
178
|
-
)
|
|
172
|
+
self.attention = nn.ModuleList([
|
|
173
|
+
MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)
|
|
174
|
+
])
|
|
175
|
+
self.source_attention = nn.ModuleList([
|
|
176
|
+
MultiHeadAttention(num_heads, d_model, dropout) for _ in range(self.num_layers)
|
|
177
|
+
])
|
|
178
|
+
self.position_feed_forward = nn.ModuleList([
|
|
179
|
+
PositionwiseFeedForward(d_model, dff, dropout) for _ in range(self.num_layers)
|
|
180
|
+
])
|
|
179
181
|
|
|
180
182
|
def forward(
|
|
181
183
|
self,
|