llama-stack 0.4.4__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (155) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +57 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +49 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +94 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  57. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  58. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  59. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +15 -18
  60. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  61. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  62. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  63. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  64. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  65. llama_stack/providers/registry/agents.py +1 -0
  66. llama_stack/providers/registry/inference.py +1 -9
  67. llama_stack/providers/registry/vector_io.py +136 -16
  68. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  69. llama_stack/providers/remote/files/s3/config.py +5 -3
  70. llama_stack/providers/remote/files/s3/files.py +2 -2
  71. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  72. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  73. llama_stack/providers/remote/inference/together/together.py +4 -0
  74. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  75. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  76. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  77. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  78. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  79. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  80. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  81. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  82. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  83. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  84. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  85. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  86. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  87. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  88. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  89. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  90. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  91. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  92. llama_stack/providers/utils/bedrock/client.py +3 -3
  93. llama_stack/providers/utils/bedrock/config.py +7 -7
  94. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  95. llama_stack/providers/utils/inference/http_client.py +239 -0
  96. llama_stack/providers/utils/inference/litellm_openai_mixin.py +5 -0
  97. llama_stack/providers/utils/inference/model_registry.py +148 -2
  98. llama_stack/providers/utils/inference/openai_compat.py +2 -1
  99. llama_stack/providers/utils/inference/openai_mixin.py +41 -2
  100. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  101. llama_stack/providers/utils/memory/vector_store.py +46 -19
  102. llama_stack/providers/utils/responses/responses_store.py +7 -7
  103. llama_stack/providers/utils/safety.py +114 -0
  104. llama_stack/providers/utils/tools/mcp.py +44 -3
  105. llama_stack/testing/api_recorder.py +9 -3
  106. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/METADATA +14 -2
  107. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/RECORD +111 -144
  108. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  109. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  110. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  111. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  112. llama_stack/models/llama/hadamard_utils.py +0 -88
  113. llama_stack/models/llama/llama3/args.py +0 -74
  114. llama_stack/models/llama/llama3/dog.jpg +0 -0
  115. llama_stack/models/llama/llama3/generation.py +0 -378
  116. llama_stack/models/llama/llama3/model.py +0 -304
  117. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  118. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  119. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  120. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  121. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  122. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  123. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  124. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  125. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  126. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  127. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  128. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  129. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  130. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  131. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  132. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  133. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  134. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  135. llama_stack/models/llama/llama4/args.py +0 -107
  136. llama_stack/models/llama/llama4/ffn.py +0 -58
  137. llama_stack/models/llama/llama4/moe.py +0 -214
  138. llama_stack/models/llama/llama4/preprocess.py +0 -435
  139. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  140. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  141. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  142. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  143. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  144. llama_stack/models/llama/quantize_impls.py +0 -316
  145. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  146. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  147. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  148. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  149. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  150. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  151. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  152. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/WHEEL +0 -0
  153. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/entry_points.txt +0 -0
  154. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
  155. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,412 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from collections.abc import Callable
