llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (159) hide show
  1. llama_stack/cli/stack/_list_deps.py +11 -7
  2. llama_stack/cli/stack/run.py +3 -25
  3. llama_stack/core/access_control/datatypes.py +78 -0
  4. llama_stack/core/configure.py +2 -2
  5. llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
  6. llama_stack/core/connectors/connectors.py +162 -0
  7. llama_stack/core/conversations/conversations.py +61 -58
  8. llama_stack/core/datatypes.py +54 -8
  9. llama_stack/core/library_client.py +60 -13
  10. llama_stack/core/prompts/prompts.py +43 -42
  11. llama_stack/core/routers/datasets.py +20 -17
  12. llama_stack/core/routers/eval_scoring.py +143 -53
  13. llama_stack/core/routers/inference.py +20 -9
  14. llama_stack/core/routers/safety.py +30 -42
  15. llama_stack/core/routers/vector_io.py +15 -7
  16. llama_stack/core/routing_tables/models.py +42 -3
  17. llama_stack/core/routing_tables/scoring_functions.py +19 -19
  18. llama_stack/core/routing_tables/shields.py +20 -17
  19. llama_stack/core/routing_tables/vector_stores.py +8 -5
  20. llama_stack/core/server/auth.py +192 -17
  21. llama_stack/core/server/fastapi_router_registry.py +40 -5
  22. llama_stack/core/server/server.py +24 -5
  23. llama_stack/core/stack.py +54 -10
  24. llama_stack/core/storage/datatypes.py +9 -0
  25. llama_stack/core/store/registry.py +1 -1
  26. llama_stack/core/utils/exec.py +2 -2
  27. llama_stack/core/utils/type_inspection.py +16 -2
  28. llama_stack/distributions/dell/config.yaml +4 -1
  29. llama_stack/distributions/dell/run-with-safety.yaml +4 -1
  30. llama_stack/distributions/nvidia/config.yaml +4 -1
  31. llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
  32. llama_stack/distributions/oci/config.yaml +4 -1
  33. llama_stack/distributions/open-benchmark/config.yaml +9 -1
  34. llama_stack/distributions/postgres-demo/config.yaml +1 -1
  35. llama_stack/distributions/starter/build.yaml +62 -0
  36. llama_stack/distributions/starter/config.yaml +22 -3
  37. llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
  38. llama_stack/distributions/starter/starter.py +13 -1
  39. llama_stack/distributions/starter-gpu/build.yaml +62 -0
  40. llama_stack/distributions/starter-gpu/config.yaml +22 -3
  41. llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
  42. llama_stack/distributions/template.py +10 -2
  43. llama_stack/distributions/watsonx/config.yaml +4 -1
  44. llama_stack/log.py +1 -0
  45. llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
  46. llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
  47. llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
  48. llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
  49. llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
  50. llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
  51. llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
  52. llama_stack/providers/inline/batches/reference/batches.py +2 -1
  53. llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
  54. llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
  55. llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
  56. llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
  57. llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
  58. llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
  59. llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
  60. llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
  61. llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
  62. llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
  63. llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
  64. llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
  65. llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
  66. llama_stack/providers/registry/agents.py +1 -0
  67. llama_stack/providers/registry/inference.py +1 -9
  68. llama_stack/providers/registry/vector_io.py +136 -16
  69. llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
  70. llama_stack/providers/remote/files/s3/config.py +5 -3
  71. llama_stack/providers/remote/files/s3/files.py +2 -2
  72. llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
  73. llama_stack/providers/remote/inference/openai/openai.py +2 -0
  74. llama_stack/providers/remote/inference/together/together.py +4 -0
  75. llama_stack/providers/remote/inference/vertexai/config.py +3 -3
  76. llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
  77. llama_stack/providers/remote/inference/vllm/config.py +37 -18
  78. llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
  79. llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
  80. llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
  81. llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
  82. llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
  83. llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
  84. llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
  85. llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
  86. llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
  87. llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
  88. llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
  89. llama_stack/providers/remote/vector_io/oci/config.py +41 -0
  90. llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
  91. llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
  92. llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
  93. llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
  94. llama_stack/providers/utils/bedrock/client.py +3 -3
  95. llama_stack/providers/utils/bedrock/config.py +7 -7
  96. llama_stack/providers/utils/inference/__init__.py +0 -25
  97. llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
  98. llama_stack/providers/utils/inference/http_client.py +239 -0
  99. llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
  100. llama_stack/providers/utils/inference/model_registry.py +148 -2
  101. llama_stack/providers/utils/inference/openai_compat.py +1 -158
  102. llama_stack/providers/utils/inference/openai_mixin.py +42 -2
  103. llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
  104. llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
  105. llama_stack/providers/utils/memory/vector_store.py +46 -19
  106. llama_stack/providers/utils/responses/responses_store.py +7 -7
  107. llama_stack/providers/utils/safety.py +114 -0
  108. llama_stack/providers/utils/tools/mcp.py +44 -3
  109. llama_stack/testing/api_recorder.py +9 -3
  110. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
  111. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
  112. llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
  113. llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
  114. llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
  115. llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
  116. llama_stack/models/llama/hadamard_utils.py +0 -88
  117. llama_stack/models/llama/llama3/args.py +0 -74
  118. llama_stack/models/llama/llama3/dog.jpg +0 -0
  119. llama_stack/models/llama/llama3/generation.py +0 -378
  120. llama_stack/models/llama/llama3/model.py +0 -304
  121. llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
  122. llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
  123. llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
  124. llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
  125. llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
  126. llama_stack/models/llama/llama3/pasta.jpeg +0 -0
  127. llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
  128. llama_stack/models/llama/llama3/quantization/loader.py +0 -316
  129. llama_stack/models/llama/llama3_1/__init__.py +0 -12
  130. llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
  131. llama_stack/models/llama/llama3_1/prompts.py +0 -258
  132. llama_stack/models/llama/llama3_2/__init__.py +0 -5
  133. llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
  134. llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
  135. llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
  136. llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
  137. llama_stack/models/llama/llama3_3/__init__.py +0 -5
  138. llama_stack/models/llama/llama3_3/prompts.py +0 -259
  139. llama_stack/models/llama/llama4/args.py +0 -107
  140. llama_stack/models/llama/llama4/ffn.py +0 -58
  141. llama_stack/models/llama/llama4/moe.py +0 -214
  142. llama_stack/models/llama/llama4/preprocess.py +0 -435
  143. llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
  144. llama_stack/models/llama/llama4/quantization/loader.py +0 -226
  145. llama_stack/models/llama/llama4/vision/__init__.py +0 -5
  146. llama_stack/models/llama/llama4/vision/embedding.py +0 -210
  147. llama_stack/models/llama/llama4/vision/encoder.py +0 -412
  148. llama_stack/models/llama/quantize_impls.py +0 -316
  149. llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
  150. llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
  151. llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
  152. llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
  153. llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
  154. llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
  155. llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
  156. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
  157. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
  158. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
  159. {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
@@ -1,412 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the terms described in the LICENSE file in
5
- # the root directory of this source tree.
6
-
7
- from collections.abc import Callable
8
- from typing import Any
9
-
10
- import fairscale.nn.model_parallel.initialize as fs_init
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
15
- from torch import einsum
16
-
17
- from ..args import ModelArgs
18
- from ..model import Attention
19
-
20
-
21
- class LayerNorm(nn.LayerNorm):
22
- """Subclass torch's LayerNorm to handle fp16."""
23
-
24
- def forward(self, x: torch.Tensor):
25
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
26
- return x
27
-
28
-
29
- class ColumnParallelConv2dPatch(torch.nn.Module):
30
- """Conv2D Patching layer with model parallelism.
31
- Column parallel over unfolded input.
32
- Arguments:
33
- in_channels: Input channels.
34
- out_channels: Output channels.
35
- kernel_size: Size of convolution kernel.
36
- stride (default 1): Stride for convolution.
37
- bias (default False): Use bias in Conv2d.
38
- Input: (bsz, in_channels, height, width)
39
- Output: (bsz, num_tokens, out_channels)
40
- """
41
-
42
- def __init__(
43
- self,
44
- in_channels: int,
45
- out_channels: int,
46
- kernel_size: int | tuple[int, int],
47
- stride: int | tuple[int, int],
48
- bias: bool | None = False,
49
- ) -> None:
50
- super().__init__()
51
- if isinstance(kernel_size, int):
52
- kernel_size = (kernel_size, kernel_size)
53
- self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
54
- self._linear = ColumnParallelLinear(
55
- in_channels * kernel_size[0] * kernel_size[1],
56
- out_channels,
57
- bias=bias,
58
- )
59
-
60
- def forward(self, x: torch.Tensor) -> torch.Tensor:
61
- x = self._unfold(x)
62
- x = x.permute(0, 2, 1)
63
- x = self._linear(x)
64
- return x
65
-
66
-
67
- class _FeedForward(torch.nn.Module):
68
- def __init__(
69
- self,
70
- dim: int,
71
- hidden_dim: int,
72
- dropout: float,
73
- act_layer: Callable = nn.GELU,
74
- ):
75
- super().__init__()
76
- # layers
77
- self.c_fc = ColumnParallelLinear(
78
- dim,
79
- hidden_dim,
80
- bias=True,
81
- gather_output=False,
82
- init_method=lambda x: x,
83
- )
84
- self.c_proj = RowParallelLinear(
85
- hidden_dim,
86
- dim,
87
- bias=True,
88
- input_is_parallel=True,
89
- init_method=lambda x: x,
90
- )
91
- self.non_linearity = act_layer()
92
- self.dropout = dropout
93
-
94
- def forward(self, x):
95
- hidden = self.c_fc(x)
96
- hidden = self.non_linearity(hidden)
97
- hidden = F.dropout(hidden, p=self.dropout, training=self.training)
98
- return self.c_proj(hidden)
99
-
100
-
101
- class _TransformerBlock(nn.Module):
102
- def __init__(
103
- self,
104
- d_model: int,
105
- n_head: int,
106
- mlp_ratio: float = 4.0,
107
- act_layer: Callable = nn.GELU,
108
- gated: bool = False,
109
- ):
110
- super().__init__()
111
- assert d_model % n_head == 0
112
- self.n_heads = n_head
113
- self.head_dim = d_model // self.n_heads
114
-
115
- attn_args = ModelArgs(
116
- dim=d_model,
117
- head_dim=self.head_dim,
118
- n_heads=self.n_heads,
119
- n_kv_heads=self.n_heads,
120
- )
121
- self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
122
- self.ln_1 = LayerNorm(d_model)
123
- self.mlp = _FeedForward(
124
- dim=d_model,
125
- hidden_dim=int(mlp_ratio * d_model),
126
- dropout=0.0,
127
- act_layer=act_layer,
128
- )
129
- self.ln_2 = LayerNorm(d_model)
130
- self.gated = gated
131
- if gated:
132
- self.gate_attn = nn.Parameter(torch.zeros(1))
133
- self.gate_ffn = nn.Parameter(torch.zeros(1))
134
-
135
- def attention(
136
- self,
137
- x: torch.Tensor,
138
- freq_cis: torch.Tensor | None = None,
139
- ):
140
- return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
141
-
142
- def forward(
143
- self,
144
- x: torch.Tensor,
145
- mask: torch.Tensor | None = None,
146
- freq_cis: torch.Tensor | None = None,
147
- ):
148
- _gate_attn = 1 if not self.gated else self.gate_attn.tanh()
149
- _gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
150
-
151
- x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
152
- x = x + _gate_ffn * self.mlp(self.ln_2(x))
153
- return x
154
-
155
-
156
- class _Transformer(nn.Module):
157
- def __init__(
158
- self,
159
- dim: int,
160
- layers: int,
161
- heads: int,
162
- mlp_ratio: float = 4.0,
163
- act_layer: Callable = nn.GELU,
164
- gated: bool = False,
165
- ):
166
- super().__init__()
167
- self.resblocks = nn.ModuleList(
168
- [
169
- _TransformerBlock(
170
- d_model=dim,
171
- n_head=heads,
172
- mlp_ratio=mlp_ratio,
173
- act_layer=act_layer,
174
- gated=gated,
175
- )
176
- for _ in range(layers)
177
- ]
178
- )
179
-
180
- def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
181
- out = []
182
- for idx, r in enumerate(self.resblocks):
183
- if return_intermediate is not None and idx in return_intermediate:
184
- out.append(x)
185
- x = r(x, mask=mask, freq_cis=freq_cis)
186
- if return_intermediate is not None:
187
- return x, torch.stack(out, dim=-1)
188
- return x
189
-
190
-
191
- class PackingIndex:
192
- Z = 0 # Z (time) coordinate of the token in the original sample
193
- Y = 1 # Y (height) coordinate of the token in the original sample
194
- X = 2 # X (width) coordinate of the token in the original sample
195
- TIME = 3 # Total number of time units (frames) in the original sample
196
- HEIGHT = 4 # Height of the original sample
197
- WIDTH = 5 # Width of the original sample
198
- # USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
199
- IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
200
- BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
201
-
202
- # Total size of the enum, remember to update this!
203
- NUM_METADATA = 8
204
-
205
- # Note: For padding tokens IDX = -1
206
- # For cls tokens, IDX = -2
207
- ID_CLS_TOKEN = -2
208
- ID_PAD_TOKEN = -1
209
-
210
-
211
- class VisionEncoder(nn.Module):
212
- def __init__(
213
- self,
214
- image_size: tuple[int, int],
215
- patch_size: tuple[int, int],
216
- dim: int,
217
- layers: int,
218
- heads: int,
219
- mlp_ratio: float,
220
- in_channels: int = 3,
221
- ):
222
- super().__init__()
223
- self.image_size = image_size
224
- self.patch_size = patch_size
225
- self.grid_size = (
226
- self.image_size[0] // self.patch_size[0],
227
- self.image_size[1] // self.patch_size[1],
228
- )
229
- self.conv1 = ColumnParallelConv2dPatch(
230
- in_channels=in_channels,
231
- out_channels=dim,
232
- kernel_size=patch_size,
233
- stride=patch_size,
234
- bias=False,
235
- )
236
- scale = dim**-0.5
237
- self.class_embedding = nn.Parameter(scale * torch.randn(dim))
238
-
239
- self.positional_embedding_vlm = nn.Parameter(
240
- scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
241
- )
242
-
243
- self.ln_pre = LayerNorm(dim)
244
- self.ln_post = LayerNorm(dim)
245
- self.transformer = _Transformer(
246
- dim,
247
- layers,
248
- heads,
249
- mlp_ratio,
250
- act_layer=nn.GELU,
251
- )
252
-
253
- # NOTE: hack for the fixed res
254
- image_h, image_w = self.image_size
255
- patch_h, patch_w = self.patch_size
256
- idx_h, idx_w = image_h // patch_h, image_w // patch_w
257
- img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
258
- img_idx = img_idx.reshape(idx_h * idx_w, 1)
259
- img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
260
- img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
261
-
262
- packed_img_idx = torch.empty(
263
- img_idx.shape[0],
264
- img_idx.shape[1],
265
- PackingIndex.NUM_METADATA - 1,
266
- dtype=torch.int32,
267
- )
268
- packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
269
- packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
270
- packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
271
- packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
272
- packed_img_idx[:, :, PackingIndex.IDX] = img_idx
273
- packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
274
- self.packed_img_idx = packed_img_idx # for positional embedding load hook
275
-
276
- # compute rope freqs
277
- rope_freq = self.get_rope_freqs(dim // heads // 2)
278
- freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
279
- freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
280
- freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
281
- # disable RoPE for padding and cls tokens
282
- freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
283
- # compute complex freqs
284
- self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
285
- # xlf automatically broadcasts
286
- self.freq_cis = self.freq_cis.squeeze(0)
287
- self.n_heads = heads // fs_init.get_model_parallel_world_size()
288
-
289
- self._register_load_state_dict_pre_hook(self.load_hook)
290
-
291
- def get_rope_freqs(self, dim, theta=10000):
292
- freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
293
- return freqs
294
-
295
- @torch.amp.autocast("cuda", enabled=False)
296
- def compute_rope_freqs(self, freqs, t):
297
- freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
298
- freqs = freqs.repeat_interleave(2, dim=-1)
299
- return freqs
300
-
301
- def load_hook(
302
- self,
303
- state_dict: dict[str, Any],
304
- prefix: str,
305
- local_metadata: dict[str, Any],
306
- strict: bool = True,
307
- missing_keys: list[str] = None,
308
- unexpected_keys: list[str] = None,
309
- error_msgs: list[str] = None,
310
- return_state_dict: bool = False,
311
- ) -> None:
312
- orig_pos_embed = state_dict.get(prefix + "positional_embedding")
313
- if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
314
- raise ValueError(
315
- f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
316
- )
317
-
318
- batch_size, token_per_image, _ = self.packed_img_idx.shape
319
- # Input points for idx are [x, y, w, h]
320
- idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
321
- total_windows, window_size, _ = idx.shape
322
-
323
- # Grid values are [-1, 1] and coords are w, h
324
- grid = (
325
- (idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
326
- )[None, ...]
327
-
328
- # In this mode, cls token has no position embedding
329
- if orig_pos_embed is not None:
330
- posemb = (
331
- orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
332
- )
333
- posemb = posemb.to(device=grid.device, dtype=grid.dtype)
334
- sample = F.grid_sample(
335
- posemb, grid, padding_mode="zeros"
336
- ) # padding tokens / class token will get zero for posemb
337
- sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
338
- sample = torch.where(
339
- idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
340
- orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
341
- sample,
342
- )
343
-
344
- new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
345
-
346
- state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
347
-
348
- if return_state_dict:
349
- return state_dict
350
-
351
- def apply_class_embedding(self, x):
352
- x = torch.cat(
353
- [
354
- x,
355
- self.class_embedding.to(x.dtype)
356
- + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
357
- ],
358
- dim=1,
359
- ) # shape = [*, grid ** 2 + 1, width]
360
- return x
361
-
362
- def forward(self, images: torch.Tensor) -> torch.Tensor:
363
- # NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
364
- if images.ndim == 5:
365
- num_concurrent_media = 1
366
- bsz, num_chunks, nch, h, w = images.shape
367
- else:
368
- bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
369
-
370
- images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
371
- # patch embedding
372
- x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
373
- x = self.conv1(x) # shape = [*, width, grid ** 2]
374
- _, ntok, dim = x.shape
375
- x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
376
-
377
- # apply cls token
378
- x = self.apply_class_embedding(x)
379
- ntok += 1
380
-
381
- # apply position embeddings
382
- if self.positional_embedding_vlm is not None:
383
- x = x + self.positional_embedding_vlm.to(x.dtype)
384
-
385
- x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
386
-
387
- x = self.ln_pre(x)
388
- x = x.view(bsz * num_concurrent_media, -1, dim)
389
- freq_cis = self.freq_cis.to(images.device)
390
-
391
- tf_output = self.transformer(
392
- x,
393
- freq_cis=freq_cis,
394
- )
395
-
396
- int_x = None
397
- if isinstance(tf_output, tuple):
398
- x, int_x = tf_output
399
- else:
400
- x = tf_output
401
- x = self.ln_post(x)
402
-
403
- # remove cls token output
404
- x = x[:, :-1, :]
405
-
406
- # add and output x + int_x features
407
- if int_x is not None:
408
- int_x = int_x[:, :-1, :, :]
409
- int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
410
- x = torch.cat([x, int_x], dim=-1)
411
-
412
- return x