tinydoc-vlm 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinydoc_vlm/__init__.py +40 -0
- tinydoc_vlm/attention.py +76 -0
- tinydoc_vlm/configuration.py +83 -0
- tinydoc_vlm/data.py +92 -0
- tinydoc_vlm/decoder.py +52 -0
- tinydoc_vlm/image_processing.py +117 -0
- tinydoc_vlm/losses.py +64 -0
- tinydoc_vlm/modeling.py +194 -0
- tinydoc_vlm/output_heads.py +135 -0
- tinydoc_vlm/processing.py +177 -0
- tinydoc_vlm/token_compressor.py +82 -0
- tinydoc_vlm/trainer.py +341 -0
- tinydoc_vlm/vision_encoder.py +64 -0
- tinydoc_vlm-0.2.0.dist-info/METADATA +255 -0
- tinydoc_vlm-0.2.0.dist-info/RECORD +18 -0
- tinydoc_vlm-0.2.0.dist-info/WHEEL +5 -0
- tinydoc_vlm-0.2.0.dist-info/licenses/LICENSE +201 -0
- tinydoc_vlm-0.2.0.dist-info/top_level.txt +1 -0
tinydoc_vlm/modeling.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Dict, Any, Optional, List, Tuple, Union
|
|
4
|
+
from transformers import PreTrainedModel, GenerationMixin
|
|
5
|
+
from transformers.modeling_outputs import CausalLMOutputWithPast
|
|
6
|
+
|
|
7
|
+
from .configuration import TinyDocVLMConfig
|
|
8
|
+
from .vision_encoder import SigLIPVisionEncoder
|
|
9
|
+
from .token_compressor import PixelShuffleTokenCompressor
|
|
10
|
+
from .decoder import TinyDocDecoder
|
|
11
|
+
from .output_heads import MultiTaskOutputHeads
|
|
12
|
+
|
|
13
|
+
class TinyDocVLMPreTrainedModel(PreTrainedModel):
|
|
14
|
+
config_class = TinyDocVLMConfig
|
|
15
|
+
base_model_prefix = "tinydoc_vlm"
|
|
16
|
+
supports_gradient_checkpointing = True
|
|
17
|
+
|
|
18
|
+
def _init_weights(self, module):
|
|
19
|
+
std = getattr(self.config, "initializer_range", 0.02)
|
|
20
|
+
if isinstance(module, nn.Linear):
|
|
21
|
+
module.weight.data.normal_(mean=0.0, std=std)
|
|
22
|
+
if module.bias is not None:
|
|
23
|
+
module.bias.data.zero_()
|
|
24
|
+
elif isinstance(module, nn.Embedding):
|
|
25
|
+
module.weight.data.normal_(mean=0.0, std=std)
|
|
26
|
+
if module.padding_idx is not None:
|
|
27
|
+
module.weight.data[module.padding_idx].zero_()
|
|
28
|
+
|
|
29
|
+
class TinyDocVLMForConditionalGeneration(TinyDocVLMPreTrainedModel, GenerationMixin):
|
|
30
|
+
"""
|
|
31
|
+
TinyDoc-VLM: The World's Smallest Document Understanding Model.
|
|
32
|
+
Coordinates SigLIP Vision Encoder, PixelShuffle Compressor, and SmolLM2 Decoder.
|
|
33
|
+
"""
|
|
34
|
+
def __init__(self, config: TinyDocVLMConfig):
|
|
35
|
+
super().__init__(config)
|
|
36
|
+
|
|
37
|
+
# 1. Vision Encoder
|
|
38
|
+
self.vision_encoder = SigLIPVisionEncoder(config)
|
|
39
|
+
|
|
40
|
+
# 2. Token Compressor / Connector
|
|
41
|
+
self.compressor = PixelShuffleTokenCompressor(
|
|
42
|
+
config,
|
|
43
|
+
encoder_dim=config.vision_config.hidden_size,
|
|
44
|
+
decoder_dim=config.decoder_config.hidden_size
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
# 3. Decoder
|
|
48
|
+
self.decoder = TinyDocDecoder(config.decoder_config)
|
|
49
|
+
|
|
50
|
+
# Learnable image pad / placeholder token ID
|
|
51
|
+
self.image_token_id = getattr(config, "image_token_id", 49152)
|
|
52
|
+
|
|
53
|
+
# 2D Positional Embeddings for visual features (added to tokens before projection)
|
|
54
|
+
s = config.pixel_shuffle_scale
|
|
55
|
+
compressed_grid_size = (config.image_size // config.patch_size) // s
|
|
56
|
+
compressed_patches = compressed_grid_size ** 2
|
|
57
|
+
|
|
58
|
+
# Learnable 2D positional embeddings for the compressed visual tokens
|
|
59
|
+
self.visual_pos_embed = nn.Parameter(
|
|
60
|
+
torch.zeros(1, 1, compressed_patches, config.decoder_config.hidden_size)
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
# 4. Structured Output Heads (multi-task)
|
|
64
|
+
self.output_heads = MultiTaskOutputHeads(
|
|
65
|
+
hidden_size=config.decoder_config.hidden_size,
|
|
66
|
+
vocab_size=config.decoder_config.vocab_size,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
# Initialize weights
|
|
70
|
+
self.post_init()
|
|
71
|
+
|
|
72
|
+
def get_input_embeddings(self) -> nn.Module:
|
|
73
|
+
return self.decoder.get_input_embeddings()
|
|
74
|
+
|
|
75
|
+
def set_input_embeddings(self, value):
|
|
76
|
+
self.decoder.lm.set_input_embeddings(value)
|
|
77
|
+
|
|
78
|
+
def forward(
|
|
79
|
+
self,
|
|
80
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
81
|
+
pixel_values: Optional[torch.FloatTensor] = None,
|
|
82
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
83
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
84
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
85
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
86
|
+
labels: Optional[torch.LongTensor] = None,
|
|
87
|
+
use_cache: Optional[bool] = None,
|
|
88
|
+
output_attentions: Optional[bool] = None,
|
|
89
|
+
output_hidden_states: Optional[bool] = None,
|
|
90
|
+
return_dict: Optional[bool] = None,
|
|
91
|
+
task: Optional[str] = None,
|
|
92
|
+
) -> Union[Tuple, Dict, CausalLMOutputWithPast]:
|
|
93
|
+
|
|
94
|
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
95
|
+
output_hidden_states = True if task else output_hidden_states
|
|
96
|
+
|
|
97
|
+
# Decoding pass (no new visual input, reuse cached states)
|
|
98
|
+
if pixel_values is None and past_key_values is not None:
|
|
99
|
+
outputs = self.decoder(
|
|
100
|
+
input_ids=input_ids,
|
|
101
|
+
attention_mask=attention_mask,
|
|
102
|
+
position_ids=position_ids,
|
|
103
|
+
past_key_values=past_key_values,
|
|
104
|
+
inputs_embeds=inputs_embeds,
|
|
105
|
+
labels=labels,
|
|
106
|
+
use_cache=use_cache,
|
|
107
|
+
output_attentions=output_attentions,
|
|
108
|
+
output_hidden_states=output_hidden_states,
|
|
109
|
+
return_dict=return_dict,
|
|
110
|
+
)
|
|
111
|
+
if task:
|
|
112
|
+
hidden = outputs.hidden_states[-1] if hasattr(outputs, "hidden_states") else outputs[2]
|
|
113
|
+
head_outputs = self.output_heads(hidden, task=task)
|
|
114
|
+
return {"lm_outputs": outputs, "head_outputs": head_outputs}
|
|
115
|
+
return outputs
|
|
116
|
+
|
|
117
|
+
# Prefill pass: merge text and visual tokens into inputs_embeds
|
|
118
|
+
if inputs_embeds is None:
|
|
119
|
+
inputs_embeds = self.decoder.get_input_embeddings()(input_ids)
|
|
120
|
+
|
|
121
|
+
if pixel_values is not None:
|
|
122
|
+
visual_features = self.vision_encoder(pixel_values)
|
|
123
|
+
compressed_features = self.compressor(visual_features)
|
|
124
|
+
compressed_features = compressed_features + self.visual_pos_embed
|
|
125
|
+
|
|
126
|
+
batch_size, num_tiles, compressed_patches, decoder_dim = compressed_features.shape
|
|
127
|
+
flat_visual_features = compressed_features.view(
|
|
128
|
+
batch_size, num_tiles * compressed_patches, decoder_dim
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
image_mask = (input_ids == self.image_token_id)
|
|
132
|
+
for b in range(batch_size):
|
|
133
|
+
num_places = image_mask[b].sum().item()
|
|
134
|
+
if num_places > 0:
|
|
135
|
+
features_to_insert = flat_visual_features[b][:num_places]
|
|
136
|
+
inputs_embeds[b, image_mask[b]] = features_to_insert
|
|
137
|
+
|
|
138
|
+
outputs = self.decoder(
|
|
139
|
+
input_ids=None,
|
|
140
|
+
attention_mask=attention_mask,
|
|
141
|
+
position_ids=position_ids,
|
|
142
|
+
past_key_values=past_key_values,
|
|
143
|
+
inputs_embeds=inputs_embeds,
|
|
144
|
+
labels=labels,
|
|
145
|
+
use_cache=use_cache,
|
|
146
|
+
output_attentions=output_attentions,
|
|
147
|
+
output_hidden_states=output_hidden_states,
|
|
148
|
+
return_dict=return_dict,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if task:
|
|
152
|
+
hidden = outputs.hidden_states[-1] if hasattr(outputs, "hidden_states") else outputs[-1]
|
|
153
|
+
head_outputs = self.output_heads(hidden, task=task)
|
|
154
|
+
return {"lm_outputs": outputs, "head_outputs": head_outputs}
|
|
155
|
+
|
|
156
|
+
return outputs
|
|
157
|
+
|
|
158
|
+
def prepare_inputs_for_generation(
|
|
159
|
+
self,
|
|
160
|
+
input_ids,
|
|
161
|
+
past_key_values=None,
|
|
162
|
+
attention_mask=None,
|
|
163
|
+
inputs_embeds=None,
|
|
164
|
+
pixel_values=None,
|
|
165
|
+
**kwargs
|
|
166
|
+
) -> Dict[str, Any]:
|
|
167
|
+
"""
|
|
168
|
+
Overridden to support KV caching during auto-regressive generation.
|
|
169
|
+
"""
|
|
170
|
+
is_decoding = past_key_values is not None and pixel_values is None
|
|
171
|
+
|
|
172
|
+
if is_decoding:
|
|
173
|
+
input_ids = input_ids[:, -1:]
|
|
174
|
+
inputs_embeds = None
|
|
175
|
+
|
|
176
|
+
position_ids = kwargs.get("position_ids", None)
|
|
177
|
+
if attention_mask is not None and position_ids is None:
|
|
178
|
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
179
|
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
180
|
+
if is_decoding:
|
|
181
|
+
position_ids = position_ids[:, -input_ids.shape[-1]:]
|
|
182
|
+
|
|
183
|
+
return {
|
|
184
|
+
"input_ids": input_ids,
|
|
185
|
+
"inputs_embeds": inputs_embeds,
|
|
186
|
+
"past_key_values": past_key_values,
|
|
187
|
+
"pixel_values": pixel_values,
|
|
188
|
+
"attention_mask": attention_mask,
|
|
189
|
+
"position_ids": position_ids,
|
|
190
|
+
"use_cache": kwargs.get("use_cache"),
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
def _reorder_cache(self, past_key_values, beam_idx):
|
|
194
|
+
return self.decoder.lm._reorder_cache(past_key_values, beam_idx)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Dict
|
|
4
|
+
|
|
5
|
+
class JSONHead(nn.Module):
|
|
6
|
+
"""
|
|
7
|
+
Structured JSON generation head.
|
|
8
|
+
Projects decoder hidden states to schema-constrained JSON token logits.
|
|
9
|
+
In practice, this works as a specialized classifier over JSON structural tokens
|
|
10
|
+
plus a pointer network for field values.
|
|
11
|
+
"""
|
|
12
|
+
def __init__(self, hidden_size: int, num_json_tokens: int = 256):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.proj = nn.Sequential(
|
|
15
|
+
nn.Linear(hidden_size, hidden_size),
|
|
16
|
+
nn.GELU(),
|
|
17
|
+
nn.Linear(hidden_size, num_json_tokens),
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
21
|
+
return self.proj(hidden_states)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class KVHead(nn.Module):
|
|
25
|
+
"""
|
|
26
|
+
Key-Value extraction head.
|
|
27
|
+
Produces key-value pairs from decoder hidden states with confidence scores.
|
|
28
|
+
Uses two separate projections: one for key detection, one for value extraction.
|
|
29
|
+
"""
|
|
30
|
+
def __init__(self, hidden_size: int, num_keys: int = 128):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.key_classifier = nn.Linear(hidden_size, num_keys)
|
|
33
|
+
self.value_proj = nn.Linear(hidden_size, hidden_size)
|
|
34
|
+
self.confidence = nn.Linear(hidden_size, 1)
|
|
35
|
+
|
|
36
|
+
def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
|
|
37
|
+
key_logits = self.key_classifier(hidden_states)
|
|
38
|
+
value_embeds = self.value_proj(hidden_states)
|
|
39
|
+
conf = torch.sigmoid(self.confidence(hidden_states))
|
|
40
|
+
return {
|
|
41
|
+
"key_logits": key_logits,
|
|
42
|
+
"value_embeds": value_embeds,
|
|
43
|
+
"confidence": conf,
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class TableHead(nn.Module):
|
|
48
|
+
"""
|
|
49
|
+
Table generation head.
|
|
50
|
+
Outputs HTML/Markdown table tokens via a specialized projection.
|
|
51
|
+
Can also output structured row/cell coordinates.
|
|
52
|
+
"""
|
|
53
|
+
def __init__(self, hidden_size: int, num_table_tokens: int = 128):
|
|
54
|
+
super().__init__()
|
|
55
|
+
self.proj = nn.Sequential(
|
|
56
|
+
nn.Linear(hidden_size, hidden_size),
|
|
57
|
+
nn.GELU(),
|
|
58
|
+
nn.Linear(hidden_size, num_table_tokens),
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
62
|
+
return self.proj(hidden_states)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class OCRHead(nn.Module):
|
|
66
|
+
"""
|
|
67
|
+
OCR text generation head.
|
|
68
|
+
Decodes visual features into character-level sequences.
|
|
69
|
+
Useful for reading-order text extraction.
|
|
70
|
+
"""
|
|
71
|
+
def __init__(self, hidden_size: int, vocab_size: int = 256):
|
|
72
|
+
super().__init__()
|
|
73
|
+
self.proj = nn.Sequential(
|
|
74
|
+
nn.Linear(hidden_size, hidden_size),
|
|
75
|
+
nn.GELU(),
|
|
76
|
+
nn.Linear(hidden_size, vocab_size),
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
80
|
+
return self.proj(hidden_states)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class QAHead(nn.Module):
|
|
84
|
+
"""
|
|
85
|
+
Question-answering head.
|
|
86
|
+
Standard LM head for natural language answers.
|
|
87
|
+
"""
|
|
88
|
+
def __init__(self, hidden_size: int, vocab_size: int):
|
|
89
|
+
super().__init__()
|
|
90
|
+
self.proj = nn.Linear(hidden_size, vocab_size, bias=False)
|
|
91
|
+
|
|
92
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
93
|
+
return self.proj(hidden_states)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class MultiTaskOutputHeads(nn.Module):
|
|
97
|
+
"""
|
|
98
|
+
Container for all structured output heads.
|
|
99
|
+
Routes decoder hidden states to the appropriate task-specific head(s).
|
|
100
|
+
During training, all heads can be trained jointly with task-specific losses.
|
|
101
|
+
During inference, only the requested head is used.
|
|
102
|
+
"""
|
|
103
|
+
def __init__(
|
|
104
|
+
self,
|
|
105
|
+
hidden_size: int,
|
|
106
|
+
vocab_size: int,
|
|
107
|
+
num_json_tokens: int = 256,
|
|
108
|
+
num_keys: int = 128,
|
|
109
|
+
num_table_tokens: int = 128,
|
|
110
|
+
ocr_vocab_size: int = 256,
|
|
111
|
+
):
|
|
112
|
+
super().__init__()
|
|
113
|
+
self.json_head = JSONHead(hidden_size, num_json_tokens)
|
|
114
|
+
self.kv_head = KVHead(hidden_size, num_keys)
|
|
115
|
+
self.table_head = TableHead(hidden_size, num_table_tokens)
|
|
116
|
+
self.ocr_head = OCRHead(hidden_size, ocr_vocab_size)
|
|
117
|
+
self.qa_head = QAHead(hidden_size, vocab_size)
|
|
118
|
+
|
|
119
|
+
def forward(
|
|
120
|
+
self,
|
|
121
|
+
hidden_states: torch.Tensor,
|
|
122
|
+
task: str = "qa",
|
|
123
|
+
) -> Dict[str, torch.Tensor]:
|
|
124
|
+
if task == "json":
|
|
125
|
+
return {"logits": self.json_head(hidden_states)}
|
|
126
|
+
elif task == "kv":
|
|
127
|
+
return self.kv_head(hidden_states)
|
|
128
|
+
elif task == "table":
|
|
129
|
+
return {"logits": self.table_head(hidden_states)}
|
|
130
|
+
elif task == "ocr":
|
|
131
|
+
return {"logits": self.ocr_head(hidden_states)}
|
|
132
|
+
elif task == "qa":
|
|
133
|
+
return {"logits": self.qa_head(hidden_states)}
|
|
134
|
+
else:
|
|
135
|
+
raise ValueError(f"Unknown task: {task}")
|
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""
|
|
2
|
+
TinyDocVLMProcessor — standalone processor (does not inherit ProcessorMixin to avoid
|
|
3
|
+
strict type-checking issues in transformers<4.45). Provides the same public API.
|
|
4
|
+
"""
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
from PIL import Image
|
|
8
|
+
from typing import Dict, Any, Union, Optional, List
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from transformers import AutoTokenizer
|
|
12
|
+
|
|
13
|
+
from .image_processing import TinyDocImageProcessor
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TinyDocVLMProcessor:
|
|
17
|
+
"""
|
|
18
|
+
Coordinates TinyDocImageProcessor (image tiling + normalisation) and the
|
|
19
|
+
SmolLM2 tokenizer extended with document-special tokens.
|
|
20
|
+
|
|
21
|
+
Usage:
|
|
22
|
+
processor = TinyDocVLMProcessor()
|
|
23
|
+
inputs = processor(text=["Extract fields: <image>"], images=[pil_img])
|
|
24
|
+
# inputs → {"input_ids", "attention_mask", "pixel_values", "image_token_id"}
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
# Class-level attrs used by some HF utilities (save_pretrained etc.)
|
|
28
|
+
image_processor_class = "TinyDocImageProcessor"
|
|
29
|
+
tokenizer_class = "AutoTokenizer"
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
image_processor: Optional[TinyDocImageProcessor] = None,
|
|
34
|
+
tokenizer=None,
|
|
35
|
+
config=None,
|
|
36
|
+
**kwargs,
|
|
37
|
+
):
|
|
38
|
+
self.image_processor = image_processor or TinyDocImageProcessor()
|
|
39
|
+
self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(
|
|
40
|
+
"HuggingFaceTB/SmolLM2-135M-Instruct"
|
|
41
|
+
)
|
|
42
|
+
self.config = config
|
|
43
|
+
|
|
44
|
+
# Ensure <image> special token exists
|
|
45
|
+
self.image_token = "<image>"
|
|
46
|
+
if self.image_token not in self.tokenizer.get_vocab():
|
|
47
|
+
self.tokenizer.add_special_tokens(
|
|
48
|
+
{"additional_special_tokens": [self.image_token]}
|
|
49
|
+
)
|
|
50
|
+
self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
|
|
51
|
+
|
|
52
|
+
# ------------------------------------------------------------------
|
|
53
|
+
# Core __call__
|
|
54
|
+
# ------------------------------------------------------------------
|
|
55
|
+
|
|
56
|
+
def __call__(
|
|
57
|
+
self,
|
|
58
|
+
text: Union[str, List[str]],
|
|
59
|
+
images: Optional[Union[Image.Image, List[Image.Image]]] = None,
|
|
60
|
+
padding: bool = True,
|
|
61
|
+
truncation: bool = True,
|
|
62
|
+
max_length: Optional[int] = None,
|
|
63
|
+
return_tensors: str = "pt",
|
|
64
|
+
) -> Dict[str, Any]:
|
|
65
|
+
"""
|
|
66
|
+
Preprocesses text and images into tensors for the model.
|
|
67
|
+
|
|
68
|
+
Returns:
|
|
69
|
+
dict with keys: input_ids, attention_mask, pixel_values (optional),
|
|
70
|
+
image_token_id.
|
|
71
|
+
"""
|
|
72
|
+
# ---- 1. Image processing ------------------------------------------------
|
|
73
|
+
pixel_values = None
|
|
74
|
+
num_tiles_list: List[int] = []
|
|
75
|
+
|
|
76
|
+
if images is not None:
|
|
77
|
+
if not isinstance(images, list):
|
|
78
|
+
images = [images]
|
|
79
|
+
|
|
80
|
+
processed: List[torch.Tensor] = []
|
|
81
|
+
for img in images:
|
|
82
|
+
tile_tensor = self.image_processor.preprocess(img) # (T, 3, H, W)
|
|
83
|
+
processed.append(tile_tensor)
|
|
84
|
+
num_tiles_list.append(tile_tensor.shape[0])
|
|
85
|
+
|
|
86
|
+
# Pad to max tiles so we can stack into a single tensor
|
|
87
|
+
max_tiles = max(num_tiles_list)
|
|
88
|
+
padded: List[torch.Tensor] = []
|
|
89
|
+
sz = self.image_processor.image_size
|
|
90
|
+
for tile_tensor in processed:
|
|
91
|
+
T = tile_tensor.shape[0]
|
|
92
|
+
if T < max_tiles:
|
|
93
|
+
pad = torch.zeros(
|
|
94
|
+
(max_tiles - T, 3, sz, sz),
|
|
95
|
+
dtype=tile_tensor.dtype,
|
|
96
|
+
device=tile_tensor.device,
|
|
97
|
+
)
|
|
98
|
+
tile_tensor = torch.cat([tile_tensor, pad], dim=0)
|
|
99
|
+
padded.append(tile_tensor)
|
|
100
|
+
|
|
101
|
+
pixel_values = torch.stack(padded, dim=0) # (B, max_tiles, 3, H, W)
|
|
102
|
+
|
|
103
|
+
# ---- 2. Expand <image> tokens in text ----------------------------------
|
|
104
|
+
scale = (
|
|
105
|
+
getattr(self.config, "pixel_shuffle_scale", 3) if self.config else 3
|
|
106
|
+
)
|
|
107
|
+
patch_size = (
|
|
108
|
+
getattr(self.config, "patch_size", 16) if self.config else 16
|
|
109
|
+
)
|
|
110
|
+
sz = self.image_processor.image_size
|
|
111
|
+
tokens_per_tile = (sz // patch_size // scale) ** 2
|
|
112
|
+
|
|
113
|
+
if isinstance(text, list):
|
|
114
|
+
expanded: List[str] = []
|
|
115
|
+
for idx, t in enumerate(text):
|
|
116
|
+
if idx < len(num_tiles_list):
|
|
117
|
+
total_vis = num_tiles_list[idx] * tokens_per_tile
|
|
118
|
+
else:
|
|
119
|
+
total_vis = 0
|
|
120
|
+
expanded.append(self._expand_image_tokens(t, total_vis))
|
|
121
|
+
processed_text = expanded
|
|
122
|
+
else:
|
|
123
|
+
total_vis = num_tiles_list[0] * tokens_per_tile if num_tiles_list else 0
|
|
124
|
+
processed_text = self._expand_image_tokens(text, total_vis)
|
|
125
|
+
|
|
126
|
+
# ---- 3. Tokenise -------------------------------------------------------
|
|
127
|
+
enc = self.tokenizer(
|
|
128
|
+
processed_text,
|
|
129
|
+
padding=padding,
|
|
130
|
+
truncation=truncation,
|
|
131
|
+
max_length=max_length,
|
|
132
|
+
return_tensors=return_tensors,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
inputs: Dict[str, Any] = {
|
|
136
|
+
"input_ids": enc["input_ids"],
|
|
137
|
+
"attention_mask": enc["attention_mask"],
|
|
138
|
+
}
|
|
139
|
+
if pixel_values is not None:
|
|
140
|
+
inputs["pixel_values"] = pixel_values
|
|
141
|
+
inputs["image_token_id"] = self.image_token_id
|
|
142
|
+
|
|
143
|
+
return inputs
|
|
144
|
+
|
|
145
|
+
# ------------------------------------------------------------------
|
|
146
|
+
# Helpers
|
|
147
|
+
# ------------------------------------------------------------------
|
|
148
|
+
|
|
149
|
+
def _expand_image_tokens(self, text: str, total_tokens: int) -> str:
|
|
150
|
+
"""Replaces the single '<image>' placeholder with `total_tokens` copies."""
|
|
151
|
+
expansion = self.image_token * total_tokens
|
|
152
|
+
return text.replace(self.image_token, expansion)
|
|
153
|
+
|
|
154
|
+
# ------------------------------------------------------------------
|
|
155
|
+
# save / from_pretrained stubs (for HF Hub compatibility)
|
|
156
|
+
# ------------------------------------------------------------------
|
|
157
|
+
|
|
158
|
+
def save_pretrained(self, save_directory: str, **kwargs):
|
|
159
|
+
os.makedirs(save_directory, exist_ok=True)
|
|
160
|
+
self.image_processor.save_pretrained(save_directory)
|
|
161
|
+
self.tokenizer.save_pretrained(save_directory)
|
|
162
|
+
# Write a minimal processor_config.json
|
|
163
|
+
cfg = {
|
|
164
|
+
"processor_class": "TinyDocVLMProcessor",
|
|
165
|
+
"image_token": self.image_token,
|
|
166
|
+
"image_token_id": self.image_token_id,
|
|
167
|
+
}
|
|
168
|
+
with open(os.path.join(save_directory, "processor_config.json"), "w") as f:
|
|
169
|
+
json.dump(cfg, f, indent=2)
|
|
170
|
+
|
|
171
|
+
@classmethod
|
|
172
|
+
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
|
|
173
|
+
image_processor = TinyDocImageProcessor.from_pretrained(
|
|
174
|
+
pretrained_model_name_or_path
|
|
175
|
+
)
|
|
176
|
+
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
|
|
177
|
+
return cls(image_processor=image_processor, tokenizer=tokenizer, **kwargs)
|
|
@@ -0,0 +1,82 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import torch
|
|
3
|
+
import torch.nn as nn
|
|
4
|
+
from .configuration import TinyDocVLMConfig
|
|
5
|
+
|
|
6
|
+
class RMSNorm(nn.Module):
|
|
7
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
|
8
|
+
super().__init__()
|
|
9
|
+
self.eps = eps
|
|
10
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
|
11
|
+
|
|
12
|
+
def forward(self, x):
|
|
13
|
+
variance = x.pow(2).mean(-1, keepdim=True)
|
|
14
|
+
return x * torch.rsqrt(variance + self.eps) * self.weight
|
|
15
|
+
|
|
16
|
+
class PixelShuffleTokenCompressor(nn.Module):
|
|
17
|
+
"""
|
|
18
|
+
Performs space-to-depth token compression on Vision Transformer patch sequences.
|
|
19
|
+
Groups scale_factor x scale_factor patches and projects to decoder hidden dimension.
|
|
20
|
+
"""
|
|
21
|
+
def __init__(self, config: TinyDocVLMConfig, encoder_dim: int, decoder_dim: int):
|
|
22
|
+
super().__init__()
|
|
23
|
+
self.config = config
|
|
24
|
+
self.scale_factor = config.pixel_shuffle_scale
|
|
25
|
+
self.encoder_dim = encoder_dim
|
|
26
|
+
self.decoder_dim = decoder_dim
|
|
27
|
+
|
|
28
|
+
# After space-to-depth, channel dimension becomes encoder_dim * scale_factor^2
|
|
29
|
+
compressed_dim = encoder_dim * (self.scale_factor ** 2)
|
|
30
|
+
|
|
31
|
+
# MLP projection to map visual tokens to language model dimension
|
|
32
|
+
self.projection = nn.Sequential(
|
|
33
|
+
nn.Linear(compressed_dim, decoder_dim),
|
|
34
|
+
nn.GELU(),
|
|
35
|
+
nn.Linear(decoder_dim, decoder_dim)
|
|
36
|
+
)
|
|
37
|
+
self.norm = RMSNorm(decoder_dim)
|
|
38
|
+
|
|
39
|
+
def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
|
|
40
|
+
"""
|
|
41
|
+
Args:
|
|
42
|
+
visual_features: shape (batch_size, num_tiles, num_patches, encoder_dim)
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
compressed_features: shape (batch_size, num_tiles, num_compressed_patches, decoder_dim)
|
|
46
|
+
"""
|
|
47
|
+
batch_size, num_tiles, num_patches, encoder_dim = visual_features.shape
|
|
48
|
+
|
|
49
|
+
# Determine spatial dimensions assuming a square grid of patches
|
|
50
|
+
grid_size = int(math.sqrt(num_patches))
|
|
51
|
+
if grid_size * grid_size != num_patches:
|
|
52
|
+
raise ValueError(
|
|
53
|
+
f"Number of patches ({num_patches}) must be a perfect square to apply 2D pixel shuffle."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if grid_size % self.scale_factor != 0:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"Grid size ({grid_size}) must be divisible by pixel_shuffle_scale ({self.scale_factor})."
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
# Reshape to 2D spatial grid: (batch_size * num_tiles, grid_size, grid_size, encoder_dim)
|
|
62
|
+
x = visual_features.view(batch_size * num_tiles, grid_size, grid_size, encoder_dim)
|
|
63
|
+
|
|
64
|
+
# Apply space-to-depth: (batch_size * num_tiles, H//s, s, W//s, s, C)
|
|
65
|
+
s = self.scale_factor
|
|
66
|
+
x = x.view(batch_size * num_tiles, grid_size // s, s, grid_size // s, s, encoder_dim)
|
|
67
|
+
|
|
68
|
+
# Permute: (batch_size * num_tiles, H//s, W//s, s, s, C)
|
|
69
|
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
|
70
|
+
|
|
71
|
+
# Reshape to flatten the spatial groups into the channel dimension:
|
|
72
|
+
# (batch_size * num_tiles, (H//s) * (W//s), s * s * C)
|
|
73
|
+
new_patches = (grid_size // s) ** 2
|
|
74
|
+
x = x.view(batch_size * num_tiles, new_patches, s * s * encoder_dim)
|
|
75
|
+
|
|
76
|
+
# Project and normalize
|
|
77
|
+
x = self.projection(x)
|
|
78
|
+
x = self.norm(x)
|
|
79
|
+
|
|
80
|
+
# Reshape back to batch: (batch_size, num_tiles, new_patches, decoder_dim)
|
|
81
|
+
x = x.view(batch_size, num_tiles, new_patches, self.decoder_dim)
|
|
82
|
+
return x
|