8
- from typing import Any
9
-
10
- import fairscale.nn.model_parallel.initialize as fs_init
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
15
- from torch import einsum
16
-
17
- from ..args import ModelArgs
18
- from ..model import Attention
19
-
20
-
21
- class LayerNorm(nn.LayerNorm):
22
- """Subclass torch's LayerNorm to handle fp16."""
23
-
24
- def forward(self, x: torch.Tensor):
25
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
26
- return x
27
-
28
-
29
- class ColumnParallelConv2dPatch(torch.nn.Module):
30
- """Conv2D Patching layer with model parallelism.
31
- Column parallel over unfolded input.
32
- Arguments:
33
- in_channels: Input channels.
34
- out_channels: Output channels.
35
- kernel_size: Size of convolution kernel.
36
- stride (default 1): Stride for convolution.
37
- bias (default False): Use bias in Conv2d.
38
- Input: (bsz, in_channels, height, width)
39
- Output: (bsz, num_tokens, out_channels)
40
- """
41
-
42
- def __init__(
43
- self,
44
- in_channels: int,
45
- out_channels: int,
46
- kernel_size: int | tuple[int, int],
47
- stride: int | tuple[int, int],
48
- bias: bool | None = False,
49
- ) -> None:
50
- super().__init__()
51
- if isinstance(kernel_size, int):
52
- kernel_size = (kernel_size, kernel_size)
53
- self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
54
- self._linear = ColumnParallelLinear(
55
- in_channels * kernel_size[0] * kernel_size[1],
56
- out_channels,
57
- bias=bias,
58
- )
59
-
60
- def forward(self, x: torch.Tensor) -> torch.Tensor:
61
- x = self._unfold(x)
62
- x = x.permute(0, 2, 1)
63
- x = self._linear(x)
64
- return x
65
-
66
-
67
- class _FeedForward(torch.nn.Module):
68
- def __init__(
69
- self,
70
- dim: int,
71
- hidden_dim: int,
72
- dropout: float,
73
- act_layer: Callable = nn.GELU,
74
- ):
75
- super().__init__()
76
- # layers
77
- self.c_fc = ColumnParallelLinear(
78
- dim,
79
- hidden_dim,
80
- bias=True,
81
- gather_output=False,
82
- init_method=lambda x: x,
83
- )
84
- self.c_proj = RowParallelLinear(
85
- hidden_dim,
86
- dim,
87
- bias=True,
88
- input_is_parallel=True,
89
- init_method=lambda x: x,
90
- )
91
- self.non_linearity = act_layer()
92
- self.dropout = dropout
93
-
94
- def forward(self, x):
95
- hidden = self.c_fc(x)
96
- hidden = self.non_linearity(hidden)
97
- hidden = F.dropout(hidden, p=self.dropout, training=self.training)
98
- return self.c_proj(hidden)
99
-
100
-
101
- class _TransformerBlock(nn.Module):
102
- def __init__(
103
- self,
104
- d_model: int,
105
- n_head: int,
106
- mlp_ratio: float = 4.0,
107
- act_layer: Callable = nn.GELU,
108
- gated: bool = False,
109
- ):
110
- super().__init__()
111
- assert d_model % n_head == 0
112
- self.n_heads = n_head
113
- self.head_dim = d_model // self.n_heads
114
-
115
- attn_args = ModelArgs(
116
- dim=d_model,
117
- head_dim=self.head_dim,
118
- n_heads=self.n_heads,
119
- n_kv_heads=self.n_heads,
120
- )
121
- self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
122
- self.ln_1 = LayerNorm(d_model)
123
- self.mlp = _FeedForward(
124
- dim=d_model,
125
- hidden_dim=int(mlp_ratio * d_model),
126
- dropout=0.0,
127
- act_layer=act_layer,
128
- )
129
- self.ln_2 = LayerNorm(d_model)
130
- self.gated = gated
131
- if gated:
132
- self.gate_attn = nn.Parameter(torch.zeros(1))
133
- self.gate_ffn = nn.Parameter(torch.zeros(1))
134
-
135
- def attention(
136
- self,
137
- x: torch.Tensor,
138
- freq_cis: torch.Tensor | None = None,
139
- ):
140
- return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
141
-
142
- def forward(
143
- self,
144
- x: torch.Tensor,
145
- mask: torch.Tensor | None = None,
146
- freq_cis: torch.Tensor | None = None,
147
- ):
148
- _gate_attn = 1 if not self.gated else self.gate_attn.tanh()
149
- _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
150
-
151
- x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
152
- x = x + _gate_ffn * self.mlp(self.ln_2(x))
153
- return x
154
-
155
-
156
- class _Transformer(nn.Module):
157
- def __init__(
158
- self,
159
- dim: int,
160
- layers: int,
161
- heads: int,
162
- mlp_ratio: float = 4.0,
163
- act_layer: Callable = nn.GELU,
164
- gated: bool = False,
165
- ):
166
- super().__init__()
167
- self.resblocks = nn.ModuleList(
168
- [
169
- _TransformerBlock(
170
- d_model=dim,
171
- n_head=heads,
172
- mlp_ratio=mlp_ratio,
173
- act_layer=act_layer,
174
- gated=gated,
175
- )
176
- for _ in range(layers)
177
- ]
178
- )
179
-
180
- def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
181
- out = []
182
- for idx, r in enumerate(self.resblocks):
183
- if return_intermediate is not None and idx in return_intermediate:
184
- out.append(x)
185
- x = r(x, mask=mask, freq_cis=freq_cis)
186
- if return_intermediate is not None:
187
- return x, torch.stack(out, dim=-1)
188
- return x
189
-
190
-
191
- class PackingIndex:
192
- Z = 0 # Z (time) coordinate of the token in the original sample
193
- Y = 1 # Y (height) coordinate of the token in the original sample
194
- X = 2 # X (width) coordinate of the token in the original sample
195
- TIME = 3 # Total number of time units (frames) in the original sample
196
- HEIGHT = 4 # Height of the original sample
197
- WIDTH = 5 # Width of the original sample
198
- # USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
199
- IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
200
- BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
201
-
202
- # Total size of the enum, remember to update this!
203
- NUM_METADATA = 8
204
-
205
- # Note: For padding tokens IDX = -1
206
- # For cls tokens, IDX = -2
207
- ID_CLS_TOKEN = -2
208
- ID_PAD_TOKEN = -1
209
-
210
-
211
- class VisionEncoder(nn.Module):
212
- def __init__(
213
- self,
214
- image_size: tuple[int, int],
215
- patch_size: tuple[int, int],
216
- dim: int,
217
- layers: int,
218
- heads: int,
219
- mlp_ratio: float,
220
- in_channels: int = 3,
221
- ):
222
- super().__init__()
223
- self.image_size = image_size
224
- self.patch_size = patch_size
225
- self.grid_size = (
226
- self.image_size[0] // self.patch_size[0],
227
- self.image_size[1] // self.patch_size[1],
228
- )
229
- self.conv1 = ColumnParallelConv2dPatch(
230
- in_channels=in_channels,
231
- out_channels=dim,
232
- kernel_size=patch_size,
233
- stride=patch_size,
234
- bias=False,
235
- )
236
- scale = dim**-0.5
237
- self.class_embedding = nn.Parameter(scale * torch.randn(dim))
238
-
239
- self.positional_embedding_vlm = nn.Parameter(
240
- scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
241
- )
242
-
243
- self.ln_pre = LayerNorm(dim)
244
- self.ln_post = LayerNorm(dim)
245
- self.transformer = _Transformer(
246
- dim,
247
- layers,
248
- heads,
249
- mlp_ratio,
250
- act_layer=nn.GELU,
251
- )
252
-
253
- # NOTE: hack for the fixed res
254
- image_h, image_w = self.image_size
255
- patch_h, patch_w = self.patch_size
256
- idx_h, idx_w = image_h // patch_h, image_w // patch_w
257
- img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
258
- img_idx = img_idx.reshape(idx_h * idx_w, 1)
259
- img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
260
- img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
261
-
262
- packed_img_idx = torch.empty(
263
- img_idx.shape[0],
264
- img_idx.shape[1],
265
- PackingIndex.NUM_METADATA - 1,
266
- dtype=torch.int32,
267
- )
268
- packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
269
- packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
270
- packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
271
- packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
272
- packed_img_idx[:, :, PackingIndex.IDX] = img_idx
273
- packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
274
- self.packed_img_idx = packed_img_idx # for positional embedding load hook
275
-
276
- # compute rope freqs
277
- rope_freq = self.get_rope_freqs(dim // heads // 2)
278
- freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
279
- freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
280
- freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
281
- # disable RoPE for padding and cls tokens
282
- freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
283
- # compute complex freqs
284
- self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
285
- # xlf automatically broadcasts
286
- self.freq_cis = self.freq_cis.squeeze(0)
287
- self.n_heads = heads // fs_init.get_model_parallel_world_size()
288
-
289
- self._register_load_state_dict_pre_hook(self.load_hook)
290
-
291
- def get_rope_freqs(self, dim, theta=10000):
292
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
293
- return freqs
294
-
295
- @torch.amp.autocast("cuda", enabled=False)
296
- def compute_rope_freqs(self, freqs, t):
297
- freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
298
- freqs = freqs.repeat_interleave(2, dim=-1)
299
- return freqs
300
-
301
- def load_hook(
302
- self,
303
- state_dict: dict[str, Any],
304
- prefix: str,
305
- local_metadata: dict[str, Any],
306
- strict: bool = True,
307
- missing_keys: list[str] = None,
308
- unexpected_keys: list[str] = None,
309
- error_msgs: list[str] = None,
310
- return_state_dict: bool = False,
311
- ) -> None:
312
- orig_pos_embed = state_dict.get(prefix + "positional_embedding")
313
- if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
314
- raise ValueError(
315
- f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
316
- )
317
-
318
- batch_size, token_per_image, _ = self.packed_img_idx.shape
319
- # Input points for idx are [x, y, w, h]
320
- idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
321
- total_windows, window_size, _ = idx.shape
322
-
323
- # Grid values are [-1, 1] and coords are w, h
324
- grid = (
325
- (idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
326
- )[None, ...]
327
-
328
- # In this mode, cls token has no position embedding
329
- if orig_pos_embed is not None:
330
- posemb = (
331
- orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
332
- )
333
- posemb = posemb.to(device=grid.device, dtype=grid.dtype)
334
- sample = F.grid_sample(
335
- posemb, grid, padding_mode="zeros"
336
- ) # padding tokens / class token will get zero for posemb
337
- sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
338
- sample = torch.where(
339
- idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
340
- orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
341
- sample,
342
- )
343
-
344
- new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
345
-
346
- state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
347
-
348
- if return_state_dict:
349
- return state_dict
350
-
351
- def apply_class_embedding(self, x):
352
- x = torch.cat(
353
- [
354
- x,
355
- self.class_embedding.to(x.dtype)
356
- + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
357
- ],
358
- dim=1,
359
- ) # shape = [*, grid ** 2 + 1, width]
360
- return x
361
-
362
- def forward(self, images: torch.Tensor) -> torch.Tensor:
363
- # NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
364
- if images.ndim == 5:
365
- num_concurrent_media = 1
366
- bsz, num_chunks, nch, h, w = images.shape
367
- else:
368
- bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
369
-
370
- images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
371
- # patch embedding
372
- x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
373
- x = self.conv1(x) # shape = [*, width, grid ** 2]
374
- _, ntok, dim = x.shape
375
- x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
376
-
377
- # apply cls token
378
- x = self.apply_class_embedding(x)
379
- ntok += 1
380
-
381
- # apply position embeddings
382
- if self.positional_embedding_vlm is not None:
383
- x = x + self.positional_embedding_vlm.to(x.dtype)
384
-
385
- x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
386
-
387
- x = self.ln_pre(x)
388
- x = x.view(bsz * num_concurrent_media, -1, dim)
389
- freq_cis = self.freq_cis.to(images.device)
390
-
391
- tf_output = self.transformer(
392
- x,
393
- freq_cis=freq_cis,
394
- )
395
-
396
- int_x = None
397
- if isinstance(tf_output, tuple):
398
- x, int_x = tf_output
399
- else:
400
- x = tf_output
401
- x = self.ln_post(x)
402
-
403
- # remove cls token output
404
- x = x[:, :-1, :]
405
-
406
- # add and output x + int_x features
407
- if int_x is not None:
408
- int_x = int_x[:, :-1, :, :]
409
- int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
410
- x = torch.cat([x, int_x], dim=-1)
411
-
412
- return x