tinydoc-vlm 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,194 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Dict, Any, Optional, List, Tuple, Union
4
+ from transformers import PreTrainedModel, GenerationMixin
5
+ from transformers.modeling_outputs import CausalLMOutputWithPast
6
+
7
+ from .configuration import TinyDocVLMConfig
8
+ from .vision_encoder import SigLIPVisionEncoder
9
+ from .token_compressor import PixelShuffleTokenCompressor
10
+ from .decoder import TinyDocDecoder
11
+ from .output_heads import MultiTaskOutputHeads
12
+
13
+ class TinyDocVLMPreTrainedModel(PreTrainedModel):
14
+ config_class = TinyDocVLMConfig
15
+ base_model_prefix = "tinydoc_vlm"
16
+ supports_gradient_checkpointing = True
17
+
18
+ def _init_weights(self, module):
19
+ std = getattr(self.config, "initializer_range", 0.02)
20
+ if isinstance(module, nn.Linear):
21
+ module.weight.data.normal_(mean=0.0, std=std)
22
+ if module.bias is not None:
23
+ module.bias.data.zero_()
24
+ elif isinstance(module, nn.Embedding):
25
+ module.weight.data.normal_(mean=0.0, std=std)
26
+ if module.padding_idx is not None:
27
+ module.weight.data[module.padding_idx].zero_()
28
+
29
+ class TinyDocVLMForConditionalGeneration(TinyDocVLMPreTrainedModel, GenerationMixin):
30
+ """
31
+ TinyDoc-VLM: The World's Smallest Document Understanding Model.
32
+ Coordinates SigLIP Vision Encoder, PixelShuffle Compressor, and SmolLM2 Decoder.
33
+ """
34
+ def __init__(self, config: TinyDocVLMConfig):
35
+ super().__init__(config)
36
+
37
+ # 1. Vision Encoder
38
+ self.vision_encoder = SigLIPVisionEncoder(config)
39
+
40
+ # 2. Token Compressor / Connector
41
+ self.compressor = PixelShuffleTokenCompressor(
42
+ config,
43
+ encoder_dim=config.vision_config.hidden_size,
44
+ decoder_dim=config.decoder_config.hidden_size
45
+ )
46
+
47
+ # 3. Decoder
48
+ self.decoder = TinyDocDecoder(config.decoder_config)
49
+
50
+ # Learnable image pad / placeholder token ID
51
+ self.image_token_id = getattr(config, "image_token_id", 49152)
52
+
53
+ # 2D Positional Embeddings for visual features (added to tokens before projection)
54
+ s = config.pixel_shuffle_scale
55
+ compressed_grid_size = (config.image_size // config.patch_size) // s
56
+ compressed_patches = compressed_grid_size ** 2
57
+
58
+ # Learnable 2D positional embeddings for the compressed visual tokens
59
+ self.visual_pos_embed = nn.Parameter(
60
+ torch.zeros(1, 1, compressed_patches, config.decoder_config.hidden_size)
61
+ )
62
+
63
+ # 4. Structured Output Heads (multi-task)
64
+ self.output_heads = MultiTaskOutputHeads(
65
+ hidden_size=config.decoder_config.hidden_size,
66
+ vocab_size=config.decoder_config.vocab_size,
67
+ )
68
+
69
+ # Initialize weights
70
+ self.post_init()
71
+
72
+ def get_input_embeddings(self) -> nn.Module:
73
+ return self.decoder.get_input_embeddings()
74
+
75
+ def set_input_embeddings(self, value):
76
+ self.decoder.lm.set_input_embeddings(value)
77
+
78
+ def forward(
79
+ self,
80
+ input_ids: Optional[torch.LongTensor] = None,
81
+ pixel_values: Optional[torch.FloatTensor] = None,
82
+ attention_mask: Optional[torch.Tensor] = None,
83
+ position_ids: Optional[torch.LongTensor] = None,
84
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
85
+ inputs_embeds: Optional[torch.FloatTensor] = None,
86
+ labels: Optional[torch.LongTensor] = None,
87
+ use_cache: Optional[bool] = None,
88
+ output_attentions: Optional[bool] = None,
89
+ output_hidden_states: Optional[bool] = None,
90
+ return_dict: Optional[bool] = None,
91
+ task: Optional[str] = None,
92
+ ) -> Union[Tuple, Dict, CausalLMOutputWithPast]:
93
+
94
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
95
+ output_hidden_states = True if task else output_hidden_states
96
+
97
+ # Decoding pass (no new visual input, reuse cached states)
98
+ if pixel_values is None and past_key_values is not None:
99
+ outputs = self.decoder(
100
+ input_ids=input_ids,
101
+ attention_mask=attention_mask,
102
+ position_ids=position_ids,
103
+ past_key_values=past_key_values,
104
+ inputs_embeds=inputs_embeds,
105
+ labels=labels,
106
+ use_cache=use_cache,
107
+ output_attentions=output_attentions,
108
+ output_hidden_states=output_hidden_states,
109
+ return_dict=return_dict,
110
+ )
111
+ if task:
112
+ hidden = outputs.hidden_states[-1] if hasattr(outputs, "hidden_states") else outputs[2]
113
+ head_outputs = self.output_heads(hidden, task=task)
114
+ return {"lm_outputs": outputs, "head_outputs": head_outputs}
115
+ return outputs
116
+
117
+ # Prefill pass: merge text and visual tokens into inputs_embeds
118
+ if inputs_embeds is None:
119
+ inputs_embeds = self.decoder.get_input_embeddings()(input_ids)
120
+
121
+ if pixel_values is not None:
122
+ visual_features = self.vision_encoder(pixel_values)
123
+ compressed_features = self.compressor(visual_features)
124
+ compressed_features = compressed_features + self.visual_pos_embed
125
+
126
+ batch_size, num_tiles, compressed_patches, decoder_dim = compressed_features.shape
127
+ flat_visual_features = compressed_features.view(
128
+ batch_size, num_tiles * compressed_patches, decoder_dim
129
+ )
130
+
131
+ image_mask = (input_ids == self.image_token_id)
132
+ for b in range(batch_size):
133
+ num_places = image_mask[b].sum().item()
134
+ if num_places > 0:
135
+ features_to_insert = flat_visual_features[b][:num_places]
136
+ inputs_embeds[b, image_mask[b]] = features_to_insert
137
+
138
+ outputs = self.decoder(
139
+ input_ids=None,
140
+ attention_mask=attention_mask,
141
+ position_ids=position_ids,
142
+ past_key_values=past_key_values,
143
+ inputs_embeds=inputs_embeds,
144
+ labels=labels,
145
+ use_cache=use_cache,
146
+ output_attentions=output_attentions,
147
+ output_hidden_states=output_hidden_states,
148
+ return_dict=return_dict,
149
+ )
150
+
151
+ if task:
152
+ hidden = outputs.hidden_states[-1] if hasattr(outputs, "hidden_states") else outputs[-1]
153
+ head_outputs = self.output_heads(hidden, task=task)
154
+ return {"lm_outputs": outputs, "head_outputs": head_outputs}
155
+
156
+ return outputs
157
+
158
+ def prepare_inputs_for_generation(
159
+ self,
160
+ input_ids,
161
+ past_key_values=None,
162
+ attention_mask=None,
163
+ inputs_embeds=None,
164
+ pixel_values=None,
165
+ **kwargs
166
+ ) -> Dict[str, Any]:
167
+ """
168
+ Overridden to support KV caching during auto-regressive generation.
169
+ """
170
+ is_decoding = past_key_values is not None and pixel_values is None
171
+
172
+ if is_decoding:
173
+ input_ids = input_ids[:, -1:]
174
+ inputs_embeds = None
175
+
176
+ position_ids = kwargs.get("position_ids", None)
177
+ if attention_mask is not None and position_ids is None:
178
+ position_ids = attention_mask.long().cumsum(-1) - 1
179
+ position_ids.masked_fill_(attention_mask == 0, 1)
180
+ if is_decoding:
181
+ position_ids = position_ids[:, -input_ids.shape[-1]:]
182
+
183
+ return {
184
+ "input_ids": input_ids,
185
+ "inputs_embeds": inputs_embeds,
186
+ "past_key_values": past_key_values,
187
+ "pixel_values": pixel_values,
188
+ "attention_mask": attention_mask,
189
+ "position_ids": position_ids,
190
+ "use_cache": kwargs.get("use_cache"),
191
+ }
192
+
193
+ def _reorder_cache(self, past_key_values, beam_idx):
194
+ return self.decoder.lm._reorder_cache(past_key_values, beam_idx)
@@ -0,0 +1,135 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Dict
4
+
5
+ class JSONHead(nn.Module):
6
+ """
7
+ Structured JSON generation head.
8
+ Projects decoder hidden states to schema-constrained JSON token logits.
9
+ In practice, this works as a specialized classifier over JSON structural tokens
10
+ plus a pointer network for field values.
11
+ """
12
+ def __init__(self, hidden_size: int, num_json_tokens: int = 256):
13
+ super().__init__()
14
+ self.proj = nn.Sequential(
15
+ nn.Linear(hidden_size, hidden_size),
16
+ nn.GELU(),
17
+ nn.Linear(hidden_size, num_json_tokens),
18
+ )
19
+
20
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
21
+ return self.proj(hidden_states)
22
+
23
+
24
+ class KVHead(nn.Module):
25
+ """
26
+ Key-Value extraction head.
27
+ Produces key-value pairs from decoder hidden states with confidence scores.
28
+ Uses two separate projections: one for key detection, one for value extraction.
29
+ """
30
+ def __init__(self, hidden_size: int, num_keys: int = 128):
31
+ super().__init__()
32
+ self.key_classifier = nn.Linear(hidden_size, num_keys)
33
+ self.value_proj = nn.Linear(hidden_size, hidden_size)
34
+ self.confidence = nn.Linear(hidden_size, 1)
35
+
36
+ def forward(self, hidden_states: torch.Tensor) -> Dict[str, torch.Tensor]:
37
+ key_logits = self.key_classifier(hidden_states)
38
+ value_embeds = self.value_proj(hidden_states)
39
+ conf = torch.sigmoid(self.confidence(hidden_states))
40
+ return {
41
+ "key_logits": key_logits,
42
+ "value_embeds": value_embeds,
43
+ "confidence": conf,
44
+ }
45
+
46
+
47
+ class TableHead(nn.Module):
48
+ """
49
+ Table generation head.
50
+ Outputs HTML/Markdown table tokens via a specialized projection.
51
+ Can also output structured row/cell coordinates.
52
+ """
53
+ def __init__(self, hidden_size: int, num_table_tokens: int = 128):
54
+ super().__init__()
55
+ self.proj = nn.Sequential(
56
+ nn.Linear(hidden_size, hidden_size),
57
+ nn.GELU(),
58
+ nn.Linear(hidden_size, num_table_tokens),
59
+ )
60
+
61
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
62
+ return self.proj(hidden_states)
63
+
64
+
65
+ class OCRHead(nn.Module):
66
+ """
67
+ OCR text generation head.
68
+ Decodes visual features into character-level sequences.
69
+ Useful for reading-order text extraction.
70
+ """
71
+ def __init__(self, hidden_size: int, vocab_size: int = 256):
72
+ super().__init__()
73
+ self.proj = nn.Sequential(
74
+ nn.Linear(hidden_size, hidden_size),
75
+ nn.GELU(),
76
+ nn.Linear(hidden_size, vocab_size),
77
+ )
78
+
79
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
80
+ return self.proj(hidden_states)
81
+
82
+
83
+ class QAHead(nn.Module):
84
+ """
85
+ Question-answering head.
86
+ Standard LM head for natural language answers.
87
+ """
88
+ def __init__(self, hidden_size: int, vocab_size: int):
89
+ super().__init__()
90
+ self.proj = nn.Linear(hidden_size, vocab_size, bias=False)
91
+
92
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
93
+ return self.proj(hidden_states)
94
+
95
+
96
+ class MultiTaskOutputHeads(nn.Module):
97
+ """
98
+ Container for all structured output heads.
99
+ Routes decoder hidden states to the appropriate task-specific head(s).
100
+ During training, all heads can be trained jointly with task-specific losses.
101
+ During inference, only the requested head is used.
102
+ """
103
+ def __init__(
104
+ self,
105
+ hidden_size: int,
106
+ vocab_size: int,
107
+ num_json_tokens: int = 256,
108
+ num_keys: int = 128,
109
+ num_table_tokens: int = 128,
110
+ ocr_vocab_size: int = 256,
111
+ ):
112
+ super().__init__()
113
+ self.json_head = JSONHead(hidden_size, num_json_tokens)
114
+ self.kv_head = KVHead(hidden_size, num_keys)
115
+ self.table_head = TableHead(hidden_size, num_table_tokens)
116
+ self.ocr_head = OCRHead(hidden_size, ocr_vocab_size)
117
+ self.qa_head = QAHead(hidden_size, vocab_size)
118
+
119
+ def forward(
120
+ self,
121
+ hidden_states: torch.Tensor,
122
+ task: str = "qa",
123
+ ) -> Dict[str, torch.Tensor]:
124
+ if task == "json":
125
+ return {"logits": self.json_head(hidden_states)}
126
+ elif task == "kv":
127
+ return self.kv_head(hidden_states)
128
+ elif task == "table":
129
+ return {"logits": self.table_head(hidden_states)}
130
+ elif task == "ocr":
131
+ return {"logits": self.ocr_head(hidden_states)}
132
+ elif task == "qa":
133
+ return {"logits": self.qa_head(hidden_states)}
134
+ else:
135
+ raise ValueError(f"Unknown task: {task}")
@@ -0,0 +1,177 @@
1
+ """
2
+ TinyDocVLMProcessor — standalone processor (does not inherit ProcessorMixin to avoid
3
+ strict type-checking issues in transformers<4.45). Provides the same public API.
4
+ """
5
+ import os
6
+ import json
7
+ from PIL import Image
8
+ from typing import Dict, Any, Union, Optional, List
9
+
10
+ import torch
11
+ from transformers import AutoTokenizer
12
+
13
+ from .image_processing import TinyDocImageProcessor
14
+
15
+
16
+ class TinyDocVLMProcessor:
17
+ """
18
+ Coordinates TinyDocImageProcessor (image tiling + normalisation) and the
19
+ SmolLM2 tokenizer extended with document-special tokens.
20
+
21
+ Usage:
22
+ processor = TinyDocVLMProcessor()
23
+ inputs = processor(text=["Extract fields: <image>"], images=[pil_img])
24
+ # inputs → {"input_ids", "attention_mask", "pixel_values", "image_token_id"}
25
+ """
26
+
27
+ # Class-level attrs used by some HF utilities (save_pretrained etc.)
28
+ image_processor_class = "TinyDocImageProcessor"
29
+ tokenizer_class = "AutoTokenizer"
30
+
31
+ def __init__(
32
+ self,
33
+ image_processor: Optional[TinyDocImageProcessor] = None,
34
+ tokenizer=None,
35
+ config=None,
36
+ **kwargs,
37
+ ):
38
+ self.image_processor = image_processor or TinyDocImageProcessor()
39
+ self.tokenizer = tokenizer or AutoTokenizer.from_pretrained(
40
+ "HuggingFaceTB/SmolLM2-135M-Instruct"
41
+ )
42
+ self.config = config
43
+
44
+ # Ensure <image> special token exists
45
+ self.image_token = "<image>"
46
+ if self.image_token not in self.tokenizer.get_vocab():
47
+ self.tokenizer.add_special_tokens(
48
+ {"additional_special_tokens": [self.image_token]}
49
+ )
50
+ self.image_token_id = self.tokenizer.convert_tokens_to_ids(self.image_token)
51
+
52
+ # ------------------------------------------------------------------
53
+ # Core __call__
54
+ # ------------------------------------------------------------------
55
+
56
+ def __call__(
57
+ self,
58
+ text: Union[str, List[str]],
59
+ images: Optional[Union[Image.Image, List[Image.Image]]] = None,
60
+ padding: bool = True,
61
+ truncation: bool = True,
62
+ max_length: Optional[int] = None,
63
+ return_tensors: str = "pt",
64
+ ) -> Dict[str, Any]:
65
+ """
66
+ Preprocesses text and images into tensors for the model.
67
+
68
+ Returns:
69
+ dict with keys: input_ids, attention_mask, pixel_values (optional),
70
+ image_token_id.
71
+ """
72
+ # ---- 1. Image processing ------------------------------------------------
73
+ pixel_values = None
74
+ num_tiles_list: List[int] = []
75
+
76
+ if images is not None:
77
+ if not isinstance(images, list):
78
+ images = [images]
79
+
80
+ processed: List[torch.Tensor] = []
81
+ for img in images:
82
+ tile_tensor = self.image_processor.preprocess(img) # (T, 3, H, W)
83
+ processed.append(tile_tensor)
84
+ num_tiles_list.append(tile_tensor.shape[0])
85
+
86
+ # Pad to max tiles so we can stack into a single tensor
87
+ max_tiles = max(num_tiles_list)
88
+ padded: List[torch.Tensor] = []
89
+ sz = self.image_processor.image_size
90
+ for tile_tensor in processed:
91
+ T = tile_tensor.shape[0]
92
+ if T < max_tiles:
93
+ pad = torch.zeros(
94
+ (max_tiles - T, 3, sz, sz),
95
+ dtype=tile_tensor.dtype,
96
+ device=tile_tensor.device,
97
+ )
98
+ tile_tensor = torch.cat([tile_tensor, pad], dim=0)
99
+ padded.append(tile_tensor)
100
+
101
+ pixel_values = torch.stack(padded, dim=0) # (B, max_tiles, 3, H, W)
102
+
103
+ # ---- 2. Expand <image> tokens in text ----------------------------------
104
+ scale = (
105
+ getattr(self.config, "pixel_shuffle_scale", 3) if self.config else 3
106
+ )
107
+ patch_size = (
108
+ getattr(self.config, "patch_size", 16) if self.config else 16
109
+ )
110
+ sz = self.image_processor.image_size
111
+ tokens_per_tile = (sz // patch_size // scale) ** 2
112
+
113
+ if isinstance(text, list):
114
+ expanded: List[str] = []
115
+ for idx, t in enumerate(text):
116
+ if idx < len(num_tiles_list):
117
+ total_vis = num_tiles_list[idx] * tokens_per_tile
118
+ else:
119
+ total_vis = 0
120
+ expanded.append(self._expand_image_tokens(t, total_vis))
121
+ processed_text = expanded
122
+ else:
123
+ total_vis = num_tiles_list[0] * tokens_per_tile if num_tiles_list else 0
124
+ processed_text = self._expand_image_tokens(text, total_vis)
125
+
126
+ # ---- 3. Tokenise -------------------------------------------------------
127
+ enc = self.tokenizer(
128
+ processed_text,
129
+ padding=padding,
130
+ truncation=truncation,
131
+ max_length=max_length,
132
+ return_tensors=return_tensors,
133
+ )
134
+
135
+ inputs: Dict[str, Any] = {
136
+ "input_ids": enc["input_ids"],
137
+ "attention_mask": enc["attention_mask"],
138
+ }
139
+ if pixel_values is not None:
140
+ inputs["pixel_values"] = pixel_values
141
+ inputs["image_token_id"] = self.image_token_id
142
+
143
+ return inputs
144
+
145
+ # ------------------------------------------------------------------
146
+ # Helpers
147
+ # ------------------------------------------------------------------
148
+
149
+ def _expand_image_tokens(self, text: str, total_tokens: int) -> str:
150
+ """Replaces the single '<image>' placeholder with `total_tokens` copies."""
151
+ expansion = self.image_token * total_tokens
152
+ return text.replace(self.image_token, expansion)
153
+
154
+ # ------------------------------------------------------------------
155
+ # save / from_pretrained stubs (for HF Hub compatibility)
156
+ # ------------------------------------------------------------------
157
+
158
+ def save_pretrained(self, save_directory: str, **kwargs):
159
+ os.makedirs(save_directory, exist_ok=True)
160
+ self.image_processor.save_pretrained(save_directory)
161
+ self.tokenizer.save_pretrained(save_directory)
162
+ # Write a minimal processor_config.json
163
+ cfg = {
164
+ "processor_class": "TinyDocVLMProcessor",
165
+ "image_token": self.image_token,
166
+ "image_token_id": self.image_token_id,
167
+ }
168
+ with open(os.path.join(save_directory, "processor_config.json"), "w") as f:
169
+ json.dump(cfg, f, indent=2)
170
+
171
+ @classmethod
172
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
173
+ image_processor = TinyDocImageProcessor.from_pretrained(
174
+ pretrained_model_name_or_path
175
+ )
176
+ tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path)
177
+ return cls(image_processor=image_processor, tokenizer=tokenizer, **kwargs)
@@ -0,0 +1,82 @@
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from .configuration import TinyDocVLMConfig
5
+
6
+ class RMSNorm(nn.Module):
7
+ def __init__(self, dim: int, eps: float = 1e-6):
8
+ super().__init__()
9
+ self.eps = eps
10
+ self.weight = nn.Parameter(torch.ones(dim))
11
+
12
+ def forward(self, x):
13
+ variance = x.pow(2).mean(-1, keepdim=True)
14
+ return x * torch.rsqrt(variance + self.eps) * self.weight
15
+
16
+ class PixelShuffleTokenCompressor(nn.Module):
17
+ """
18
+ Performs space-to-depth token compression on Vision Transformer patch sequences.
19
+ Groups scale_factor x scale_factor patches and projects to decoder hidden dimension.
20
+ """
21
+ def __init__(self, config: TinyDocVLMConfig, encoder_dim: int, decoder_dim: int):
22
+ super().__init__()
23
+ self.config = config
24
+ self.scale_factor = config.pixel_shuffle_scale
25
+ self.encoder_dim = encoder_dim
26
+ self.decoder_dim = decoder_dim
27
+
28
+ # After space-to-depth, channel dimension becomes encoder_dim * scale_factor^2
29
+ compressed_dim = encoder_dim * (self.scale_factor ** 2)
30
+
31
+ # MLP projection to map visual tokens to language model dimension
32
+ self.projection = nn.Sequential(
33
+ nn.Linear(compressed_dim, decoder_dim),
34
+ nn.GELU(),
35
+ nn.Linear(decoder_dim, decoder_dim)
36
+ )
37
+ self.norm = RMSNorm(decoder_dim)
38
+
39
+ def forward(self, visual_features: torch.Tensor) -> torch.Tensor:
40
+ """
41
+ Args:
42
+ visual_features: shape (batch_size, num_tiles, num_patches, encoder_dim)
43
+
44
+ Returns:
45
+ compressed_features: shape (batch_size, num_tiles, num_compressed_patches, decoder_dim)
46
+ """
47
+ batch_size, num_tiles, num_patches, encoder_dim = visual_features.shape
48
+
49
+ # Determine spatial dimensions assuming a square grid of patches
50
+ grid_size = int(math.sqrt(num_patches))
51
+ if grid_size * grid_size != num_patches:
52
+ raise ValueError(
53
+ f"Number of patches ({num_patches}) must be a perfect square to apply 2D pixel shuffle."
54
+ )
55
+
56
+ if grid_size % self.scale_factor != 0:
57
+ raise ValueError(
58
+ f"Grid size ({grid_size}) must be divisible by pixel_shuffle_scale ({self.scale_factor})."
59
+ )
60
+
61
+ # Reshape to 2D spatial grid: (batch_size * num_tiles, grid_size, grid_size, encoder_dim)
62
+ x = visual_features.view(batch_size * num_tiles, grid_size, grid_size, encoder_dim)
63
+
64
+ # Apply space-to-depth: (batch_size * num_tiles, H//s, s, W//s, s, C)
65
+ s = self.scale_factor
66
+ x = x.view(batch_size * num_tiles, grid_size // s, s, grid_size // s, s, encoder_dim)
67
+
68
+ # Permute: (batch_size * num_tiles, H//s, W//s, s, s, C)
69
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
70
+
71
+ # Reshape to flatten the spatial groups into the channel dimension:
72
+ # (batch_size * num_tiles, (H//s) * (W//s), s * s * C)
73
+ new_patches = (grid_size // s) ** 2
74
+ x = x.view(batch_size * num_tiles, new_patches, s * s * encoder_dim)
75
+
76
+ # Project and normalize
77
+ x = self.projection(x)
78
+ x = self.norm(x)
79
+
80
+ # Reshape back to batch: (batch_size, num_tiles, new_patches, decoder_dim)
81
+ x = x.view(batch_size, num_tiles, new_patches, self.decoder_dim)
82
+ return x