docling-ibm-models 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
- docling_ibm_models/tableformer/__init__.py +0 -0
- docling_ibm_models/tableformer/common.py +200 -0
- docling_ibm_models/tableformer/data_management/__init__.py +0 -0
- docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
- docling_ibm_models/tableformer/data_management/functional.py +574 -0
- docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
- docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
- docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
- docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
- docling_ibm_models/tableformer/data_management/transforms.py +396 -0
- docling_ibm_models/tableformer/models/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/__init__.py +0 -0
- docling_ibm_models/tableformer/models/common/base_model.py +279 -0
- docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
- docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
- docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
- docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
- docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
- docling_ibm_models/tableformer/otsl.py +541 -0
- docling_ibm_models/tableformer/settings.py +90 -0
- docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
- docling_ibm_models/tableformer/test_prepare_image.py +99 -0
- docling_ibm_models/tableformer/utils/__init__.py +0 -0
- docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
- docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
- docling_ibm_models/tableformer/utils/utils.py +376 -0
- docling_ibm_models/tableformer/utils/variance.py +175 -0
- docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
- docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
- docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
- docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,163 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import logging
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
|
10
|
+
import docling_ibm_models.tableformer.settings as s
|
11
|
+
import docling_ibm_models.tableformer.utils.utils as u
|
12
|
+
|
13
|
+
# from scipy.optimize import linear_sum_assignment
|
14
|
+
|
15
|
+
LOG_LEVEL = logging.INFO
|
16
|
+
|
17
|
+
|
18
|
+
class CellAttention(nn.Module):
|
19
|
+
"""
|
20
|
+
Attention Network.
|
21
|
+
"""
|
22
|
+
|
23
|
+
def __init__(self, encoder_dim, tag_decoder_dim, language_dim, attention_dim):
|
24
|
+
"""
|
25
|
+
:param encoder_dim: feature size of encoded images
|
26
|
+
:param tag_decoder_dim: size of tag decoder's RNN
|
27
|
+
:param language_dim: size of language model's RNN
|
28
|
+
:param attention_dim: size of the attention network
|
29
|
+
"""
|
30
|
+
super(CellAttention, self).__init__()
|
31
|
+
# linear layer to transform encoded image
|
32
|
+
self._encoder_att = nn.Linear(encoder_dim, attention_dim)
|
33
|
+
# linear layer to transform tag decoder output
|
34
|
+
self._tag_decoder_att = nn.Linear(tag_decoder_dim, attention_dim)
|
35
|
+
# linear layer to transform language models output
|
36
|
+
self._language_att = nn.Linear(language_dim, attention_dim)
|
37
|
+
# linear layer to calculate values to be softmax-ed
|
38
|
+
self._full_att = nn.Linear(attention_dim, 1)
|
39
|
+
self._relu = nn.ReLU()
|
40
|
+
self._softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
|
41
|
+
|
42
|
+
def _log(self):
|
43
|
+
# Setup a custom logger
|
44
|
+
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
45
|
+
|
46
|
+
def forward(self, encoder_out, decoder_hidden, language_out):
|
47
|
+
"""
|
48
|
+
Forward propagation.
|
49
|
+
:param encoder_out: encoded images, a tensor of dimension (1, num_pixels, encoder_dim)
|
50
|
+
:param decoder_hidden: tag decoder output, a tensor of dimension [(num_cells,
|
51
|
+
tag_decoder_dim)]
|
52
|
+
:param language_out: language model output, a tensor of dimension (num_cells,
|
53
|
+
language_dim)
|
54
|
+
:return: attention weighted encoding, weights
|
55
|
+
"""
|
56
|
+
att1 = self._encoder_att(encoder_out) # (1, num_pixels, attention_dim)
|
57
|
+
att2 = self._tag_decoder_att(decoder_hidden) # (num_cells, tag_decoder_dim)
|
58
|
+
att3 = self._language_att(language_out) # (num_cells, attention_dim)
|
59
|
+
att = self._full_att(
|
60
|
+
self._relu(att1 + att2.unsqueeze(1) + att3.unsqueeze(1))
|
61
|
+
).squeeze(2)
|
62
|
+
alpha = self._softmax(att) # (num_cells, num_pixels)
|
63
|
+
# (num_cells, encoder_dim)
|
64
|
+
attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
|
65
|
+
return attention_weighted_encoding, alpha
|
66
|
+
|
67
|
+
|
68
|
+
class BBoxDecoder(nn.Module):
|
69
|
+
"""
|
70
|
+
CellDecoder generates cell content
|
71
|
+
"""
|
72
|
+
|
73
|
+
def __init__(
|
74
|
+
self,
|
75
|
+
device,
|
76
|
+
attention_dim,
|
77
|
+
embed_dim,
|
78
|
+
tag_decoder_dim,
|
79
|
+
decoder_dim,
|
80
|
+
num_classes,
|
81
|
+
encoder_dim=512,
|
82
|
+
dropout=0.5,
|
83
|
+
cnn_layer_stride=1,
|
84
|
+
):
|
85
|
+
"""
|
86
|
+
:param attention_dim: size of attention network
|
87
|
+
:param embed_dim: embedding size
|
88
|
+
:param tag_decoder_dim: size of tag decoder's RNN
|
89
|
+
:param decoder_dim: size of decoder's RNN
|
90
|
+
:param vocab_size: size of vocabulary
|
91
|
+
:param encoder_dim: feature size of encoded images
|
92
|
+
:param dropout: dropout
|
93
|
+
:param mini_batch_size: batch size of cells to reduce GPU memory usage
|
94
|
+
"""
|
95
|
+
super(BBoxDecoder, self).__init__()
|
96
|
+
self._device = device
|
97
|
+
self._encoder_dim = encoder_dim
|
98
|
+
self._attention_dim = attention_dim
|
99
|
+
self._embed_dim = embed_dim
|
100
|
+
self._decoder_dim = decoder_dim
|
101
|
+
self._dropout = dropout
|
102
|
+
self._num_classes = num_classes
|
103
|
+
|
104
|
+
if cnn_layer_stride is not None:
|
105
|
+
self._input_filter = u.resnet_block(stride=cnn_layer_stride)
|
106
|
+
# attention network
|
107
|
+
self._attention = CellAttention(
|
108
|
+
encoder_dim, tag_decoder_dim, decoder_dim, attention_dim
|
109
|
+
)
|
110
|
+
# decoder LSTMCell
|
111
|
+
self._init_h = nn.Linear(encoder_dim, decoder_dim)
|
112
|
+
|
113
|
+
# linear layer to create a sigmoid-activated gate
|
114
|
+
self._f_beta = nn.Linear(decoder_dim, encoder_dim)
|
115
|
+
self._sigmoid = nn.Sigmoid()
|
116
|
+
self._dropout = nn.Dropout(p=self._dropout)
|
117
|
+
self._class_embed = nn.Linear(512, self._num_classes + 1)
|
118
|
+
self._bbox_embed = u.MLP(512, 256, 4, 3)
|
119
|
+
|
120
|
+
def _init_hidden_state(self, encoder_out, batch_size):
|
121
|
+
mean_encoder_out = encoder_out.mean(dim=1)
|
122
|
+
h = self._init_h(mean_encoder_out).expand(batch_size, -1)
|
123
|
+
return h
|
124
|
+
|
125
|
+
def _log(self):
|
126
|
+
# Setup a custom logger
|
127
|
+
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
128
|
+
|
129
|
+
def inference(self, encoder_out, tag_H):
|
130
|
+
"""
|
131
|
+
Inference on test images with beam search
|
132
|
+
"""
|
133
|
+
if hasattr(self, "_input_filter"):
|
134
|
+
encoder_out = self._input_filter(encoder_out.permute(0, 3, 1, 2)).permute(
|
135
|
+
0, 2, 3, 1
|
136
|
+
)
|
137
|
+
|
138
|
+
encoder_dim = encoder_out.size(3)
|
139
|
+
|
140
|
+
# Flatten encoding (1, num_pixels, encoder_dim)
|
141
|
+
encoder_out = encoder_out.view(1, -1, encoder_dim)
|
142
|
+
|
143
|
+
num_cells = len(tag_H)
|
144
|
+
predictions_bboxes = []
|
145
|
+
predictions_classes = []
|
146
|
+
|
147
|
+
for c_id in range(num_cells):
|
148
|
+
# Start decoding
|
149
|
+
h = self._init_hidden_state(encoder_out, 1)
|
150
|
+
cell_tag_H = tag_H[c_id]
|
151
|
+
awe, _ = self._attention(encoder_out, cell_tag_H, h)
|
152
|
+
gate = self._sigmoid(self._f_beta(h))
|
153
|
+
awe = gate * awe
|
154
|
+
h = awe * h
|
155
|
+
|
156
|
+
predictions_bboxes.append(self._bbox_embed(h).sigmoid())
|
157
|
+
predictions_classes.append(self._class_embed(h))
|
158
|
+
if len(predictions_bboxes) > 0:
|
159
|
+
predictions_bboxes = torch.stack([x[0] for x in predictions_bboxes])
|
160
|
+
if len(predictions_classes) > 0:
|
161
|
+
predictions_classes = torch.stack([x[0] for x in predictions_classes])
|
162
|
+
|
163
|
+
return predictions_classes, predictions_bboxes
|
@@ -0,0 +1,72 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import logging
|
6
|
+
|
7
|
+
import torch.nn as nn
|
8
|
+
import torchvision
|
9
|
+
|
10
|
+
import docling_ibm_models.tableformer.settings as s
|
11
|
+
|
12
|
+
LOG_LEVEL = logging.INFO
|
13
|
+
# LOG_LEVEL = logging.DEBUG
|
14
|
+
|
15
|
+
|
16
|
+
class Encoder04(nn.Module):
|
17
|
+
"""
|
18
|
+
Encoder based on resnet-18
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(self, enc_image_size, enc_dim=512):
|
22
|
+
r"""
|
23
|
+
Parameters
|
24
|
+
----------
|
25
|
+
enc_image_size : int
|
26
|
+
Assuming that the encoded image is a square, this is the length of the image side
|
27
|
+
"""
|
28
|
+
|
29
|
+
super(Encoder04, self).__init__()
|
30
|
+
self.enc_image_size = enc_image_size
|
31
|
+
self._encoder_dim = enc_dim
|
32
|
+
|
33
|
+
resnet = torchvision.models.resnet18(pretrained=False)
|
34
|
+
modules = list(resnet.children())[:-3]
|
35
|
+
|
36
|
+
self._resnet = nn.Sequential(*modules)
|
37
|
+
self._adaptive_pool = nn.AdaptiveAvgPool2d(
|
38
|
+
(self.enc_image_size, self.enc_image_size)
|
39
|
+
)
|
40
|
+
|
41
|
+
def _log(self):
|
42
|
+
# Setup a custom logger
|
43
|
+
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
44
|
+
|
45
|
+
def get_encoder_dim(self):
|
46
|
+
return self._encoder_dim
|
47
|
+
|
48
|
+
def forward(self, images):
|
49
|
+
"""
|
50
|
+
Forward propagation
|
51
|
+
The encoder_dim 512 is decided by the structure of the image network (modified resnet-19)
|
52
|
+
|
53
|
+
Parameters
|
54
|
+
----------
|
55
|
+
images : tensor (batch_size, image_channels, resized_image, resized_image)
|
56
|
+
images input
|
57
|
+
|
58
|
+
Returns
|
59
|
+
-------
|
60
|
+
tensor : (batch_size, enc_image_size, enc_image_size, 256)
|
61
|
+
encoded images
|
62
|
+
"""
|
63
|
+
out = self._resnet(images) # (batch_size, 256, 28, 28)
|
64
|
+
self._log().debug("forward: resnet out: {}".format(out.size()))
|
65
|
+
out = self._adaptive_pool(out)
|
66
|
+
out = out.permute(
|
67
|
+
0, 2, 3, 1
|
68
|
+
) # (batch_size, enc_image_size, enc_image_size, 256)
|
69
|
+
|
70
|
+
self._log().debug("enc forward: final out: {}".format(out.size()))
|
71
|
+
|
72
|
+
return out
|
@@ -0,0 +1,324 @@
|
|
1
|
+
#
|
2
|
+
# Copyright IBM Corp. 2024 - 2024
|
3
|
+
# SPDX-License-Identifier: MIT
|
4
|
+
#
|
5
|
+
import logging
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
|
10
|
+
import docling_ibm_models.tableformer.settings as s
|
11
|
+
from docling_ibm_models.tableformer.models.common.base_model import BaseModel
|
12
|
+
from docling_ibm_models.tableformer.models.table04_rs.bbox_decoder_rs import BBoxDecoder
|
13
|
+
from docling_ibm_models.tableformer.models.table04_rs.encoder04_rs import Encoder04
|
14
|
+
from docling_ibm_models.tableformer.models.table04_rs.transformer_rs import (
|
15
|
+
Tag_Transformer,
|
16
|
+
)
|
17
|
+
from docling_ibm_models.tableformer.utils.app_profiler import AggProfiler
|
18
|
+
|
19
|
+
LOG_LEVEL = logging.WARN
|
20
|
+
# LOG_LEVEL = logging.INFO
|
21
|
+
# LOG_LEVEL = logging.DEBUG
|
22
|
+
|
23
|
+
|
24
|
+
class TableModel04_rs(BaseModel, nn.Module):
|
25
|
+
r"""
|
26
|
+
TableNet04Model encoder, dual-decoder model with OTSL+ support
|
27
|
+
"""
|
28
|
+
|
29
|
+
def __init__(self, config, init_data, purpose, device):
|
30
|
+
super(TableModel04_rs, self).__init__(config, init_data, device)
|
31
|
+
|
32
|
+
self._prof = config["predict"].get("profiling", False)
|
33
|
+
self._device = device
|
34
|
+
# Extract the word_map from the init_data
|
35
|
+
word_map = init_data["word_map"]
|
36
|
+
|
37
|
+
# Encoder
|
38
|
+
self._enc_image_size = config["model"]["enc_image_size"]
|
39
|
+
self._encoder_dim = config["model"]["hidden_dim"]
|
40
|
+
self._encoder = Encoder04(self._enc_image_size, self._encoder_dim).to(device)
|
41
|
+
|
42
|
+
tag_vocab_size = len(word_map["word_map_tag"])
|
43
|
+
|
44
|
+
td_encode = []
|
45
|
+
for t in ["ecel", "fcel", "ched", "rhed", "srow"]:
|
46
|
+
if t in word_map["word_map_tag"]:
|
47
|
+
td_encode.append(word_map["word_map_tag"][t])
|
48
|
+
self._log().debug("td_encode length: {}".format(len(td_encode)))
|
49
|
+
self._log().debug("td_encode: {}".format(td_encode))
|
50
|
+
|
51
|
+
self._tag_attention_dim = config["model"]["tag_attention_dim"]
|
52
|
+
self._tag_embed_dim = config["model"]["tag_embed_dim"]
|
53
|
+
self._tag_decoder_dim = config["model"]["tag_decoder_dim"]
|
54
|
+
self._decoder_dim = config["model"]["hidden_dim"]
|
55
|
+
self._dropout = config["model"]["dropout"]
|
56
|
+
|
57
|
+
self._bbox = config["train"]["bbox"]
|
58
|
+
self._bbox_attention_dim = config["model"]["bbox_attention_dim"]
|
59
|
+
self._bbox_embed_dim = config["model"]["bbox_embed_dim"]
|
60
|
+
self._bbox_decoder_dim = config["model"]["hidden_dim"]
|
61
|
+
|
62
|
+
self._enc_layers = config["model"]["enc_layers"]
|
63
|
+
self._dec_layers = config["model"]["dec_layers"]
|
64
|
+
self._n_heads = config["model"]["nheads"]
|
65
|
+
|
66
|
+
self._num_classes = config["model"]["bbox_classes"]
|
67
|
+
self._enc_image_size = config["model"]["enc_image_size"]
|
68
|
+
|
69
|
+
self._max_pred_len = config["predict"]["max_steps"]
|
70
|
+
|
71
|
+
self._tag_transformer = Tag_Transformer(
|
72
|
+
device,
|
73
|
+
tag_vocab_size,
|
74
|
+
td_encode,
|
75
|
+
self._decoder_dim,
|
76
|
+
self._enc_layers,
|
77
|
+
self._dec_layers,
|
78
|
+
self._enc_image_size,
|
79
|
+
n_heads=self._n_heads,
|
80
|
+
).to(device)
|
81
|
+
|
82
|
+
self._bbox_decoder = BBoxDecoder(
|
83
|
+
device,
|
84
|
+
self._bbox_attention_dim,
|
85
|
+
self._bbox_embed_dim,
|
86
|
+
self._tag_decoder_dim,
|
87
|
+
self._bbox_decoder_dim,
|
88
|
+
self._num_classes,
|
89
|
+
self._encoder_dim,
|
90
|
+
self._dropout,
|
91
|
+
).to(device)
|
92
|
+
|
93
|
+
def _log(self):
|
94
|
+
# Setup a custom logger
|
95
|
+
return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
|
96
|
+
|
97
|
+
def mergebboxes(self, bbox1, bbox2):
|
98
|
+
new_w = (bbox2[0] + bbox2[2] / 2) - (bbox1[0] - bbox1[2] / 2)
|
99
|
+
new_h = (bbox2[1] + bbox2[3] / 2) - (bbox1[1] - bbox1[3] / 2)
|
100
|
+
|
101
|
+
new_left = bbox1[0] - bbox1[2] / 2
|
102
|
+
new_top = min((bbox2[1] - bbox2[3] / 2), (bbox1[1] - bbox1[3] / 2))
|
103
|
+
|
104
|
+
new_cx = new_left + new_w / 2
|
105
|
+
new_cy = new_top + new_h / 2
|
106
|
+
|
107
|
+
bboxm = torch.tensor([new_cx, new_cy, new_w, new_h])
|
108
|
+
return bboxm
|
109
|
+
|
110
|
+
def predict(self, imgs, max_steps, k, return_attention=False):
|
111
|
+
r"""
|
112
|
+
Inference.
|
113
|
+
The input image must be preprocessed and transformed.
|
114
|
+
|
115
|
+
Parameters
|
116
|
+
----------
|
117
|
+
img : tensor FloatTensor - torch.Size([1, 3, 448, 448])
|
118
|
+
Input image for the inference
|
119
|
+
|
120
|
+
Returns
|
121
|
+
-------
|
122
|
+
seq : list
|
123
|
+
Predictions for the tags as indices over the word_map
|
124
|
+
outputs_class : tensor(x, 3)
|
125
|
+
Classes of predicted bboxes. x is the number of bboxes. There are 3 bbox classes
|
126
|
+
|
127
|
+
outputs_coord : tensor(x, 4)
|
128
|
+
Coords of predicted bboxes. x is the number of bboxes. Each bbox is in [cxcywh] format
|
129
|
+
"""
|
130
|
+
AggProfiler().begin("predict_total", self._prof)
|
131
|
+
|
132
|
+
# Invoke encoder
|
133
|
+
self._tag_transformer.eval()
|
134
|
+
enc_out = self._encoder(imgs)
|
135
|
+
AggProfiler().end("model_encoder", self._prof)
|
136
|
+
|
137
|
+
word_map = self._init_data["word_map"]["word_map_tag"]
|
138
|
+
n_heads = self._tag_transformer._n_heads
|
139
|
+
# [1, 28, 28, 512]
|
140
|
+
encoder_out = self._tag_transformer._input_filter(
|
141
|
+
enc_out.permute(0, 3, 1, 2)
|
142
|
+
).permute(0, 2, 3, 1)
|
143
|
+
|
144
|
+
batch_size = encoder_out.size(0)
|
145
|
+
encoder_dim = encoder_out.size(-1)
|
146
|
+
enc_inputs = encoder_out.view(batch_size, -1, encoder_dim).to(self._device)
|
147
|
+
enc_inputs = enc_inputs.permute(1, 0, 2)
|
148
|
+
positions = enc_inputs.shape[0]
|
149
|
+
|
150
|
+
encoder_mask = torch.zeros(
|
151
|
+
(batch_size * n_heads, positions, positions), device=self._device
|
152
|
+
) == torch.ones(
|
153
|
+
(batch_size * n_heads, positions, positions), device=self._device
|
154
|
+
)
|
155
|
+
|
156
|
+
# Invoking tag transformer encoder before the loop to save time
|
157
|
+
AggProfiler().begin("model_tag_transformer_encoder", self._prof)
|
158
|
+
encoder_out = self._tag_transformer._encoder(enc_inputs, mask=encoder_mask)
|
159
|
+
AggProfiler().end("model_tag_transformer_encoder", self._prof)
|
160
|
+
|
161
|
+
decoded_tags = (
|
162
|
+
torch.LongTensor([word_map["<start>"]]).to(self._device).unsqueeze(1)
|
163
|
+
)
|
164
|
+
output_tags = []
|
165
|
+
cache = None
|
166
|
+
tag_H_buf = []
|
167
|
+
|
168
|
+
skip_next_tag = True
|
169
|
+
prev_tag_ucel = False
|
170
|
+
line_num = 0
|
171
|
+
|
172
|
+
# Populate bboxes_to_merge, indexes of first lcel, and last cell in a span
|
173
|
+
first_lcel = True
|
174
|
+
bboxes_to_merge = {}
|
175
|
+
cur_bbox_ind = -1
|
176
|
+
bbox_ind = 0
|
177
|
+
|
178
|
+
# i = 0
|
179
|
+
while len(output_tags) < self._max_pred_len:
|
180
|
+
decoded_embedding = self._tag_transformer._embedding(decoded_tags)
|
181
|
+
decoded_embedding = self._tag_transformer._positional_encoding(
|
182
|
+
decoded_embedding
|
183
|
+
)
|
184
|
+
AggProfiler().begin("model_tag_transformer_decoder", self._prof)
|
185
|
+
decoded, cache = self._tag_transformer._decoder(
|
186
|
+
decoded_embedding,
|
187
|
+
encoder_out,
|
188
|
+
cache,
|
189
|
+
memory_key_padding_mask=encoder_mask,
|
190
|
+
)
|
191
|
+
AggProfiler().end("model_tag_transformer_decoder", self._prof)
|
192
|
+
# Grab last feature to produce token
|
193
|
+
AggProfiler().begin("model_tag_transformer_fc", self._prof)
|
194
|
+
logits = self._tag_transformer._fc(decoded[-1, :, :]) # 1, vocab_size
|
195
|
+
AggProfiler().end("model_tag_transformer_fc", self._prof)
|
196
|
+
new_tag = logits.argmax(1).item()
|
197
|
+
|
198
|
+
# STRUCTURE ERROR CORRECTION
|
199
|
+
# Correction for first line xcel...
|
200
|
+
if line_num == 0:
|
201
|
+
if new_tag == word_map["xcel"]:
|
202
|
+
new_tag = word_map["lcel"]
|
203
|
+
|
204
|
+
# Correction for ucel, lcel sequence...
|
205
|
+
if prev_tag_ucel:
|
206
|
+
if new_tag == word_map["lcel"]:
|
207
|
+
new_tag = word_map["fcel"]
|
208
|
+
|
209
|
+
# End of generation
|
210
|
+
if new_tag == word_map["<end>"]:
|
211
|
+
output_tags.append(new_tag)
|
212
|
+
decoded_tags = torch.cat(
|
213
|
+
[
|
214
|
+
decoded_tags,
|
215
|
+
torch.LongTensor([new_tag]).unsqueeze(1).to(self._device),
|
216
|
+
],
|
217
|
+
dim=0,
|
218
|
+
) # current_output_len, 1
|
219
|
+
break
|
220
|
+
output_tags.append(new_tag)
|
221
|
+
|
222
|
+
# BBOX PREDICTION
|
223
|
+
|
224
|
+
# MAKE SURE TO SYNC NUMBER OF CELLS WITH NUMBER OF BBOXes
|
225
|
+
if not skip_next_tag:
|
226
|
+
if new_tag in [
|
227
|
+
word_map["fcel"],
|
228
|
+
word_map["ecel"],
|
229
|
+
word_map["ched"],
|
230
|
+
word_map["rhed"],
|
231
|
+
word_map["srow"],
|
232
|
+
word_map["nl"],
|
233
|
+
word_map["ucel"],
|
234
|
+
]:
|
235
|
+
# GENERATE BBOX HERE TOO (All other cases)...
|
236
|
+
tag_H_buf.append(decoded[-1, :, :])
|
237
|
+
if first_lcel is not True:
|
238
|
+
# Mark end index for horizontal cell bbox merge
|
239
|
+
bboxes_to_merge[cur_bbox_ind] = bbox_ind
|
240
|
+
bbox_ind += 1
|
241
|
+
|
242
|
+
# Treat horisontal span bboxes...
|
243
|
+
if new_tag != word_map["lcel"]:
|
244
|
+
first_lcel = True
|
245
|
+
else:
|
246
|
+
if first_lcel:
|
247
|
+
# GENERATE BBOX HERE (Beginning of horisontal span)...
|
248
|
+
tag_H_buf.append(decoded[-1, :, :])
|
249
|
+
first_lcel = False
|
250
|
+
# Mark start index for cell bbox merge
|
251
|
+
cur_bbox_ind = bbox_ind
|
252
|
+
bboxes_to_merge[cur_bbox_ind] = -1
|
253
|
+
bbox_ind += 1
|
254
|
+
|
255
|
+
if new_tag in [word_map["nl"], word_map["ucel"], word_map["xcel"]]:
|
256
|
+
skip_next_tag = True
|
257
|
+
else:
|
258
|
+
skip_next_tag = False
|
259
|
+
|
260
|
+
# Register ucel in sequence...
|
261
|
+
if new_tag == word_map["ucel"]:
|
262
|
+
prev_tag_ucel = True
|
263
|
+
else:
|
264
|
+
prev_tag_ucel = False
|
265
|
+
|
266
|
+
decoded_tags = torch.cat(
|
267
|
+
[
|
268
|
+
decoded_tags,
|
269
|
+
torch.LongTensor([new_tag]).unsqueeze(1).to(self._device),
|
270
|
+
],
|
271
|
+
dim=0,
|
272
|
+
) # current_output_len, 1
|
273
|
+
seq = decoded_tags.squeeze().tolist()
|
274
|
+
|
275
|
+
if self._bbox:
|
276
|
+
AggProfiler().begin("model_bbox_decoder", self._prof)
|
277
|
+
outputs_class, outputs_coord = self._bbox_decoder.inference(
|
278
|
+
enc_out, tag_H_buf
|
279
|
+
)
|
280
|
+
AggProfiler().end("model_bbox_decoder", self._prof)
|
281
|
+
else:
|
282
|
+
outputs_class, outputs_coord = None, None
|
283
|
+
|
284
|
+
outputs_class.to(self._device)
|
285
|
+
outputs_coord.to(self._device)
|
286
|
+
|
287
|
+
########################################################################################
|
288
|
+
# Merge First and Last predicted BBOX for each span, according to bboxes_to_merge
|
289
|
+
########################################################################################
|
290
|
+
|
291
|
+
outputs_class1 = []
|
292
|
+
outputs_coord1 = []
|
293
|
+
boxes_to_skip = []
|
294
|
+
|
295
|
+
for box_ind in range(len(outputs_coord)):
|
296
|
+
box1 = outputs_coord[box_ind].to(self._device)
|
297
|
+
cls1 = outputs_class[box_ind].to(self._device)
|
298
|
+
if box_ind in bboxes_to_merge:
|
299
|
+
box2 = outputs_coord[bboxes_to_merge[box_ind]].to(self._device)
|
300
|
+
boxes_to_skip.append(bboxes_to_merge[box_ind])
|
301
|
+
boxm = self.mergebboxes(box1, box2).to(self._device)
|
302
|
+
outputs_coord1.append(boxm)
|
303
|
+
outputs_class1.append(cls1)
|
304
|
+
else:
|
305
|
+
if box_ind not in boxes_to_skip:
|
306
|
+
outputs_coord1.append(box1)
|
307
|
+
outputs_class1.append(cls1)
|
308
|
+
|
309
|
+
if len(outputs_coord1) > 0:
|
310
|
+
outputs_coord1 = torch.stack(outputs_coord1)
|
311
|
+
if len(outputs_class1) > 0:
|
312
|
+
outputs_class1 = torch.stack(outputs_class1)
|
313
|
+
|
314
|
+
outputs_class = outputs_class1
|
315
|
+
outputs_coord = outputs_coord1
|
316
|
+
|
317
|
+
# Do the rest of the steps...
|
318
|
+
AggProfiler().end("predict_total", self._prof)
|
319
|
+
num_tab_cells = seq.count(4) + seq.count(5)
|
320
|
+
num_rows = seq.count(9)
|
321
|
+
self._log().info(
|
322
|
+
"OTSL predicted table cells#: {}; rows#: {}".format(num_tab_cells, num_rows)
|
323
|
+
)
|
324
|
+
return seq, outputs_class, outputs_coord
|