docling-ibm-models 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. docling_ibm_models/layoutmodel/layout_predictor.py +171 -0
  2. docling_ibm_models/tableformer/__init__.py +0 -0
  3. docling_ibm_models/tableformer/common.py +200 -0
  4. docling_ibm_models/tableformer/data_management/__init__.py +0 -0
  5. docling_ibm_models/tableformer/data_management/data_transformer.py +504 -0
  6. docling_ibm_models/tableformer/data_management/functional.py +574 -0
  7. docling_ibm_models/tableformer/data_management/matching_post_processor.py +1325 -0
  8. docling_ibm_models/tableformer/data_management/tf_cell_matcher.py +596 -0
  9. docling_ibm_models/tableformer/data_management/tf_dataset.py +1233 -0
  10. docling_ibm_models/tableformer/data_management/tf_predictor.py +1020 -0
  11. docling_ibm_models/tableformer/data_management/transforms.py +396 -0
  12. docling_ibm_models/tableformer/models/__init__.py +0 -0
  13. docling_ibm_models/tableformer/models/common/__init__.py +0 -0
  14. docling_ibm_models/tableformer/models/common/base_model.py +279 -0
  15. docling_ibm_models/tableformer/models/table04_rs/__init__.py +0 -0
  16. docling_ibm_models/tableformer/models/table04_rs/bbox_decoder_rs.py +163 -0
  17. docling_ibm_models/tableformer/models/table04_rs/encoder04_rs.py +72 -0
  18. docling_ibm_models/tableformer/models/table04_rs/tablemodel04_rs.py +324 -0
  19. docling_ibm_models/tableformer/models/table04_rs/transformer_rs.py +203 -0
  20. docling_ibm_models/tableformer/otsl.py +541 -0
  21. docling_ibm_models/tableformer/settings.py +90 -0
  22. docling_ibm_models/tableformer/test_dataset_cache.py +37 -0
  23. docling_ibm_models/tableformer/test_prepare_image.py +99 -0
  24. docling_ibm_models/tableformer/utils/__init__.py +0 -0
  25. docling_ibm_models/tableformer/utils/app_profiler.py +243 -0
  26. docling_ibm_models/tableformer/utils/torch_utils.py +216 -0
  27. docling_ibm_models/tableformer/utils/utils.py +376 -0
  28. docling_ibm_models/tableformer/utils/variance.py +175 -0
  29. docling_ibm_models-0.1.0.dist-info/LICENSE +21 -0
  30. docling_ibm_models-0.1.0.dist-info/METADATA +172 -0
  31. docling_ibm_models-0.1.0.dist-info/RECORD +32 -0
  32. docling_ibm_models-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,203 @@
