tinydoc-vlm 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinydoc_vlm/__init__.py +40 -0
- tinydoc_vlm/attention.py +76 -0
- tinydoc_vlm/configuration.py +83 -0
- tinydoc_vlm/data.py +92 -0
- tinydoc_vlm/decoder.py +52 -0
- tinydoc_vlm/image_processing.py +117 -0
- tinydoc_vlm/losses.py +64 -0
- tinydoc_vlm/modeling.py +194 -0
- tinydoc_vlm/output_heads.py +135 -0
- tinydoc_vlm/processing.py +177 -0
- tinydoc_vlm/token_compressor.py +82 -0
- tinydoc_vlm/trainer.py +341 -0
- tinydoc_vlm/vision_encoder.py +64 -0
- tinydoc_vlm-0.2.0.dist-info/METADATA +255 -0
- tinydoc_vlm-0.2.0.dist-info/RECORD +18 -0
- tinydoc_vlm-0.2.0.dist-info/WHEEL +5 -0
- tinydoc_vlm-0.2.0.dist-info/licenses/LICENSE +201 -0
- tinydoc_vlm-0.2.0.dist-info/top_level.txt +1 -0
tinydoc_vlm/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from .configuration import TinyDocVLMConfig
|
|
2
|
+
from .vision_encoder import SigLIPVisionEncoder
|
|
3
|
+
from .token_compressor import PixelShuffleTokenCompressor
|
|
4
|
+
from .decoder import TinyDocDecoder
|
|
5
|
+
from .modeling import TinyDocVLMForConditionalGeneration, TinyDocVLMPreTrainedModel
|
|
6
|
+
from .image_processing import TinyDocImageProcessor
|
|
7
|
+
from .processing import TinyDocVLMProcessor
|
|
8
|
+
from .output_heads import MultiTaskOutputHeads, JSONHead, KVHead, TableHead, OCRHead, QAHead
|
|
9
|
+
from .data import DocumentDataset
|
|
10
|
+
from .losses import CombinedLoss
|
|
11
|
+
from .trainer import TinyDocVLMTrainer, TrainerConfig
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"TinyDocVLMConfig",
|
|
15
|
+
"SigLIPVisionEncoder",
|
|
16
|
+
"PixelShuffleTokenCompressor",
|
|
17
|
+
"TinyDocDecoder",
|
|
18
|
+
"TinyDocVLMForConditionalGeneration",
|
|
19
|
+
"TinyDocVLMPreTrainedModel",
|
|
20
|
+
"TinyDocImageProcessor",
|
|
21
|
+
"TinyDocVLMProcessor",
|
|
22
|
+
"MultiTaskOutputHeads",
|
|
23
|
+
"JSONHead",
|
|
24
|
+
"KVHead",
|
|
25
|
+
"TableHead",
|
|
26
|
+
"OCRHead",
|
|
27
|
+
"QAHead",
|
|
28
|
+
"DocumentDataset",
|
|
29
|
+
"CombinedLoss",
|
|
30
|
+
"TinyDocVLMTrainer",
|
|
31
|
+
"TrainerConfig",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
|
35
|
+
|
|
36
|
+
try:
|
|
37
|
+
AutoConfig.register("tinydoc_vlm", TinyDocVLMConfig)
|
|
38
|
+
AutoModelForCausalLM.register(TinyDocVLMConfig, TinyDocVLMForConditionalGeneration)
|
|
39
|
+
except ValueError:
|
|
40
|
+
pass
|
tinydoc_vlm/attention.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False) -> torch.Tensor:
|
|
4
|
+
"""
|
|
5
|
+
Generate 2D sinusoidal positional embeddings.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
embed_dim: Dimension of the embedding (must be even)
|
|
9
|
+
grid_size: Height/Width of the grid (assumed square)
|
|
10
|
+
cls_token: If True, prepends a zero embedding for the class token
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
pos_embed: shape (grid_size * grid_size, embed_dim) or (1 + grid_size * grid_size, embed_dim)
|
|
14
|
+
"""
|
|
15
|
+
grid_h = torch.arange(grid_size, dtype=torch.float32)
|
|
16
|
+
grid_w = torch.arange(grid_size, dtype=torch.float32)
|
|
17
|
+
|
|
18
|
+
# Create coordinate grid
|
|
19
|
+
grid = torch.meshgrid(grid_h, grid_w, indexing="ij")
|
|
20
|
+
grid = torch.stack(grid, dim=0) # shape (2, grid_size, grid_size)
|
|
21
|
+
grid = grid.reshape(2, 1, grid_size, grid_size)
|
|
22
|
+
|
|
23
|
+
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
|
24
|
+
|
|
25
|
+
if cls_token:
|
|
26
|
+
# Prepend a zero embedding for CLS token
|
|
27
|
+
pos_embed = torch.cat([torch.zeros([1, embed_dim]), pos_embed], dim=0)
|
|
28
|
+
|
|
29
|
+
return pos_embed
|
|
30
|
+
|
|
31
|
+
def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
|
|
32
|
+
"""
|
|
33
|
+
Helper function to generate embeddings from a coordinate grid.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
embed_dim: Dimension of the embedding
|
|
37
|
+
grid: shape (2, 1, grid_h, grid_w) where index 0 is Y, index 1 is X
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
emb: shape (grid_h * grid_w, embed_dim)
|
|
41
|
+
"""
|
|
42
|
+
assert embed_dim % 2 == 0
|
|
43
|
+
|
|
44
|
+
# Use half of dimensions for X, half for Y
|
|
45
|
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # Y coords
|
|
46
|
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # X coords
|
|
47
|
+
|
|
48
|
+
emb = torch.cat([emb_h, emb_w], dim=1) # shape (grid_h * grid_w, embed_dim)
|
|
49
|
+
return emb
|
|
50
|
+
|
|
51
|
+
def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
|
|
52
|
+
"""
|
|
53
|
+
Generate 1D sinusoidal positional embeddings for a coordinate tensor.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
embed_dim: Dimension of the embedding
|
|
57
|
+
pos: Coordinate tensor of shape (1, H, W) or (L,)
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
emb: shape (H * W, embed_dim) or (L, embed_dim)
|
|
61
|
+
"""
|
|
62
|
+
assert embed_dim % 2 == 0
|
|
63
|
+
|
|
64
|
+
# omega = 1 / (10000 ** (2i / d))
|
|
65
|
+
omega = torch.arange(embed_dim // 2, dtype=torch.float32)
|
|
66
|
+
omega /= embed_dim / 2.0
|
|
67
|
+
omega = 1.0 / (10000 ** omega) # shape (embed_dim // 2,)
|
|
68
|
+
|
|
69
|
+
pos = pos.reshape(-1) # Flatten spatial dims to 1D sequence
|
|
70
|
+
out = torch.outer(pos, omega) # shape (seq_len, embed_dim // 2)
|
|
71
|
+
|
|
72
|
+
emb_sine = torch.sin(out)
|
|
73
|
+
emb_cosine = torch.cos(out)
|
|
74
|
+
|
|
75
|
+
emb = torch.cat([emb_sine, emb_cosine], dim=1) # shape (seq_len, embed_dim)
|
|
76
|
+
return emb
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from typing import Dict, Any, Union
|
|
2
|
+
from transformers import PretrainedConfig, AutoConfig
|
|
3
|
+
|
|
4
|
+
class TinyDocVLMConfig(PretrainedConfig):
|
|
5
|
+
model_type = "tinydoc_vlm"
|
|
6
|
+
is_composition = True
|
|
7
|
+
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
vision_config: Union[Dict[str, Any], PretrainedConfig] = None,
|
|
11
|
+
decoder_config: Union[Dict[str, Any], PretrainedConfig] = None,
|
|
12
|
+
pixel_shuffle_scale: int = 3,
|
|
13
|
+
image_size: int = 384,
|
|
14
|
+
patch_size: int = 16,
|
|
15
|
+
**kwargs,
|
|
16
|
+
):
|
|
17
|
+
super().__init__(**kwargs)
|
|
18
|
+
|
|
19
|
+
# Set defaults if not provided
|
|
20
|
+
if vision_config is None:
|
|
21
|
+
# Default SigLIP-B/16-like configuration (approx 93M parameters)
|
|
22
|
+
vision_config = {
|
|
23
|
+
"model_type": "siglip_vision_model",
|
|
24
|
+
"hidden_size": 768,
|
|
25
|
+
"intermediate_size": 3072,
|
|
26
|
+
"num_hidden_layers": 12,
|
|
27
|
+
"num_attention_heads": 12,
|
|
28
|
+
"patch_size": 16,
|
|
29
|
+
"image_size": 384,
|
|
30
|
+
"num_channels": 3,
|
|
31
|
+
"layer_norm_eps": 1e-6,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
if decoder_config is None:
|
|
35
|
+
# Default SmolLM2-135M-like configuration (approx 135M parameters)
|
|
36
|
+
decoder_config = {
|
|
37
|
+
"model_type": "llama",
|
|
38
|
+
"vocab_size": 49152,
|
|
39
|
+
"hidden_size": 576,
|
|
40
|
+
"intermediate_size": 1536,
|
|
41
|
+
"num_hidden_layers": 30,
|
|
42
|
+
"num_attention_heads": 9,
|
|
43
|
+
"num_key_value_heads": 3,
|
|
44
|
+
"max_position_embeddings": 8192,
|
|
45
|
+
"rms_norm_eps": 1e-5,
|
|
46
|
+
"rope_theta": 273000.0,
|
|
47
|
+
"attention_bias": False,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# Initialize config objects
|
|
51
|
+
if isinstance(vision_config, dict):
|
|
52
|
+
vision_config_copy = vision_config.copy()
|
|
53
|
+
vision_model_type = vision_config_copy.pop("model_type", "siglip_vision_model")
|
|
54
|
+
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_copy)
|
|
55
|
+
else:
|
|
56
|
+
self.vision_config = vision_config
|
|
57
|
+
|
|
58
|
+
if isinstance(decoder_config, dict):
|
|
59
|
+
decoder_config_copy = decoder_config.copy()
|
|
60
|
+
decoder_model_type = decoder_config_copy.pop("model_type", "llama")
|
|
61
|
+
self.decoder_config = AutoConfig.for_model(decoder_model_type, **decoder_config_copy)
|
|
62
|
+
else:
|
|
63
|
+
self.decoder_config = decoder_config
|
|
64
|
+
|
|
65
|
+
self.pixel_shuffle_scale = pixel_shuffle_scale
|
|
66
|
+
self.image_size = image_size
|
|
67
|
+
self.patch_size = patch_size
|
|
68
|
+
|
|
69
|
+
def __getattr__(self, name):
|
|
70
|
+
if name in ('decoder_config', 'vision_config'):
|
|
71
|
+
raise AttributeError(name)
|
|
72
|
+
if 'decoder_config' in self.__dict__:
|
|
73
|
+
try:
|
|
74
|
+
return getattr(self.decoder_config, name)
|
|
75
|
+
except AttributeError:
|
|
76
|
+
pass
|
|
77
|
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
|
78
|
+
|
|
79
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
80
|
+
output = super().to_dict()
|
|
81
|
+
output["vision_config"] = self.vision_config.to_dict()
|
|
82
|
+
output["decoder_config"] = self.decoder_config.to_dict()
|
|
83
|
+
return output
|
tinydoc_vlm/data.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Dict, List, Optional, Union
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from torch.utils.data import Dataset
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from .image_processing import TinyDocImageProcessor
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DocumentDataset(Dataset):
|
|
13
|
+
"""
|
|
14
|
+
Dataset for document understanding training.
|
|
15
|
+
Supports loading from a JSON manifest file or from individual samples.
|
|
16
|
+
|
|
17
|
+
Manifest format (JSONL):
|
|
18
|
+
{"image_path": "path/to/image.png", "text": "Extract: <image>", "labels": {...}}
|
|
19
|
+
"""
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
data_root: Union[str, Path],
|
|
23
|
+
manifest_path: Optional[Union[str, Path]] = None,
|
|
24
|
+
image_processor: Optional[TinyDocImageProcessor] = None,
|
|
25
|
+
max_seq_length: int = 2048,
|
|
26
|
+
stage: int = 1,
|
|
27
|
+
samples: Optional[List[Dict]] = None,
|
|
28
|
+
):
|
|
29
|
+
self.data_root = Path(data_root)
|
|
30
|
+
self.image_processor = image_processor or TinyDocImageProcessor()
|
|
31
|
+
self.max_seq_length = max_seq_length
|
|
32
|
+
self.stage = stage
|
|
33
|
+
|
|
34
|
+
if samples is not None:
|
|
35
|
+
self.samples = samples
|
|
36
|
+
elif manifest_path:
|
|
37
|
+
with open(manifest_path) as f:
|
|
38
|
+
self.samples = [json.loads(line) for line in f if line.strip()]
|
|
39
|
+
else:
|
|
40
|
+
self.samples = []
|
|
41
|
+
|
|
42
|
+
def __len__(self) -> int:
|
|
43
|
+
return len(self.samples)
|
|
44
|
+
|
|
45
|
+
def __getitem__(self, idx: int) -> Dict:
|
|
46
|
+
sample = self.samples[idx]
|
|
47
|
+
image_path = self.data_root / sample["image_path"]
|
|
48
|
+
image = Image.open(image_path).convert("RGB")
|
|
49
|
+
pixel_values = self.image_processor.preprocess(image)
|
|
50
|
+
|
|
51
|
+
text = sample.get("text", "<image>")
|
|
52
|
+
labels = sample.get("labels", {})
|
|
53
|
+
|
|
54
|
+
return {
|
|
55
|
+
"pixel_values": pixel_values,
|
|
56
|
+
"text": text,
|
|
57
|
+
"labels": labels,
|
|
58
|
+
"metadata": sample.get("metadata", {}),
|
|
59
|
+
}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def collate_fn(batch: List[Dict], tokenizer, image_token_id: int, max_length: int = 2048) -> Dict:
|
|
63
|
+
"""
|
|
64
|
+
Collate function for DocumentDataset.
|
|
65
|
+
Handles variable-length text, variable-number tiles, and label padding.
|
|
66
|
+
"""
|
|
67
|
+
texts = [item["text"] for item in batch]
|
|
68
|
+
images = [item.get("pixel_values") for item in batch]
|
|
69
|
+
|
|
70
|
+
max_tiles = max(pv.shape[0] for pv in images)
|
|
71
|
+
image_size = images[0].shape[-1]
|
|
72
|
+
padded_pixel_values = []
|
|
73
|
+
for pv in images:
|
|
74
|
+
num_tiles = pv.shape[0]
|
|
75
|
+
if num_tiles < max_tiles:
|
|
76
|
+
pad = torch.zeros(max_tiles - num_tiles, 3, image_size, image_size, dtype=pv.dtype)
|
|
77
|
+
pv = torch.cat([pv, pad], dim=0)
|
|
78
|
+
padded_pixel_values.append(pv)
|
|
79
|
+
|
|
80
|
+
pixel_values = torch.stack(padded_pixel_values, dim=0)
|
|
81
|
+
|
|
82
|
+
tokenized = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
|
|
83
|
+
|
|
84
|
+
labels = tokenized["input_ids"].clone()
|
|
85
|
+
labels[labels == tokenizer.pad_token_id] = -100
|
|
86
|
+
|
|
87
|
+
return {
|
|
88
|
+
"input_ids": tokenized["input_ids"],
|
|
89
|
+
"attention_mask": tokenized["attention_mask"],
|
|
90
|
+
"pixel_values": pixel_values,
|
|
91
|
+
"labels": labels,
|
|
92
|
+
}
|
tinydoc_vlm/decoder.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from typing import Optional, List
|
|
4
|
+
from transformers import LlamaForCausalLM, LlamaConfig
|
|
5
|
+
|
|
6
|
+
class TinyDocDecoder(nn.Module):
|
|
7
|
+
"""
|
|
8
|
+
Decoder wrapper around LlamaForCausalLM (used by SmolLM2).
|
|
9
|
+
Manages loading and vocabulary/embedding resizing for special tokens.
|
|
10
|
+
"""
|
|
11
|
+
def __init__(self, config: LlamaConfig):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.config = config
|
|
14
|
+
self.lm = LlamaForCausalLM(config)
|
|
15
|
+
self.hidden_size = config.hidden_size
|
|
16
|
+
|
|
17
|
+
def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
|
|
18
|
+
"""
|
|
19
|
+
Resizes input token embeddings and output LM head of the decoder.
|
|
20
|
+
"""
|
|
21
|
+
resized = self.lm.resize_token_embeddings(new_num_tokens)
|
|
22
|
+
self.config.vocab_size = new_num_tokens
|
|
23
|
+
return resized
|
|
24
|
+
|
|
25
|
+
def get_input_embeddings(self) -> nn.Module:
|
|
26
|
+
return self.lm.get_input_embeddings()
|
|
27
|
+
|
|
28
|
+
def forward(
|
|
29
|
+
self,
|
|
30
|
+
input_ids: Optional[torch.LongTensor] = None,
|
|
31
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
32
|
+
position_ids: Optional[torch.LongTensor] = None,
|
|
33
|
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
34
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
35
|
+
labels: Optional[torch.LongTensor] = None,
|
|
36
|
+
use_cache: Optional[bool] = None,
|
|
37
|
+
output_attentions: Optional[bool] = None,
|
|
38
|
+
output_hidden_states: Optional[bool] = None,
|
|
39
|
+
return_dict: Optional[bool] = None,
|
|
40
|
+
):
|
|
41
|
+
return self.lm(
|
|
42
|
+
input_ids=input_ids,
|
|
43
|
+
attention_mask=attention_mask,
|
|
44
|
+
position_ids=position_ids,
|
|
45
|
+
past_key_values=past_key_values,
|
|
46
|
+
inputs_embeds=inputs_embeds,
|
|
47
|
+
labels=labels,
|
|
48
|
+
use_cache=use_cache,
|
|
49
|
+
output_attentions=output_attentions,
|
|
50
|
+
output_hidden_states=output_hidden_states,
|
|
51
|
+
return_dict=return_dict,
|
|
52
|
+
)
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
from PIL import Image
|
|
3
|
+
from typing import Optional, List
|
|
4
|
+
import torch
|
|
5
|
+
import torchvision.transforms as T
|
|
6
|
+
from transformers.image_processing_base import ImageProcessingMixin
|
|
7
|
+
|
|
8
|
+
class TinyDocImageProcessor(ImageProcessingMixin):
|
|
9
|
+
"""
|
|
10
|
+
Image processor for TinyDoc-VLM.
|
|
11
|
+
Handles resizing, normalization, and optional tiling (splitting) of document images.
|
|
12
|
+
"""
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
image_size: int = 384,
|
|
16
|
+
mean: Optional[List[float]] = None,
|
|
17
|
+
std: Optional[List[float]] = None,
|
|
18
|
+
tiling_mode: str = "auto", # "none", "auto" (split if large)
|
|
19
|
+
**kwargs,
|
|
20
|
+
):
|
|
21
|
+
self.image_size = image_size
|
|
22
|
+
self.mean = mean or [0.5, 0.5, 0.5]
|
|
23
|
+
self.std = std or [0.5, 0.5, 0.5]
|
|
24
|
+
self.tiling_mode = tiling_mode
|
|
25
|
+
|
|
26
|
+
super().__init__(**kwargs)
|
|
27
|
+
|
|
28
|
+
# Base torchvision transforms for single tile
|
|
29
|
+
self.transform = T.Compose([
|
|
30
|
+
T.ToTensor(),
|
|
31
|
+
T.Normalize(mean=self.mean, std=self.std)
|
|
32
|
+
])
|
|
33
|
+
|
|
34
|
+
def preprocess(
|
|
35
|
+
self,
|
|
36
|
+
image: Image.Image,
|
|
37
|
+
return_tensors: str = "pt"
|
|
38
|
+
) -> torch.Tensor:
|
|
39
|
+
"""
|
|
40
|
+
Preprocesses a PIL Image into a multi-tile float tensor.
|
|
41
|
+
|
|
42
|
+
Returns shape: (num_tiles, 3, image_size, image_size)
|
|
43
|
+
"""
|
|
44
|
+
# Ensure RGB
|
|
45
|
+
if image.mode != "RGB":
|
|
46
|
+
image = image.convert("RGB")
|
|
47
|
+
|
|
48
|
+
w, h = image.size
|
|
49
|
+
|
|
50
|
+
if self.tiling_mode == "none" or (w <= self.image_size and h <= self.image_size):
|
|
51
|
+
# No tiling needed: resize to image_size x image_size and return single tile
|
|
52
|
+
resized = image.resize((self.image_size, self.image_size), Image.Resampling.BILINEAR)
|
|
53
|
+
tile_tensor = self.transform(resized) # shape (3, image_size, image_size)
|
|
54
|
+
# Add tile batch dimension: shape (1, 3, image_size, image_size)
|
|
55
|
+
return tile_tensor.unsqueeze(0)
|
|
56
|
+
|
|
57
|
+
# Tiling mode 'auto': split high-res image into a grid of image_size x image_size tiles,
|
|
58
|
+
# plus a downscaled overview thumbnail.
|
|
59
|
+
|
|
60
|
+
# Calculate how many tiles we need
|
|
61
|
+
cols = int(np.ceil(w / self.image_size))
|
|
62
|
+
rows = int(np.ceil(h / self.image_size))
|
|
63
|
+
|
|
64
|
+
# Limit grid size to prevent excessive memory usage (max 2x2 grid = 4 tiles)
|
|
65
|
+
cols = min(cols, 2)
|
|
66
|
+
rows = min(rows, 2)
|
|
67
|
+
|
|
68
|
+
# Target size for the tiling grid
|
|
69
|
+
target_w = cols * self.image_size
|
|
70
|
+
target_h = rows * self.image_size
|
|
71
|
+
|
|
72
|
+
# Resize original image to fit the target grid shape (maintaining proportions via padding)
|
|
73
|
+
resized_full = self._resize_and_pad(image, target_w, target_h)
|
|
74
|
+
|
|
75
|
+
tiles = []
|
|
76
|
+
|
|
77
|
+
# 1. Generate thumbnail/overview of the full image
|
|
78
|
+
thumbnail = image.resize((self.image_size, self.image_size), Image.Resampling.BILINEAR)
|
|
79
|
+
tiles.append(self.transform(thumbnail))
|
|
80
|
+
|
|
81
|
+
# 2. Extract tiles from the grid
|
|
82
|
+
for r in range(rows):
|
|
83
|
+
for c in range(cols):
|
|
84
|
+
box = (
|
|
85
|
+
c * self.image_size,
|
|
86
|
+
r * self.image_size,
|
|
87
|
+
(c + 1) * self.image_size,
|
|
88
|
+
(r + 1) * self.image_size
|
|
89
|
+
)
|
|
90
|
+
tile = resized_full.crop(box)
|
|
91
|
+
tiles.append(self.transform(tile))
|
|
92
|
+
|
|
93
|
+
# Stack tiles along a new dimension: shape (num_tiles, 3, image_size, image_size)
|
|
94
|
+
# where num_tiles = 1 (overview) + rows * cols
|
|
95
|
+
stacked_tiles = torch.stack(tiles, dim=0)
|
|
96
|
+
return stacked_tiles
|
|
97
|
+
|
|
98
|
+
def _resize_and_pad(self, img: Image.Image, target_w: int, target_h: int) -> Image.Image:
|
|
99
|
+
"""
|
|
100
|
+
Resizes and pads an image to target dimensions while maintaining aspect ratio.
|
|
101
|
+
"""
|
|
102
|
+
# Calculate aspect ratio
|
|
103
|
+
w, h = img.size
|
|
104
|
+
ratio = min(target_w / w, target_h / h)
|
|
105
|
+
new_w = int(w * ratio)
|
|
106
|
+
new_h = int(h * ratio)
|
|
107
|
+
|
|
108
|
+
resized = img.resize((new_w, new_h), Image.Resampling.BILINEAR)
|
|
109
|
+
|
|
110
|
+
# Create a new padded background image
|
|
111
|
+
padded = Image.new("RGB", (target_w, target_h), (255, 255, 255))
|
|
112
|
+
# Center the resized image
|
|
113
|
+
x_offset = (target_w - new_w) // 2
|
|
114
|
+
y_offset = (target_h - new_h) // 2
|
|
115
|
+
padded.paste(resized, (x_offset, y_offset))
|
|
116
|
+
|
|
117
|
+
return padded
|
tinydoc_vlm/losses.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from typing import Dict, Optional
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def ce_loss(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor:
|
|
8
|
+
return F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=ignore_index)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def kv_loss(kv_outputs: Dict[str, torch.Tensor], key_labels: torch.Tensor, confidence_labels: torch.Tensor) -> torch.Tensor:
|
|
12
|
+
key_ce = F.cross_entropy(kv_outputs["key_logits"].view(-1, kv_outputs["key_logits"].size(-1)), key_labels.view(-1))
|
|
13
|
+
conf_bce = F.binary_cross_entropy(kv_outputs["confidence"].view(-1), confidence_labels.view(-1))
|
|
14
|
+
return key_ce + conf_bce
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class CombinedLoss(nn.Module):
|
|
18
|
+
"""
|
|
19
|
+
Combines multiple task-specific losses for multi-stage training.
|
|
20
|
+
Stage 1: layout + OCR + region
|
|
21
|
+
Stage 2: QA + JSON + KV + table
|
|
22
|
+
Stage 3: standard LM
|
|
23
|
+
"""
|
|
24
|
+
def __init__(self, stage: int = 1):
|
|
25
|
+
super().__init__()
|
|
26
|
+
self.stage = stage
|
|
27
|
+
|
|
28
|
+
def forward(
|
|
29
|
+
self,
|
|
30
|
+
lm_logits: torch.Tensor,
|
|
31
|
+
lm_labels: torch.Tensor,
|
|
32
|
+
head_outputs: Optional[Dict[str, torch.Tensor]] = None,
|
|
33
|
+
head_labels: Optional[Dict[str, torch.Tensor]] = None,
|
|
34
|
+
) -> Dict[str, torch.Tensor]:
|
|
35
|
+
losses = {}
|
|
36
|
+
total_loss = torch.tensor(0.0, device=lm_logits.device)
|
|
37
|
+
|
|
38
|
+
lm_loss = ce_loss(lm_logits, lm_labels)
|
|
39
|
+
losses["lm_loss"] = lm_loss
|
|
40
|
+
|
|
41
|
+
if self.stage == 1:
|
|
42
|
+
total_loss = lm_loss
|
|
43
|
+
elif self.stage == 2:
|
|
44
|
+
total_loss = lm_loss
|
|
45
|
+
if head_outputs and head_labels:
|
|
46
|
+
if "json_logits" in head_outputs and "json_labels" in head_labels:
|
|
47
|
+
json_loss = ce_loss(head_outputs["json_logits"], head_labels["json_labels"])
|
|
48
|
+
losses["json_loss"] = json_loss
|
|
49
|
+
total_loss = total_loss + json_loss
|
|
50
|
+
|
|
51
|
+
if "kv" in head_outputs and "kv_key_labels" in head_labels:
|
|
52
|
+
kv = kv_loss(head_outputs["kv"], head_labels["kv_key_labels"], head_labels["kv_conf_labels"])
|
|
53
|
+
losses["kv_loss"] = kv
|
|
54
|
+
total_loss = total_loss + kv
|
|
55
|
+
|
|
56
|
+
if "table_logits" in head_outputs and "table_labels" in head_labels:
|
|
57
|
+
table_loss = ce_loss(head_outputs["table_logits"], head_labels["table_labels"])
|
|
58
|
+
losses["table_loss"] = table_loss
|
|
59
|
+
total_loss = total_loss + table_loss
|
|
60
|
+
else:
|
|
61
|
+
total_loss = lm_loss
|
|
62
|
+
|
|
63
|
+
losses["loss"] = total_loss
|
|
64
|
+
return losses
|