docling-ibm-models 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
  2. docling_ibm_models/tableformer/__init__.py +0 -0
  3. docling_ibm_models/tableformer/common.py +200 -0
  4. docling_ibm_models/tableformer/data_management/__init__.py +0 -0
  5. docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
  6. docling_ibm_models/tableformer/data_management/functional.py +574 -0
  7. docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
  8. docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
  9. docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
  10. docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
  11. docling_ibm_models/tableformer/data_management/transforms.py +396 -0
  12. docling_ibm_models/tableformer/models/__init__.py +0 -0
  13. docling_ibm_models/tableformer/models/common/__init__.py +0 -0
  14. docling_ibm_models/tableformer/models/common/base_model.py +279 -0
  15. docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
  16. docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
  17. docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
  18. docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
  19. docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
  20. docling_ibm_models/tableformer/otsl.py +541 -0
  21. docling_ibm_models/tableformer/settings.py +90 -0
  22. docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
  23. docling_ibm_models/tableformer/test_prepare_image.py +99 -0
  24. docling_ibm_models/tableformer/utils/__init__.py +0 -0
  25. docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
  26. docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
  27. docling_ibm_models/tableformer/utils/utils.py +376 -0
  28. docling_ibm_models/tableformer/utils/variance.py +175 -0
  29. docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
  30. docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
  31. docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
  32. docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,163 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import logging
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import docling_ibm_models.tableformer.settings as s
11
+ import docling_ibm_models.tableformer.utils.utils as u
12
+
13
+ # from scipy.optimize import linear_sum_assignment
14
+
15
+ LOG_LEVEL = logging.INFO
16
+
17
+
18
+ class CellAttention(nn.Module):
19
+ """
20
+ Attention Network.
21
+ """
22
+
23
+ def __init__(self, encoder_dim, tag_decoder_dim, language_dim, attention_dim):
24
+ """
25
+ :param encoder_dim: feature size of encoded images
26
+ :param tag_decoder_dim: size of tag decoder's RNN
27
+ :param language_dim: size of language model's RNN
28
+ :param attention_dim: size of the attention network
29
+ """
30
+ super(CellAttention, self).__init__()
31
+ # linear layer to transform encoded image
32
+ self._encoder_att = nn.Linear(encoder_dim, attention_dim)
33
+ # linear layer to transform tag decoder output
34
+ self._tag_decoder_att = nn.Linear(tag_decoder_dim, attention_dim)
35
+ # linear layer to transform language models output
36
+ self._language_att = nn.Linear(language_dim, attention_dim)
37
+ # linear layer to calculate values to be softmax-ed
38
+ self._full_att = nn.Linear(attention_dim, 1)
39
+ self._relu = nn.ReLU()
40
+ self._softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
41
+
42
+ def _log(self):
43
+ # Setup a custom logger
44
+ return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
45
+
46
+ def forward(self, encoder_out, decoder_hidden, language_out):
47
+ """
48
+ Forward propagation.
49
+ :param encoder_out: encoded images, a tensor of dimension (1, num_pixels, encoder_dim)
50
+ :param decoder_hidden: tag decoder output, a tensor of dimension [(num_cells,
51
+ tag_decoder_dim)]
52
+ :param language_out: language model output, a tensor of dimension (num_cells,
53
+ language_dim)
54
+ :return: attention weighted encoding, weights
55
+ """
56
+ att1 = self._encoder_att(encoder_out) # (1, num_pixels, attention_dim)
57
+ att2 = self._tag_decoder_att(decoder_hidden) # (num_cells, tag_decoder_dim)
58
+ att3 = self._language_att(language_out) # (num_cells, attention_dim)
59
+ att = self._full_att(
60
+ self._relu(att1 + att2.unsqueeze(1) + att3.unsqueeze(1))
61
+ ).squeeze(2)
62
+ alpha = self._softmax(att) # (num_cells, num_pixels)
63
+ # (num_cells, encoder_dim)
64
+ attention_weighted_encoding = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
65
+ return attention_weighted_encoding, alpha
66
+
67
+
68
+ class BBoxDecoder(nn.Module):
69
+ """
70
+ CellDecoder generates cell content
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ device,
76
+ attention_dim,
77
+ embed_dim,
78
+ tag_decoder_dim,
79
+ decoder_dim,
80
+ num_classes,
81
+ encoder_dim=512,
82
+ dropout=0.5,
83
+ cnn_layer_stride=1,
84
+ ):
85
+ """
86
+ :param attention_dim: size of attention network
87
+ :param embed_dim: embedding size
88
+ :param tag_decoder_dim: size of tag decoder's RNN
89
+ :param decoder_dim: size of decoder's RNN
90
+ :param vocab_size: size of vocabulary
91
+ :param encoder_dim: feature size of encoded images
92
+ :param dropout: dropout
93
+ :param mini_batch_size: batch size of cells to reduce GPU memory usage
94
+ """
95
+ super(BBoxDecoder, self).__init__()
96
+ self._device = device
97
+ self._encoder_dim = encoder_dim
98
+ self._attention_dim = attention_dim
99
+ self._embed_dim = embed_dim
100
+ self._decoder_dim = decoder_dim
101
+ self._dropout = dropout
102
+ self._num_classes = num_classes
103
+
104
+ if cnn_layer_stride is not None:
105
+ self._input_filter = u.resnet_block(stride=cnn_layer_stride)
106
+ # attention network
107
+ self._attention = CellAttention(
108
+ encoder_dim, tag_decoder_dim, decoder_dim, attention_dim
109
+ )
110
+ # decoder LSTMCell
111
+ self._init_h = nn.Linear(encoder_dim, decoder_dim)
112
+
113
+ # linear layer to create a sigmoid-activated gate
114
+ self._f_beta = nn.Linear(decoder_dim, encoder_dim)
115
+ self._sigmoid = nn.Sigmoid()
116
+ self._dropout = nn.Dropout(p=self._dropout)
117
+ self._class_embed = nn.Linear(512, self._num_classes + 1)
118
+ self._bbox_embed = u.MLP(512, 256, 4, 3)
119
+
120
+ def _init_hidden_state(self, encoder_out, batch_size):
121
+ mean_encoder_out = encoder_out.mean(dim=1)
122
+ h = self._init_h(mean_encoder_out).expand(batch_size, -1)
123
+ return h
124
+
125
+ def _log(self):
126
+ # Setup a custom logger
127
+ return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
128
+
129
+ def inference(self, encoder_out, tag_H):
130
+ """
131
+ Inference on test images with beam search
132
+ """
133
+ if hasattr(self, "_input_filter"):
134
+ encoder_out = self._input_filter(encoder_out.permute(0, 3, 1, 2)).permute(
135
+ 0, 2, 3, 1
136
+ )
137
+
138
+ encoder_dim = encoder_out.size(3)
139
+
140
+ # Flatten encoding (1, num_pixels, encoder_dim)
141
+ encoder_out = encoder_out.view(1, -1, encoder_dim)
142
+
143
+ num_cells = len(tag_H)
144
+ predictions_bboxes = []
145
+ predictions_classes = []
146
+
147
+ for c_id in range(num_cells):
148
+ # Start decoding
149
+ h = self._init_hidden_state(encoder_out, 1)
150
+ cell_tag_H = tag_H[c_id]
151
+ awe, _ = self._attention(encoder_out, cell_tag_H, h)
152
+ gate = self._sigmoid(self._f_beta(h))
153
+ awe = gate * awe
154
+ h = awe * h
155
+
156
+ predictions_bboxes.append(self._bbox_embed(h).sigmoid())
157
+ predictions_classes.append(self._class_embed(h))
158
+ if len(predictions_bboxes) > 0:
159
+ predictions_bboxes = torch.stack([x[0] for x in predictions_bboxes])
160
+ if len(predictions_classes) > 0:
161
+ predictions_classes = torch.stack([x[0] for x in predictions_classes])
162
+
163
+ return predictions_classes, predictions_bboxes
@@ -0,0 +1,72 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import logging
6
+
7
+ import torch.nn as nn
8
+ import torchvision
9
+
10
+ import docling_ibm_models.tableformer.settings as s
11
+
12
+ LOG_LEVEL = logging.INFO
13
+ # LOG_LEVEL = logging.DEBUG
14
+
15
+
16
+ class Encoder04(nn.Module):
17
+ """
18
+ Encoder based on resnet-18
19
+ """
20
+
21
+ def __init__(self, enc_image_size, enc_dim=512):
22
+ r"""
23
+ Parameters
24
+ ----------
25
+ enc_image_size : int
26
+ Assuming that the encoded image is a square, this is the length of the image side
27
+ """
28
+
29
+ super(Encoder04, self).__init__()
30
+ self.enc_image_size = enc_image_size
31
+ self._encoder_dim = enc_dim
32
+
33
+ resnet = torchvision.models.resnet18(pretrained=False)
34
+ modules = list(resnet.children())[:-3]
35
+
36
+ self._resnet = nn.Sequential(*modules)
37
+ self._adaptive_pool = nn.AdaptiveAvgPool2d(
38
+ (self.enc_image_size, self.enc_image_size)
39
+ )
40
+
41
+ def _log(self):
42
+ # Setup a custom logger
43
+ return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
44
+
45
+ def get_encoder_dim(self):
46
+ return self._encoder_dim
47
+
48
+ def forward(self, images):
49
+ """
50
+ Forward propagation
51
+ The encoder_dim 512 is decided by the structure of the image network (modified resnet-19)
52
+
53
+ Parameters
54
+ ----------
55
+ images : tensor (batch_size, image_channels, resized_image, resized_image)
56
+ images input
57
+
58
+ Returns
59
+ -------
60
+ tensor : (batch_size, enc_image_size, enc_image_size, 256)
61
+ encoded images
62
+ """
63
+ out = self._resnet(images) # (batch_size, 256, 28, 28)
64
+ self._log().debug("forward: resnet out: {}".format(out.size()))
65
+ out = self._adaptive_pool(out)
66
+ out = out.permute(
67
+ 0, 2, 3, 1
68
+ ) # (batch_size, enc_image_size, enc_image_size, 256)
69
+
70
+ self._log().debug("enc forward: final out: {}".format(out.size()))
71
+
72
+ return out
@@ -0,0 +1,324 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import logging
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ import docling_ibm_models.tableformer.settings as s
11
+ from docling_ibm_models.tableformer.models.common.base_model import BaseModel
12
+ from docling_ibm_models.tableformer.models.table04_rs.bbox_decoder_rs import BBoxDecoder
13
+ from docling_ibm_models.tableformer.models.table04_rs.encoder04_rs import Encoder04
14
+ from docling_ibm_models.tableformer.models.table04_rs.transformer_rs import (
15
+ Tag_Transformer,
16
+ )
17
+ from docling_ibm_models.tableformer.utils.app_profiler import AggProfiler
18
+
19
+ LOG_LEVEL = logging.WARN
20
+ # LOG_LEVEL = logging.INFO
21
+ # LOG_LEVEL = logging.DEBUG
22
+
23
+
24
+ class TableModel04_rs(BaseModel, nn.Module):
25
+ r"""
26
+ TableNet04Model encoder, dual-decoder model with OTSL+ support
27
+ """
28
+
29
+ def __init__(self, config, init_data, purpose, device):
30
+ super(TableModel04_rs, self).__init__(config, init_data, device)
31
+
32
+ self._prof = config["predict"].get("profiling", False)
33
+ self._device = device
34
+ # Extract the word_map from the init_data
35
+ word_map = init_data["word_map"]
36
+
37
+ # Encoder
38
+ self._enc_image_size = config["model"]["enc_image_size"]
39
+ self._encoder_dim = config["model"]["hidden_dim"]
40
+ self._encoder = Encoder04(self._enc_image_size, self._encoder_dim).to(device)
41
+
42
+ tag_vocab_size = len(word_map["word_map_tag"])
43
+
44
+ td_encode = []
45
+ for t in ["ecel", "fcel", "ched", "rhed", "srow"]:
46
+ if t in word_map["word_map_tag"]:
47
+ td_encode.append(word_map["word_map_tag"][t])
48
+ self._log().debug("td_encode length: {}".format(len(td_encode)))
49
+ self._log().debug("td_encode: {}".format(td_encode))
50
+
51
+ self._tag_attention_dim = config["model"]["tag_attention_dim"]
52
+ self._tag_embed_dim = config["model"]["tag_embed_dim"]
53
+ self._tag_decoder_dim = config["model"]["tag_decoder_dim"]
54
+ self._decoder_dim = config["model"]["hidden_dim"]
55
+ self._dropout = config["model"]["dropout"]
56
+
57
+ self._bbox = config["train"]["bbox"]
58
+ self._bbox_attention_dim = config["model"]["bbox_attention_dim"]
59
+ self._bbox_embed_dim = config["model"]["bbox_embed_dim"]
60
+ self._bbox_decoder_dim = config["model"]["hidden_dim"]
61
+
62
+ self._enc_layers = config["model"]["enc_layers"]
63
+ self._dec_layers = config["model"]["dec_layers"]
64
+ self._n_heads = config["model"]["nheads"]
65
+
66
+ self._num_classes = config["model"]["bbox_classes"]
67
+ self._enc_image_size = config["model"]["enc_image_size"]
68
+
69
+ self._max_pred_len = config["predict"]["max_steps"]
70
+
71
+ self._tag_transformer = Tag_Transformer(
72
+ device,
73
+ tag_vocab_size,
74
+ td_encode,
75
+ self._decoder_dim,
76
+ self._enc_layers,
77
+ self._dec_layers,
78
+ self._enc_image_size,
79
+ n_heads=self._n_heads,
80
+ ).to(device)
81
+
82
+ self._bbox_decoder = BBoxDecoder(
83
+ device,
84
+ self._bbox_attention_dim,
85
+ self._bbox_embed_dim,
86
+ self._tag_decoder_dim,
87
+ self._bbox_decoder_dim,
88
+ self._num_classes,
89
+ self._encoder_dim,
90
+ self._dropout,
91
+ ).to(device)
92
+
93
+ def _log(self):
94
+ # Setup a custom logger
95
+ return s.get_custom_logger(self.__class__.__name__, LOG_LEVEL)
96
+
97
+ def mergebboxes(self, bbox1, bbox2):
98
+ new_w = (bbox2[0] + bbox2[2] / 2) - (bbox1[0] - bbox1[2] / 2)
99
+ new_h = (bbox2[1] + bbox2[3] / 2) - (bbox1[1] - bbox1[3] / 2)
100
+
101
+ new_left = bbox1[0] - bbox1[2] / 2
102
+ new_top = min((bbox2[1] - bbox2[3] / 2), (bbox1[1] - bbox1[3] / 2))
103
+
104
+ new_cx = new_left + new_w / 2
105
+ new_cy = new_top + new_h / 2
106
+
107
+ bboxm = torch.tensor([new_cx, new_cy, new_w, new_h])
108
+ return bboxm
109
+
110
+ def predict(self, imgs, max_steps, k, return_attention=False):
111
+ r"""
112
+ Inference.
113
+ The input image must be preprocessed and transformed.
114
+
115
+ Parameters
116
+ ----------
117
+ img : tensor FloatTensor - torch.Size([1, 3, 448, 448])
118
+ Input image for the inference
119
+
120
+ Returns
121
+ -------
122
+ seq : list
123
+ Predictions for the tags as indices over the word_map
124
+ outputs_class : tensor(x, 3)
125
+ Classes of predicted bboxes. x is the number of bboxes. There are 3 bbox classes
126
+
127
+ outputs_coord : tensor(x, 4)
128
+ Coords of predicted bboxes. x is the number of bboxes. Each bbox is in [cxcywh] format
129
+ """
130
+ AggProfiler().begin("predict_total", self._prof)
131
+
132
+ # Invoke encoder
133
+ self._tag_transformer.eval()
134
+ enc_out = self._encoder(imgs)
135
+ AggProfiler().end("model_encoder", self._prof)
136
+
137
+ word_map = self._init_data["word_map"]["word_map_tag"]
138
+ n_heads = self._tag_transformer._n_heads
139
+ # [1, 28, 28, 512]
140
+ encoder_out = self._tag_transformer._input_filter(
141
+ enc_out.permute(0, 3, 1, 2)
142
+ ).permute(0, 2, 3, 1)
143
+
144
+ batch_size = encoder_out.size(0)
145
+ encoder_dim = encoder_out.size(-1)
146
+ enc_inputs = encoder_out.view(batch_size, -1, encoder_dim).to(self._device)
147
+ enc_inputs = enc_inputs.permute(1, 0, 2)
148
+ positions = enc_inputs.shape[0]
149
+
150
+ encoder_mask = torch.zeros(
151
+ (batch_size * n_heads, positions, positions), device=self._device
152
+ ) == torch.ones(
153
+ (batch_size * n_heads, positions, positions), device=self._device
154
+ )
155
+
156
+ # Invoking tag transformer encoder before the loop to save time
157
+ AggProfiler().begin("model_tag_transformer_encoder", self._prof)
158
+ encoder_out = self._tag_transformer._encoder(enc_inputs, mask=encoder_mask)
159
+ AggProfiler().end("model_tag_transformer_encoder", self._prof)
160
+
161
+ decoded_tags = (
162
+ torch.LongTensor([word_map["<start>"]]).to(self._device).unsqueeze(1)
163
+ )
164
+ output_tags = []
165
+ cache = None
166
+ tag_H_buf = []
167
+
168
+ skip_next_tag = True
169
+ prev_tag_ucel = False
170
+ line_num = 0
171
+
172
+ # Populate bboxes_to_merge, indexes of first lcel, and last cell in a span
173
+ first_lcel = True
174
+ bboxes_to_merge = {}
175
+ cur_bbox_ind = -1
176
+ bbox_ind = 0
177
+
178
+ # i = 0
179
+ while len(output_tags) < self._max_pred_len:
180
+ decoded_embedding = self._tag_transformer._embedding(decoded_tags)
181
+ decoded_embedding = self._tag_transformer._positional_encoding(
182
+ decoded_embedding
183
+ )
184
+ AggProfiler().begin("model_tag_transformer_decoder", self._prof)
185
+ decoded, cache = self._tag_transformer._decoder(
186
+ decoded_embedding,
187
+ encoder_out,
188
+ cache,
189
+ memory_key_padding_mask=encoder_mask,
190
+ )
191
+ AggProfiler().end("model_tag_transformer_decoder", self._prof)
192
+ # Grab last feature to produce token
193
+ AggProfiler().begin("model_tag_transformer_fc", self._prof)
194
+ logits = self._tag_transformer._fc(decoded[-1, :, :]) # 1, vocab_size
195
+ AggProfiler().end("model_tag_transformer_fc", self._prof)
196
+ new_tag = logits.argmax(1).item()
197
+
198
+ # STRUCTURE ERROR CORRECTION
199
+ # Correction for first line xcel...
200
+ if line_num == 0:
201
+ if new_tag == word_map["xcel"]:
202
+ new_tag = word_map["lcel"]
203
+
204
+ # Correction for ucel, lcel sequence...
205
+ if prev_tag_ucel:
206
+ if new_tag == word_map["lcel"]:
207
+ new_tag = word_map["fcel"]
208
+
209
+ # End of generation
210
+ if new_tag == word_map["<end>"]:
211
+ output_tags.append(new_tag)
212
+ decoded_tags = torch.cat(
213
+ [
214
+ decoded_tags,
215
+ torch.LongTensor([new_tag]).unsqueeze(1).to(self._device),
216
+ ],
217
+ dim=0,
218
+ ) # current_output_len, 1
219
+ break
220
+ output_tags.append(new_tag)
221
+
222
+ # BBOX PREDICTION
223
+
224
+ # MAKE SURE TO SYNC NUMBER OF CELLS WITH NUMBER OF BBOXes
225
+ if not skip_next_tag:
226
+ if new_tag in [
227
+ word_map["fcel"],
228
+ word_map["ecel"],
229
+ word_map["ched"],
230
+ word_map["rhed"],
231
+ word_map["srow"],
232
+ word_map["nl"],
233
+ word_map["ucel"],
234
+ ]:
235
+ # GENERATE BBOX HERE TOO (All other cases)...
236
+ tag_H_buf.append(decoded[-1, :, :])
237
+ if first_lcel is not True:
238
+ # Mark end index for horizontal cell bbox merge
239
+ bboxes_to_merge[cur_bbox_ind] = bbox_ind
240
+ bbox_ind += 1
241
+
242
+ # Treat horisontal span bboxes...
243
+ if new_tag != word_map["lcel"]:
244
+ first_lcel = True
245
+ else:
246
+ if first_lcel:
247
+ # GENERATE BBOX HERE (Beginning of horisontal span)...
248
+ tag_H_buf.append(decoded[-1, :, :])
249
+ first_lcel = False
250
+ # Mark start index for cell bbox merge
251
+ cur_bbox_ind = bbox_ind
252
+ bboxes_to_merge[cur_bbox_ind] = -1
253
+ bbox_ind += 1
254
+
255
+ if new_tag in [word_map["nl"], word_map["ucel"], word_map["xcel"]]:
256
+ skip_next_tag = True
257
+ else:
258
+ skip_next_tag = False
259
+
260
+ # Register ucel in sequence...
261
+ if new_tag == word_map["ucel"]:
262
+ prev_tag_ucel = True
263
+ else:
264
+ prev_tag_ucel = False
265
+
266
+ decoded_tags = torch.cat(
267
+ [
268
+ decoded_tags,
269
+ torch.LongTensor([new_tag]).unsqueeze(1).to(self._device),
270
+ ],
271
+ dim=0,
272
+ ) # current_output_len, 1
273
+ seq = decoded_tags.squeeze().tolist()
274
+
275
+ if self._bbox:
276
+ AggProfiler().begin("model_bbox_decoder", self._prof)
277
+ outputs_class, outputs_coord = self._bbox_decoder.inference(
278
+ enc_out, tag_H_buf
279
+ )
280
+ AggProfiler().end("model_bbox_decoder", self._prof)
281
+ else:
282
+ outputs_class, outputs_coord = None, None
283
+
284
+ outputs_class.to(self._device)
285
+ outputs_coord.to(self._device)
286
+
287
+ ########################################################################################
288
+ # Merge First and Last predicted BBOX for each span, according to bboxes_to_merge
289
+ ########################################################################################
290
+
291
+ outputs_class1 = []
292
+ outputs_coord1 = []
293
+ boxes_to_skip = []
294
+
295
+ for box_ind in range(len(outputs_coord)):
296
+ box1 = outputs_coord[box_ind].to(self._device)
297
+ cls1 = outputs_class[box_ind].to(self._device)
298
+ if box_ind in bboxes_to_merge:
299
+ box2 = outputs_coord[bboxes_to_merge[box_ind]].to(self._device)
300
+ boxes_to_skip.append(bboxes_to_merge[box_ind])
301
+ boxm = self.mergebboxes(box1, box2).to(self._device)
302
+ outputs_coord1.append(boxm)
303
+ outputs_class1.append(cls1)
304
+ else:
305
+ if box_ind not in boxes_to_skip:
306
+ outputs_coord1.append(box1)
307
+ outputs_class1.append(cls1)
308
+
309
+ if len(outputs_coord1) > 0:
310
+ outputs_coord1 = torch.stack(outputs_coord1)
311
+ if len(outputs_class1) > 0:
312
+ outputs_class1 = torch.stack(outputs_class1)
313
+
314
+ outputs_class = outputs_class1
315
+ outputs_coord = outputs_coord1
316
+
317
+ # Do the rest of the steps...
318
+ AggProfiler().end("predict_total", self._prof)
319
+ num_tab_cells = seq.count(4) + seq.count(5)
320
+ num_rows = seq.count(9)
321
+ self._log().info(
322
+ "OTSL predicted table cells#: {}; rows#: {}".format(num_tab_cells, num_rows)
323
+ )
324
+ return seq, outputs_class, outputs_coord