tinydoc-vlm 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,40 @@
1
+ from .configuration import TinyDocVLMConfig
2
+ from .vision_encoder import SigLIPVisionEncoder
3
+ from .token_compressor import PixelShuffleTokenCompressor
4
+ from .decoder import TinyDocDecoder
5
+ from .modeling import TinyDocVLMForConditionalGeneration, TinyDocVLMPreTrainedModel
6
+ from .image_processing import TinyDocImageProcessor
7
+ from .processing import TinyDocVLMProcessor
8
+ from .output_heads import MultiTaskOutputHeads, JSONHead, KVHead, TableHead, OCRHead, QAHead
9
+ from .data import DocumentDataset
10
+ from .losses import CombinedLoss
11
+ from .trainer import TinyDocVLMTrainer, TrainerConfig
12
+
13
+ __all__ = [
14
+ "TinyDocVLMConfig",
15
+ "SigLIPVisionEncoder",
16
+ "PixelShuffleTokenCompressor",
17
+ "TinyDocDecoder",
18
+ "TinyDocVLMForConditionalGeneration",
19
+ "TinyDocVLMPreTrainedModel",
20
+ "TinyDocImageProcessor",
21
+ "TinyDocVLMProcessor",
22
+ "MultiTaskOutputHeads",
23
+ "JSONHead",
24
+ "KVHead",
25
+ "TableHead",
26
+ "OCRHead",
27
+ "QAHead",
28
+ "DocumentDataset",
29
+ "CombinedLoss",
30
+ "TinyDocVLMTrainer",
31
+ "TrainerConfig",
32
+ ]
33
+
34
+ from transformers import AutoConfig, AutoModelForCausalLM
35
+
36
+ try:
37
+ AutoConfig.register("tinydoc_vlm", TinyDocVLMConfig)
38
+ AutoModelForCausalLM.register(TinyDocVLMConfig, TinyDocVLMForConditionalGeneration)
39
+ except ValueError:
40
+ pass
@@ -0,0 +1,76 @@
1
+ import torch
2
+
3
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: int, cls_token: bool = False) -> torch.Tensor:
4
+ """
5
+ Generate 2D sinusoidal positional embeddings.
6
+
7
+ Args:
8
+ embed_dim: Dimension of the embedding (must be even)
9
+ grid_size: Height/Width of the grid (assumed square)
10
+ cls_token: If True, prepends a zero embedding for the class token
11
+
12
+ Returns:
13
+ pos_embed: shape (grid_size * grid_size, embed_dim) or (1 + grid_size * grid_size, embed_dim)
14
+ """
15
+ grid_h = torch.arange(grid_size, dtype=torch.float32)
16
+ grid_w = torch.arange(grid_size, dtype=torch.float32)
17
+
18
+ # Create coordinate grid
19
+ grid = torch.meshgrid(grid_h, grid_w, indexing="ij")
20
+ grid = torch.stack(grid, dim=0) # shape (2, grid_size, grid_size)
21
+ grid = grid.reshape(2, 1, grid_size, grid_size)
22
+
23
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
24
+
25
+ if cls_token:
26
+ # Prepend a zero embedding for CLS token
27
+ pos_embed = torch.cat([torch.zeros([1, embed_dim]), pos_embed], dim=0)
28
+
29
+ return pos_embed
30
+
31
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
32
+ """
33
+ Helper function to generate embeddings from a coordinate grid.
34
+
35
+ Args:
36
+ embed_dim: Dimension of the embedding
37
+ grid: shape (2, 1, grid_h, grid_w) where index 0 is Y, index 1 is X
38
+
39
+ Returns:
40
+ emb: shape (grid_h * grid_w, embed_dim)
41
+ """
42
+ assert embed_dim % 2 == 0
43
+
44
+ # Use half of dimensions for X, half for Y
45
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # Y coords
46
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # X coords
47
+
48
+ emb = torch.cat([emb_h, emb_w], dim=1) # shape (grid_h * grid_w, embed_dim)
49
+ return emb
50
+
51
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
52
+ """
53
+ Generate 1D sinusoidal positional embeddings for a coordinate tensor.
54
+
55
+ Args:
56
+ embed_dim: Dimension of the embedding
57
+ pos: Coordinate tensor of shape (1, H, W) or (L,)
58
+
59
+ Returns:
60
+ emb: shape (H * W, embed_dim) or (L, embed_dim)
61
+ """
62
+ assert embed_dim % 2 == 0
63
+
64
+ # omega = 1 / (10000 ** (2i / d))
65
+ omega = torch.arange(embed_dim // 2, dtype=torch.float32)
66
+ omega /= embed_dim / 2.0
67
+ omega = 1.0 / (10000 ** omega) # shape (embed_dim // 2,)
68
+
69
+ pos = pos.reshape(-1) # Flatten spatial dims to 1D sequence
70
+ out = torch.outer(pos, omega) # shape (seq_len, embed_dim // 2)
71
+
72
+ emb_sine = torch.sin(out)
73
+ emb_cosine = torch.cos(out)
74
+
75
+ emb = torch.cat([emb_sine, emb_cosine], dim=1) # shape (seq_len, embed_dim)
76
+ return emb
@@ -0,0 +1,83 @@
1
+ from typing import Dict, Any, Union
2
+ from transformers import PretrainedConfig, AutoConfig
3
+
4
+ class TinyDocVLMConfig(PretrainedConfig):
5
+ model_type = "tinydoc_vlm"
6
+ is_composition = True
7
+
8
+ def __init__(
9
+ self,
10
+ vision_config: Union[Dict[str, Any], PretrainedConfig] = None,
11
+ decoder_config: Union[Dict[str, Any], PretrainedConfig] = None,
12
+ pixel_shuffle_scale: int = 3,
13
+ image_size: int = 384,
14
+ patch_size: int = 16,
15
+ **kwargs,
16
+ ):
17
+ super().__init__(**kwargs)
18
+
19
+ # Set defaults if not provided
20
+ if vision_config is None:
21
+ # Default SigLIP-B/16-like configuration (approx 93M parameters)
22
+ vision_config = {
23
+ "model_type": "siglip_vision_model",
24
+ "hidden_size": 768,
25
+ "intermediate_size": 3072,
26
+ "num_hidden_layers": 12,
27
+ "num_attention_heads": 12,
28
+ "patch_size": 16,
29
+ "image_size": 384,
30
+ "num_channels": 3,
31
+ "layer_norm_eps": 1e-6,
32
+ }
33
+
34
+ if decoder_config is None:
35
+ # Default SmolLM2-135M-like configuration (approx 135M parameters)
36
+ decoder_config = {
37
+ "model_type": "llama",
38
+ "vocab_size": 49152,
39
+ "hidden_size": 576,
40
+ "intermediate_size": 1536,
41
+ "num_hidden_layers": 30,
42
+ "num_attention_heads": 9,
43
+ "num_key_value_heads": 3,
44
+ "max_position_embeddings": 8192,
45
+ "rms_norm_eps": 1e-5,
46
+ "rope_theta": 273000.0,
47
+ "attention_bias": False,
48
+ }
49
+
50
+ # Initialize config objects
51
+ if isinstance(vision_config, dict):
52
+ vision_config_copy = vision_config.copy()
53
+ vision_model_type = vision_config_copy.pop("model_type", "siglip_vision_model")
54
+ self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_copy)
55
+ else:
56
+ self.vision_config = vision_config
57
+
58
+ if isinstance(decoder_config, dict):
59
+ decoder_config_copy = decoder_config.copy()
60
+ decoder_model_type = decoder_config_copy.pop("model_type", "llama")
61
+ self.decoder_config = AutoConfig.for_model(decoder_model_type, **decoder_config_copy)
62
+ else:
63
+ self.decoder_config = decoder_config
64
+
65
+ self.pixel_shuffle_scale = pixel_shuffle_scale
66
+ self.image_size = image_size
67
+ self.patch_size = patch_size
68
+
69
+ def __getattr__(self, name):
70
+ if name in ('decoder_config', 'vision_config'):
71
+ raise AttributeError(name)
72
+ if 'decoder_config' in self.__dict__:
73
+ try:
74
+ return getattr(self.decoder_config, name)
75
+ except AttributeError:
76
+ pass
77
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
78
+
79
+ def to_dict(self) -> Dict[str, Any]:
80
+ output = super().to_dict()
81
+ output["vision_config"] = self.vision_config.to_dict()
82
+ output["decoder_config"] = self.decoder_config.to_dict()
83
+ return output
tinydoc_vlm/data.py ADDED
@@ -0,0 +1,92 @@
1
+ import json
2
+ from pathlib import Path
3
+ from typing import Dict, List, Optional, Union
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from PIL import Image
8
+
9
+ from .image_processing import TinyDocImageProcessor
10
+
11
+
12
+ class DocumentDataset(Dataset):
13
+ """
14
+ Dataset for document understanding training.
15
+ Supports loading from a JSON manifest file or from individual samples.
16
+
17
+ Manifest format (JSONL):
18
+ {"image_path": "path/to/image.png", "text": "Extract: <image>", "labels": {...}}
19
+ """
20
+ def __init__(
21
+ self,
22
+ data_root: Union[str, Path],
23
+ manifest_path: Optional[Union[str, Path]] = None,
24
+ image_processor: Optional[TinyDocImageProcessor] = None,
25
+ max_seq_length: int = 2048,
26
+ stage: int = 1,
27
+ samples: Optional[List[Dict]] = None,
28
+ ):
29
+ self.data_root = Path(data_root)
30
+ self.image_processor = image_processor or TinyDocImageProcessor()
31
+ self.max_seq_length = max_seq_length
32
+ self.stage = stage
33
+
34
+ if samples is not None:
35
+ self.samples = samples
36
+ elif manifest_path:
37
+ with open(manifest_path) as f:
38
+ self.samples = [json.loads(line) for line in f if line.strip()]
39
+ else:
40
+ self.samples = []
41
+
42
+ def __len__(self) -> int:
43
+ return len(self.samples)
44
+
45
+ def __getitem__(self, idx: int) -> Dict:
46
+ sample = self.samples[idx]
47
+ image_path = self.data_root / sample["image_path"]
48
+ image = Image.open(image_path).convert("RGB")
49
+ pixel_values = self.image_processor.preprocess(image)
50
+
51
+ text = sample.get("text", "<image>")
52
+ labels = sample.get("labels", {})
53
+
54
+ return {
55
+ "pixel_values": pixel_values,
56
+ "text": text,
57
+ "labels": labels,
58
+ "metadata": sample.get("metadata", {}),
59
+ }
60
+
61
+
62
+ def collate_fn(batch: List[Dict], tokenizer, image_token_id: int, max_length: int = 2048) -> Dict:
63
+ """
64
+ Collate function for DocumentDataset.
65
+ Handles variable-length text, variable-number tiles, and label padding.
66
+ """
67
+ texts = [item["text"] for item in batch]
68
+ images = [item.get("pixel_values") for item in batch]
69
+
70
+ max_tiles = max(pv.shape[0] for pv in images)
71
+ image_size = images[0].shape[-1]
72
+ padded_pixel_values = []
73
+ for pv in images:
74
+ num_tiles = pv.shape[0]
75
+ if num_tiles < max_tiles:
76
+ pad = torch.zeros(max_tiles - num_tiles, 3, image_size, image_size, dtype=pv.dtype)
77
+ pv = torch.cat([pv, pad], dim=0)
78
+ padded_pixel_values.append(pv)
79
+
80
+ pixel_values = torch.stack(padded_pixel_values, dim=0)
81
+
82
+ tokenized = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt")
83
+
84
+ labels = tokenized["input_ids"].clone()
85
+ labels[labels == tokenizer.pad_token_id] = -100
86
+
87
+ return {
88
+ "input_ids": tokenized["input_ids"],
89
+ "attention_mask": tokenized["attention_mask"],
90
+ "pixel_values": pixel_values,
91
+ "labels": labels,
92
+ }
tinydoc_vlm/decoder.py ADDED
@@ -0,0 +1,52 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional, List
4
+ from transformers import LlamaForCausalLM, LlamaConfig
5
+
6
+ class TinyDocDecoder(nn.Module):
7
+ """
8
+ Decoder wrapper around LlamaForCausalLM (used by SmolLM2).
9
+ Manages loading and vocabulary/embedding resizing for special tokens.
10
+ """
11
+ def __init__(self, config: LlamaConfig):
12
+ super().__init__()
13
+ self.config = config
14
+ self.lm = LlamaForCausalLM(config)
15
+ self.hidden_size = config.hidden_size
16
+
17
+ def resize_token_embeddings(self, new_num_tokens: int) -> nn.Embedding:
18
+ """
19
+ Resizes input token embeddings and output LM head of the decoder.
20
+ """
21
+ resized = self.lm.resize_token_embeddings(new_num_tokens)
22
+ self.config.vocab_size = new_num_tokens
23
+ return resized
24
+
25
+ def get_input_embeddings(self) -> nn.Module:
26
+ return self.lm.get_input_embeddings()
27
+
28
+ def forward(
29
+ self,
30
+ input_ids: Optional[torch.LongTensor] = None,
31
+ attention_mask: Optional[torch.Tensor] = None,
32
+ position_ids: Optional[torch.LongTensor] = None,
33
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
34
+ inputs_embeds: Optional[torch.FloatTensor] = None,
35
+ labels: Optional[torch.LongTensor] = None,
36
+ use_cache: Optional[bool] = None,
37
+ output_attentions: Optional[bool] = None,
38
+ output_hidden_states: Optional[bool] = None,
39
+ return_dict: Optional[bool] = None,
40
+ ):
41
+ return self.lm(
42
+ input_ids=input_ids,
43
+ attention_mask=attention_mask,
44
+ position_ids=position_ids,
45
+ past_key_values=past_key_values,
46
+ inputs_embeds=inputs_embeds,
47
+ labels=labels,
48
+ use_cache=use_cache,
49
+ output_attentions=output_attentions,
50
+ output_hidden_states=output_hidden_states,
51
+ return_dict=return_dict,
52
+ )
@@ -0,0 +1,117 @@
1
+ import numpy as np
2
+ from PIL import Image
3
+ from typing import Optional, List
4
+ import torch
5
+ import torchvision.transforms as T
6
+ from transformers.image_processing_base import ImageProcessingMixin
7
+
8
+ class TinyDocImageProcessor(ImageProcessingMixin):
9
+ """
10
+ Image processor for TinyDoc-VLM.
11
+ Handles resizing, normalization, and optional tiling (splitting) of document images.
12
+ """
13
+ def __init__(
14
+ self,
15
+ image_size: int = 384,
16
+ mean: Optional[List[float]] = None,
17
+ std: Optional[List[float]] = None,
18
+ tiling_mode: str = "auto", # "none", "auto" (split if large)
19
+ **kwargs,
20
+ ):
21
+ self.image_size = image_size
22
+ self.mean = mean or [0.5, 0.5, 0.5]
23
+ self.std = std or [0.5, 0.5, 0.5]
24
+ self.tiling_mode = tiling_mode
25
+
26
+ super().__init__(**kwargs)
27
+
28
+ # Base torchvision transforms for single tile
29
+ self.transform = T.Compose([
30
+ T.ToTensor(),
31
+ T.Normalize(mean=self.mean, std=self.std)
32
+ ])
33
+
34
+ def preprocess(
35
+ self,
36
+ image: Image.Image,
37
+ return_tensors: str = "pt"
38
+ ) -> torch.Tensor:
39
+ """
40
+ Preprocesses a PIL Image into a multi-tile float tensor.
41
+
42
+ Returns shape: (num_tiles, 3, image_size, image_size)
43
+ """
44
+ # Ensure RGB
45
+ if image.mode != "RGB":
46
+ image = image.convert("RGB")
47
+
48
+ w, h = image.size
49
+
50
+ if self.tiling_mode == "none" or (w <= self.image_size and h <= self.image_size):
51
+ # No tiling needed: resize to image_size x image_size and return single tile
52
+ resized = image.resize((self.image_size, self.image_size), Image.Resampling.BILINEAR)
53
+ tile_tensor = self.transform(resized) # shape (3, image_size, image_size)
54
+ # Add tile batch dimension: shape (1, 3, image_size, image_size)
55
+ return tile_tensor.unsqueeze(0)
56
+
57
+ # Tiling mode 'auto': split high-res image into a grid of image_size x image_size tiles,
58
+ # plus a downscaled overview thumbnail.
59
+
60
+ # Calculate how many tiles we need
61
+ cols = int(np.ceil(w / self.image_size))
62
+ rows = int(np.ceil(h / self.image_size))
63
+
64
+ # Limit grid size to prevent excessive memory usage (max 2x2 grid = 4 tiles)
65
+ cols = min(cols, 2)
66
+ rows = min(rows, 2)
67
+
68
+ # Target size for the tiling grid
69
+ target_w = cols * self.image_size
70
+ target_h = rows * self.image_size
71
+
72
+ # Resize original image to fit the target grid shape (maintaining proportions via padding)
73
+ resized_full = self._resize_and_pad(image, target_w, target_h)
74
+
75
+ tiles = []
76
+
77
+ # 1. Generate thumbnail/overview of the full image
78
+ thumbnail = image.resize((self.image_size, self.image_size), Image.Resampling.BILINEAR)
79
+ tiles.append(self.transform(thumbnail))
80
+
81
+ # 2. Extract tiles from the grid
82
+ for r in range(rows):
83
+ for c in range(cols):
84
+ box = (
85
+ c * self.image_size,
86
+ r * self.image_size,
87
+ (c + 1) * self.image_size,
88
+ (r + 1) * self.image_size
89
+ )
90
+ tile = resized_full.crop(box)
91
+ tiles.append(self.transform(tile))
92
+
93
+ # Stack tiles along a new dimension: shape (num_tiles, 3, image_size, image_size)
94
+ # where num_tiles = 1 (overview) + rows * cols
95
+ stacked_tiles = torch.stack(tiles, dim=0)
96
+ return stacked_tiles
97
+
98
+ def _resize_and_pad(self, img: Image.Image, target_w: int, target_h: int) -> Image.Image:
99
+ """
100
+ Resizes and pads an image to target dimensions while maintaining aspect ratio.
101
+ """
102
+ # Calculate aspect ratio
103
+ w, h = img.size
104
+ ratio = min(target_w / w, target_h / h)
105
+ new_w = int(w * ratio)
106
+ new_h = int(h * ratio)
107
+
108
+ resized = img.resize((new_w, new_h), Image.Resampling.BILINEAR)
109
+
110
+ # Create a new padded background image
111
+ padded = Image.new("RGB", (target_w, target_h), (255, 255, 255))
112
+ # Center the resized image
113
+ x_offset = (target_w - new_w) // 2
114
+ y_offset = (target_h - new_h) // 2
115
+ padded.paste(resized, (x_offset, y_offset))
116
+
117
+ return padded
tinydoc_vlm/losses.py ADDED
@@ -0,0 +1,64 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Dict, Optional
5
+
6
+
7
+ def ce_loss(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> torch.Tensor:
8
+ return F.cross_entropy(logits.view(-1, logits.size(-1)), labels.view(-1), ignore_index=ignore_index)
9
+
10
+
11
+ def kv_loss(kv_outputs: Dict[str, torch.Tensor], key_labels: torch.Tensor, confidence_labels: torch.Tensor) -> torch.Tensor:
12
+ key_ce = F.cross_entropy(kv_outputs["key_logits"].view(-1, kv_outputs["key_logits"].size(-1)), key_labels.view(-1))
13
+ conf_bce = F.binary_cross_entropy(kv_outputs["confidence"].view(-1), confidence_labels.view(-1))
14
+ return key_ce + conf_bce
15
+
16
+
17
+ class CombinedLoss(nn.Module):
18
+ """
19
+ Combines multiple task-specific losses for multi-stage training.
20
+ Stage 1: layout + OCR + region
21
+ Stage 2: QA + JSON + KV + table
22
+ Stage 3: standard LM
23
+ """
24
+ def __init__(self, stage: int = 1):
25
+ super().__init__()
26
+ self.stage = stage
27
+
28
+ def forward(
29
+ self,
30
+ lm_logits: torch.Tensor,
31
+ lm_labels: torch.Tensor,
32
+ head_outputs: Optional[Dict[str, torch.Tensor]] = None,
33
+ head_labels: Optional[Dict[str, torch.Tensor]] = None,
34
+ ) -> Dict[str, torch.Tensor]:
35
+ losses = {}
36
+ total_loss = torch.tensor(0.0, device=lm_logits.device)
37
+
38
+ lm_loss = ce_loss(lm_logits, lm_labels)
39
+ losses["lm_loss"] = lm_loss
40
+
41
+ if self.stage == 1:
42
+ total_loss = lm_loss
43
+ elif self.stage == 2:
44
+ total_loss = lm_loss
45
+ if head_outputs and head_labels:
46
+ if "json_logits" in head_outputs and "json_labels" in head_labels:
47
+ json_loss = ce_loss(head_outputs["json_logits"], head_labels["json_labels"])
48
+ losses["json_loss"] = json_loss
49
+ total_loss = total_loss + json_loss
50
+
51
+ if "kv" in head_outputs and "kv_key_labels" in head_labels:
52
+ kv = kv_loss(head_outputs["kv"], head_labels["kv_key_labels"], head_labels["kv_conf_labels"])
53
+ losses["kv_loss"] = kv
54
+ total_loss = total_loss + kv
55
+
56
+ if "table_logits" in head_outputs and "table_labels" in head_labels:
57
+ table_loss = ce_loss(head_outputs["table_logits"], head_labels["table_labels"])
58
+ losses["table_loss"] = table_loss
59
+ total_loss = total_loss + table_loss
60
+ else:
61
+ total_loss = lm_loss
62
+
63
+ losses["loss"] = total_loss
64
+ return losses