1
+ #
2
+ # Copyright IBM Corp. 2024 - 2024
3
+ # SPDX-License-Identifier: MIT
4
+ #
5
+ import logging
6
+ import math
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from torch import Tensor, nn
11
+
12
+ import docling_ibm_models.tableformer.utils.utils as u
13
+
14
+ LOG_LEVEL = logging.INFO
15
+ # LOG_LEVEL = logging.DEBUG
16
+
17
+
18
+ class PositionalEncoding(nn.Module):
19
+ def __init__(self, d_model, dropout=0.1, max_len=1024):
20
+ super(PositionalEncoding, self).__init__()
21
+ self.dropout = nn.Dropout(p=dropout)
22
+
23
+ pe = torch.zeros(max_len, d_model)
24
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
25
+ div_term = torch.exp(
26
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
27
+ )
28
+ pe[:, 0::2] = torch.sin(position * div_term)
29
+ pe[:, 1::2] = torch.cos(position * div_term)
30
+ pe = pe.unsqueeze(0).transpose(0, 1)
31
+ self.register_buffer("pe", pe)
32
+
33
+ def forward(self, x):
34
+ x = x + self.pe[: x.size(0), :]
35
+ return self.dropout(x)
36
+
37
+
38
+ class TMTransformerDecoder(nn.TransformerDecoder):
39
+ def forward(
40
+ self,
41
+ tgt: Tensor,
42
+ memory: Optional[Tensor] = None,
43
+ cache: Optional[Tensor] = None,
44
+ memory_mask: Optional[Tensor] = None,
45
+ tgt_key_padding_mask: Optional[Tensor] = None,
46
+ memory_key_padding_mask: Optional[Tensor] = None,
47
+ ) -> Tensor:
48
+ """
49
+ Args:
50
+ tgt (Tensor): encoded tags. (tags_len,bsz,hidden_dim)
51
+ memory (Tensor): encoded image (enc_image_size,bsz,hidden_dim)
52
+ cache (Optional[Tensor]): None during training, only used during inference.
53
+ Returns:
54
+ output (Tensor): (tags_len,bsz,hidden_dim)
55
+ """
56
+
57
+ output = tgt
58
+
59
+ # cache
60
+ tag_cache = []
61
+ for i, mod in enumerate(self.layers):
62
+ output = mod(output, memory)
63
+ tag_cache.append(output)
64
+ if cache is not None:
65
+ output = torch.cat([cache[i], output], dim=0)
66
+
67
+ if cache is not None:
68
+ out_cache = torch.cat([cache, torch.stack(tag_cache, dim=0)], dim=1)
69
+ else:
70
+ out_cache = torch.stack(tag_cache, dim=0)
71
+
72
+ return output, out_cache
73
+
74
+
75
+ class TMTransformerDecoderLayer(nn.TransformerDecoderLayer):
76
+ def forward(
77
+ self,
78
+ tgt: Tensor,
79
+ memory: Optional[Tensor] = None,
80
+ memory_mask: Optional[Tensor] = None,
81
+ tgt_key_padding_mask: Optional[Tensor] = None,
82
+ memory_key_padding_mask: Optional[Tensor] = None,
83
+ ) -> Tensor:
84
+ """
85
+ Args:
86
+ same as TMTransformerDecoder
87
+ Returns:
88
+ Tensor:
89
+ During training (seq_len,bsz,hidden_dim)
90
+ If eval mode: embedding of last tag: (1,bsz,hidden_dim)
91
+ """
92
+
93
+ # From PyTorch but modified to only use the last tag
94
+ tgt_last_tok = tgt[-1:, :, :]
95
+
96
+ tmp_tgt = self.self_attn(
97
+ tgt_last_tok,
98
+ tgt,
99
+ tgt,
100
+ attn_mask=None, # None, because we only care about the last tag
101
+ key_padding_mask=tgt_key_padding_mask,
102
+ )[0]
103
+ tgt_last_tok = tgt_last_tok + self.dropout1(tmp_tgt)
104
+ tgt_last_tok = self.norm1(tgt_last_tok)
105
+
106
+ if memory is not None:
107
+ tmp_tgt = self.multihead_attn(
108
+ tgt_last_tok,
109
+ memory,
110
+ memory,
111
+ attn_mask=memory_mask,
112
+ key_padding_mask=memory_key_padding_mask,
113
+ )[0]
114
+ tgt_last_tok = tgt_last_tok + self.dropout2(tmp_tgt)
115
+ tgt_last_tok = self.norm2(tgt_last_tok)
116
+
117
+ tmp_tgt = self.linear2(
118
+ self.dropout(self.activation(self.linear1(tgt_last_tok)))
119
+ )
120
+ tgt_last_tok = tgt_last_tok + self.dropout3(tmp_tgt)
121
+ tgt_last_tok = self.norm3(tgt_last_tok)
122
+ return tgt_last_tok
123
+
124
+
125
+ class Tag_Transformer(nn.Module):
126
+ """
127
+ "Attention Is All You Need" - https://arxiv.org/abs/1706.03762
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ device,
133
+ vocab_size,
134
+ td_encode,
135
+ embed_dim,
136
+ encoder_layers,
137
+ decoder_layers,
138
+ enc_image_size,
139
+ dropout=0.1,
140
+ n_heads=4,
141
+ dim_ff=1024,
142
+ ):
143
+
144
+ super(Tag_Transformer, self).__init__()
145
+
146
+ self._device = device
147
+ self._n_heads = n_heads
148
+ self._embedding = nn.Embedding(vocab_size, embed_dim)
149
+ self._positional_encoding = PositionalEncoding(embed_dim)
150
+ self._td_encode = td_encode
151
+
152
+ self._encoder = nn.TransformerEncoder(
153
+ nn.TransformerEncoderLayer(
154
+ d_model=embed_dim, nhead=n_heads, dim_feedforward=dim_ff
155
+ ),
156
+ num_layers=encoder_layers,
157
+ )
158
+
159
+ self._decoder = TMTransformerDecoder(
160
+ TMTransformerDecoderLayer(
161
+ d_model=embed_dim,
162
+ nhead=n_heads,
163
+ dim_feedforward=dim_ff,
164
+ ),
165
+ num_layers=decoder_layers,
166
+ )
167
+
168
+ self._decoder_dim = embed_dim
169
+ self._enc_image_size = enc_image_size
170
+ self._input_filter = u.resnet_block(stride=1)
171
+ self._fc = nn.Linear(embed_dim, vocab_size)
172
+
173
+ def inference(self, enc_inputs, tags, tag_lens, num_cells):
174
+ # CNN backbone image encoding
175
+ enc_inputs = self._input_filter(enc_inputs.permute(0, 3, 1, 2)).permute(
176
+ 0, 2, 3, 1
177
+ )
178
+
179
+ batch_size = enc_inputs.size(0)
180
+ encoder_dim = enc_inputs.size(-1)
181
+
182
+ enc_inputs = enc_inputs.view(batch_size, -1, encoder_dim).to(self._device)
183
+
184
+ enc_inputs = enc_inputs.permute(1, 0, 2)
185
+ positions = enc_inputs.shape[0]
186
+ # Transformer Encoder Encoded Image mask need to check if its useful
187
+ encoder_mask = torch.zeros(
188
+ (batch_size * self._n_heads, positions, positions), device=self._device
189
+ ) == torch.ones(
190
+ (batch_size * self._n_heads, positions, positions), device=self._device
191
+ )
192
+
193
+ # Transformer Encoder
194
+ encoder_out = self._encoder(enc_inputs, mask=encoder_mask)
195
+
196
+ decode_lengths = (tag_lens - 1).tolist()
197
+
198
+ tgt = self._positional_encoding(self._embedding(tags).permute(1, 0, 2))
199
+
200
+ decoded = self._decoder(tgt, memory=encoder_out)
201
+ decoded = decoded.permute(1, 0, 2)
202
+ predictions = self._fc(decoded)
203
+ return predictions, decode_lengths