llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/__init__.py +0 -25
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +1 -158
- llama_stack/providers/utils/inference/openai_mixin.py +42 -2
- llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +7 -7
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/dog.jpg +0 -0
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/pasta.jpeg +0 -0
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,412 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
|
|
7
|
-
from collections.abc import Callable
|
|
8
|
-
from typing import Any
|
|
9
|
-
|
|
10
|
-
import fairscale.nn.model_parallel.initialize as fs_init
|
|
11
|
-
import torch
|
|
12
|
-
import torch.nn as nn
|
|
13
|
-
import torch.nn.functional as F
|
|
14
|
-
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
|
15
|
-
from torch import einsum
|
|
16
|
-
|
|
17
|
-
from ..args import ModelArgs
|
|
18
|
-
from ..model import Attention
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class LayerNorm(nn.LayerNorm):
|
|
22
|
-
"""Subclass torch's LayerNorm to handle fp16."""
|
|
23
|
-
|
|
24
|
-
def forward(self, x: torch.Tensor):
|
|
25
|
-
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
26
|
-
return x
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
class ColumnParallelConv2dPatch(torch.nn.Module):
|
|
30
|
-
"""Conv2D Patching layer with model parallelism.
|
|
31
|
-
Column parallel over unfolded input.
|
|
32
|
-
Arguments:
|
|
33
|
-
in_channels: Input channels.
|
|
34
|
-
out_channels: Output channels.
|
|
35
|
-
kernel_size: Size of convolution kernel.
|
|
36
|
-
stride (default 1): Stride for convolution.
|
|
37
|
-
bias (default False): Use bias in Conv2d.
|
|
38
|
-
Input: (bsz, in_channels, height, width)
|
|
39
|
-
Output: (bsz, num_tokens, out_channels)
|
|
40
|
-
"""
|
|
41
|
-
|
|
42
|
-
def __init__(
|
|
43
|
-
self,
|
|
44
|
-
in_channels: int,
|
|
45
|
-
out_channels: int,
|
|
46
|
-
kernel_size: int | tuple[int, int],
|
|
47
|
-
stride: int | tuple[int, int],
|
|
48
|
-
bias: bool | None = False,
|
|
49
|
-
) -> None:
|
|
50
|
-
super().__init__()
|
|
51
|
-
if isinstance(kernel_size, int):
|
|
52
|
-
kernel_size = (kernel_size, kernel_size)
|
|
53
|
-
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
|
|
54
|
-
self._linear = ColumnParallelLinear(
|
|
55
|
-
in_channels * kernel_size[0] * kernel_size[1],
|
|
56
|
-
out_channels,
|
|
57
|
-
bias=bias,
|
|
58
|
-
)
|
|
59
|
-
|
|
60
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
61
|
-
x = self._unfold(x)
|
|
62
|
-
x = x.permute(0, 2, 1)
|
|
63
|
-
x = self._linear(x)
|
|
64
|
-
return x
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
class _FeedForward(torch.nn.Module):
|
|
68
|
-
def __init__(
|
|
69
|
-
self,
|
|
70
|
-
dim: int,
|
|
71
|
-
hidden_dim: int,
|
|
72
|
-
dropout: float,
|
|
73
|
-
act_layer: Callable = nn.GELU,
|
|
74
|
-
):
|
|
75
|
-
super().__init__()
|
|
76
|
-
# layers
|
|
77
|
-
self.c_fc = ColumnParallelLinear(
|
|
78
|
-
dim,
|
|
79
|
-
hidden_dim,
|
|
80
|
-
bias=True,
|
|
81
|
-
gather_output=False,
|
|
82
|
-
init_method=lambda x: x,
|
|
83
|
-
)
|
|
84
|
-
self.c_proj = RowParallelLinear(
|
|
85
|
-
hidden_dim,
|
|
86
|
-
dim,
|
|
87
|
-
bias=True,
|
|
88
|
-
input_is_parallel=True,
|
|
89
|
-
init_method=lambda x: x,
|
|
90
|
-
)
|
|
91
|
-
self.non_linearity = act_layer()
|
|
92
|
-
self.dropout = dropout
|
|
93
|
-
|
|
94
|
-
def forward(self, x):
|
|
95
|
-
hidden = self.c_fc(x)
|
|
96
|
-
hidden = self.non_linearity(hidden)
|
|
97
|
-
hidden = F.dropout(hidden, p=self.dropout, training=self.training)
|
|
98
|
-
return self.c_proj(hidden)
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
class _TransformerBlock(nn.Module):
|
|
102
|
-
def __init__(
|
|
103
|
-
self,
|
|
104
|
-
d_model: int,
|
|
105
|
-
n_head: int,
|
|
106
|
-
mlp_ratio: float = 4.0,
|
|
107
|
-
act_layer: Callable = nn.GELU,
|
|
108
|
-
gated: bool = False,
|
|
109
|
-
):
|
|
110
|
-
super().__init__()
|
|
111
|
-
assert d_model % n_head == 0
|
|
112
|
-
self.n_heads = n_head
|
|
113
|
-
self.head_dim = d_model // self.n_heads
|
|
114
|
-
|
|
115
|
-
attn_args = ModelArgs(
|
|
116
|
-
dim=d_model,
|
|
117
|
-
head_dim=self.head_dim,
|
|
118
|
-
n_heads=self.n_heads,
|
|
119
|
-
n_kv_heads=self.n_heads,
|
|
120
|
-
)
|
|
121
|
-
self.attn = Attention(attn_args, use_rope=True, use_qk_norm=False, add_bias=True)
|
|
122
|
-
self.ln_1 = LayerNorm(d_model)
|
|
123
|
-
self.mlp = _FeedForward(
|
|
124
|
-
dim=d_model,
|
|
125
|
-
hidden_dim=int(mlp_ratio * d_model),
|
|
126
|
-
dropout=0.0,
|
|
127
|
-
act_layer=act_layer,
|
|
128
|
-
)
|
|
129
|
-
self.ln_2 = LayerNorm(d_model)
|
|
130
|
-
self.gated = gated
|
|
131
|
-
if gated:
|
|
132
|
-
self.gate_attn = nn.Parameter(torch.zeros(1))
|
|
133
|
-
self.gate_ffn = nn.Parameter(torch.zeros(1))
|
|
134
|
-
|
|
135
|
-
def attention(
|
|
136
|
-
self,
|
|
137
|
-
x: torch.Tensor,
|
|
138
|
-
freq_cis: torch.Tensor | None = None,
|
|
139
|
-
):
|
|
140
|
-
return self.attn(x=x, start_pos=0, freqs_cis=freq_cis)
|
|
141
|
-
|
|
142
|
-
def forward(
|
|
143
|
-
self,
|
|
144
|
-
x: torch.Tensor,
|
|
145
|
-
mask: torch.Tensor | None = None,
|
|
146
|
-
freq_cis: torch.Tensor | None = None,
|
|
147
|
-
):
|
|
148
|
-
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
|
|
149
|
-
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
|
|
150
|
-
|
|
151
|
-
x = x + _gate_attn * self.attention(self.ln_1(x), freq_cis=freq_cis)
|
|
152
|
-
x = x + _gate_ffn * self.mlp(self.ln_2(x))
|
|
153
|
-
return x
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
class _Transformer(nn.Module):
|
|
157
|
-
def __init__(
|
|
158
|
-
self,
|
|
159
|
-
dim: int,
|
|
160
|
-
layers: int,
|
|
161
|
-
heads: int,
|
|
162
|
-
mlp_ratio: float = 4.0,
|
|
163
|
-
act_layer: Callable = nn.GELU,
|
|
164
|
-
gated: bool = False,
|
|
165
|
-
):
|
|
166
|
-
super().__init__()
|
|
167
|
-
self.resblocks = nn.ModuleList(
|
|
168
|
-
[
|
|
169
|
-
_TransformerBlock(
|
|
170
|
-
d_model=dim,
|
|
171
|
-
n_head=heads,
|
|
172
|
-
mlp_ratio=mlp_ratio,
|
|
173
|
-
act_layer=act_layer,
|
|
174
|
-
gated=gated,
|
|
175
|
-
)
|
|
176
|
-
for _ in range(layers)
|
|
177
|
-
]
|
|
178
|
-
)
|
|
179
|
-
|
|
180
|
-
def forward(self, x: torch.Tensor, return_intermediate=None, mask=None, freq_cis=None):
|
|
181
|
-
out = []
|
|
182
|
-
for idx, r in enumerate(self.resblocks):
|
|
183
|
-
if return_intermediate is not None and idx in return_intermediate:
|
|
184
|
-
out.append(x)
|
|
185
|
-
x = r(x, mask=mask, freq_cis=freq_cis)
|
|
186
|
-
if return_intermediate is not None:
|
|
187
|
-
return x, torch.stack(out, dim=-1)
|
|
188
|
-
return x
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
class PackingIndex:
|
|
192
|
-
Z = 0 # Z (time) coordinate of the token in the original sample
|
|
193
|
-
Y = 1 # Y (height) coordinate of the token in the original sample
|
|
194
|
-
X = 2 # X (width) coordinate of the token in the original sample
|
|
195
|
-
TIME = 3 # Total number of time units (frames) in the original sample
|
|
196
|
-
HEIGHT = 4 # Height of the original sample
|
|
197
|
-
WIDTH = 5 # Width of the original sample
|
|
198
|
-
# USE INDEX TO CHECK THE TYPE OF THE TOKEN (see ID fields below)
|
|
199
|
-
IDX = 6 # Full index of the token in the original sample (x + y * w + z * w * h)
|
|
200
|
-
BATCH_IDX = 7 # Which batch element this token belongs to. Note the batch idx of padding tokens is BATCH_SIZE
|
|
201
|
-
|
|
202
|
-
# Total size of the enum, remember to update this!
|
|
203
|
-
NUM_METADATA = 8
|
|
204
|
-
|
|
205
|
-
# Note: For padding tokens IDX = -1
|
|
206
|
-
# For cls tokens, IDX = -2
|
|
207
|
-
ID_CLS_TOKEN = -2
|
|
208
|
-
ID_PAD_TOKEN = -1
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
class VisionEncoder(nn.Module):
|
|
212
|
-
def __init__(
|
|
213
|
-
self,
|
|
214
|
-
image_size: tuple[int, int],
|
|
215
|
-
patch_size: tuple[int, int],
|
|
216
|
-
dim: int,
|
|
217
|
-
layers: int,
|
|
218
|
-
heads: int,
|
|
219
|
-
mlp_ratio: float,
|
|
220
|
-
in_channels: int = 3,
|
|
221
|
-
):
|
|
222
|
-
super().__init__()
|
|
223
|
-
self.image_size = image_size
|
|
224
|
-
self.patch_size = patch_size
|
|
225
|
-
self.grid_size = (
|
|
226
|
-
self.image_size[0] // self.patch_size[0],
|
|
227
|
-
self.image_size[1] // self.patch_size[1],
|
|
228
|
-
)
|
|
229
|
-
self.conv1 = ColumnParallelConv2dPatch(
|
|
230
|
-
in_channels=in_channels,
|
|
231
|
-
out_channels=dim,
|
|
232
|
-
kernel_size=patch_size,
|
|
233
|
-
stride=patch_size,
|
|
234
|
-
bias=False,
|
|
235
|
-
)
|
|
236
|
-
scale = dim**-0.5
|
|
237
|
-
self.class_embedding = nn.Parameter(scale * torch.randn(dim))
|
|
238
|
-
|
|
239
|
-
self.positional_embedding_vlm = nn.Parameter(
|
|
240
|
-
scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, dim)
|
|
241
|
-
)
|
|
242
|
-
|
|
243
|
-
self.ln_pre = LayerNorm(dim)
|
|
244
|
-
self.ln_post = LayerNorm(dim)
|
|
245
|
-
self.transformer = _Transformer(
|
|
246
|
-
dim,
|
|
247
|
-
layers,
|
|
248
|
-
heads,
|
|
249
|
-
mlp_ratio,
|
|
250
|
-
act_layer=nn.GELU,
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
# NOTE: hack for the fixed res
|
|
254
|
-
image_h, image_w = self.image_size
|
|
255
|
-
patch_h, patch_w = self.patch_size
|
|
256
|
-
idx_h, idx_w = image_h // patch_h, image_w // patch_w
|
|
257
|
-
img_idx = torch.arange(image_h * image_w // (patch_h * patch_w), dtype=torch.int32)
|
|
258
|
-
img_idx = img_idx.reshape(idx_h * idx_w, 1)
|
|
259
|
-
img_idx = torch.cat([img_idx, img_idx[:1]], dim=0)
|
|
260
|
-
img_idx[-1, -1] = PackingIndex.ID_CLS_TOKEN
|
|
261
|
-
|
|
262
|
-
packed_img_idx = torch.empty(
|
|
263
|
-
img_idx.shape[0],
|
|
264
|
-
img_idx.shape[1],
|
|
265
|
-
PackingIndex.NUM_METADATA - 1,
|
|
266
|
-
dtype=torch.int32,
|
|
267
|
-
)
|
|
268
|
-
packed_img_idx[:, :, PackingIndex.Y] = img_idx // idx_w
|
|
269
|
-
packed_img_idx[:, :, PackingIndex.X] = img_idx % idx_w
|
|
270
|
-
packed_img_idx[:, :, PackingIndex.HEIGHT].fill_(idx_h)
|
|
271
|
-
packed_img_idx[:, :, PackingIndex.WIDTH].fill_(idx_w)
|
|
272
|
-
packed_img_idx[:, :, PackingIndex.IDX] = img_idx
|
|
273
|
-
packed_img_idx = packed_img_idx.reshape(1, -1, PackingIndex.NUM_METADATA - 1)
|
|
274
|
-
self.packed_img_idx = packed_img_idx # for positional embedding load hook
|
|
275
|
-
|
|
276
|
-
# compute rope freqs
|
|
277
|
-
rope_freq = self.get_rope_freqs(dim // heads // 2)
|
|
278
|
-
freqs_x = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.X] + 1)
|
|
279
|
-
freqs_y = self.compute_rope_freqs(rope_freq, packed_img_idx[:, :, PackingIndex.Y] + 1)
|
|
280
|
-
freqs = torch.cat([freqs_x, freqs_y], dim=-1).float().contiguous()[..., ::2]
|
|
281
|
-
# disable RoPE for padding and cls tokens
|
|
282
|
-
freqs = freqs.masked_fill(packed_img_idx[:, :, PackingIndex.IDX, None] < 0, 0)
|
|
283
|
-
# compute complex freqs
|
|
284
|
-
self.freq_cis = torch.view_as_complex(torch.stack([torch.cos(freqs), torch.sin(freqs)], dim=-1))
|
|
285
|
-
# xlf automatically broadcasts
|
|
286
|
-
self.freq_cis = self.freq_cis.squeeze(0)
|
|
287
|
-
self.n_heads = heads // fs_init.get_model_parallel_world_size()
|
|
288
|
-
|
|
289
|
-
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
290
|
-
|
|
291
|
-
def get_rope_freqs(self, dim, theta=10000):
|
|
292
|
-
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
|
293
|
-
return freqs
|
|
294
|
-
|
|
295
|
-
@torch.amp.autocast("cuda", enabled=False)
|
|
296
|
-
def compute_rope_freqs(self, freqs, t):
|
|
297
|
-
freqs = einsum("..., f -> ... f", t.type(freqs.dtype), freqs)
|
|
298
|
-
freqs = freqs.repeat_interleave(2, dim=-1)
|
|
299
|
-
return freqs
|
|
300
|
-
|
|
301
|
-
def load_hook(
|
|
302
|
-
self,
|
|
303
|
-
state_dict: dict[str, Any],
|
|
304
|
-
prefix: str,
|
|
305
|
-
local_metadata: dict[str, Any],
|
|
306
|
-
strict: bool = True,
|
|
307
|
-
missing_keys: list[str] = None,
|
|
308
|
-
unexpected_keys: list[str] = None,
|
|
309
|
-
error_msgs: list[str] = None,
|
|
310
|
-
return_state_dict: bool = False,
|
|
311
|
-
) -> None:
|
|
312
|
-
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
|
313
|
-
if orig_pos_embed is not None and orig_pos_embed.shape[-2:] != self.positional_embedding_vlm.shape[-2:]:
|
|
314
|
-
raise ValueError(
|
|
315
|
-
f"Positional embedding shape {orig_pos_embed.shape} does not match expected shape {self.positional_embedding_vlm.shape}"
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
batch_size, token_per_image, _ = self.packed_img_idx.shape
|
|
319
|
-
# Input points for idx are [x, y, w, h]
|
|
320
|
-
idx = self.packed_img_idx.reshape(batch_size * token_per_image, 1, -1)
|
|
321
|
-
total_windows, window_size, _ = idx.shape
|
|
322
|
-
|
|
323
|
-
# Grid values are [-1, 1] and coords are w, h
|
|
324
|
-
grid = (
|
|
325
|
-
(idx[:, :, [PackingIndex.X, PackingIndex.Y]] / idx[:, :, [PackingIndex.WIDTH, PackingIndex.HEIGHT]]) * 2 - 1
|
|
326
|
-
)[None, ...]
|
|
327
|
-
|
|
328
|
-
# In this mode, cls token has no position embedding
|
|
329
|
-
if orig_pos_embed is not None:
|
|
330
|
-
posemb = (
|
|
331
|
-
orig_pos_embed[1:].view(1, self.grid_size[0], self.grid_size[1], -1).permute(0, 3, 1, 2).contiguous()
|
|
332
|
-
)
|
|
333
|
-
posemb = posemb.to(device=grid.device, dtype=grid.dtype)
|
|
334
|
-
sample = F.grid_sample(
|
|
335
|
-
posemb, grid, padding_mode="zeros"
|
|
336
|
-
) # padding tokens / class token will get zero for posemb
|
|
337
|
-
sample = sample.view(-1, total_windows, window_size).permute(1, 2, 0).contiguous()
|
|
338
|
-
sample = torch.where(
|
|
339
|
-
idx[:, :, PackingIndex.IDX, None] == PackingIndex.ID_CLS_TOKEN,
|
|
340
|
-
orig_pos_embed[0].view(1, 1, -1).to(device=sample.device, dtype=sample.dtype),
|
|
341
|
-
sample,
|
|
342
|
-
)
|
|
343
|
-
|
|
344
|
-
new_pos_embed = sample.reshape(batch_size, token_per_image, -1)
|
|
345
|
-
|
|
346
|
-
state_dict[prefix + "positional_embedding_vlm"] = new_pos_embed.squeeze(0)
|
|
347
|
-
|
|
348
|
-
if return_state_dict:
|
|
349
|
-
return state_dict
|
|
350
|
-
|
|
351
|
-
def apply_class_embedding(self, x):
|
|
352
|
-
x = torch.cat(
|
|
353
|
-
[
|
|
354
|
-
x,
|
|
355
|
-
self.class_embedding.to(x.dtype)
|
|
356
|
-
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
|
357
|
-
],
|
|
358
|
-
dim=1,
|
|
359
|
-
) # shape = [*, grid ** 2 + 1, width]
|
|
360
|
-
return x
|
|
361
|
-
|
|
362
|
-
def forward(self, images: torch.Tensor) -> torch.Tensor:
|
|
363
|
-
# NOTE: in Llama4 bsz=bsz*num_tiles, num_chunks=1
|
|
364
|
-
if images.ndim == 5:
|
|
365
|
-
num_concurrent_media = 1
|
|
366
|
-
bsz, num_chunks, nch, h, w = images.shape
|
|
367
|
-
else:
|
|
368
|
-
bsz, num_concurrent_media, num_chunks, nch, h, w = images.shape
|
|
369
|
-
|
|
370
|
-
images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
|
|
371
|
-
# patch embedding
|
|
372
|
-
x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, h, w)
|
|
373
|
-
x = self.conv1(x) # shape = [*, width, grid ** 2]
|
|
374
|
-
_, ntok, dim = x.shape
|
|
375
|
-
x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
|
|
376
|
-
|
|
377
|
-
# apply cls token
|
|
378
|
-
x = self.apply_class_embedding(x)
|
|
379
|
-
ntok += 1
|
|
380
|
-
|
|
381
|
-
# apply position embeddings
|
|
382
|
-
if self.positional_embedding_vlm is not None:
|
|
383
|
-
x = x + self.positional_embedding_vlm.to(x.dtype)
|
|
384
|
-
|
|
385
|
-
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
|
|
386
|
-
|
|
387
|
-
x = self.ln_pre(x)
|
|
388
|
-
x = x.view(bsz * num_concurrent_media, -1, dim)
|
|
389
|
-
freq_cis = self.freq_cis.to(images.device)
|
|
390
|
-
|
|
391
|
-
tf_output = self.transformer(
|
|
392
|
-
x,
|
|
393
|
-
freq_cis=freq_cis,
|
|
394
|
-
)
|
|
395
|
-
|
|
396
|
-
int_x = None
|
|
397
|
-
if isinstance(tf_output, tuple):
|
|
398
|
-
x, int_x = tf_output
|
|
399
|
-
else:
|
|
400
|
-
x = tf_output
|
|
401
|
-
x = self.ln_post(x)
|
|
402
|
-
|
|
403
|
-
# remove cls token output
|
|
404
|
-
x = x[:, :-1, :]
|
|
405
|
-
|
|
406
|
-
# add and output x + int_x features
|
|
407
|
-
if int_x is not None:
|
|
408
|
-
int_x = int_x[:, :-1, :, :]
|
|
409
|
-
int_x = int_x.reshape(bsz * num_concurrent_media, ntok - 1, -1)
|
|
410
|
-
x = torch.cat([x, int_x], dim=-1)
|
|
411
|
-
|
|
412
|
-
return x
|