llama-stack 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llama_stack/cli/stack/_list_deps.py +11 -7
- llama_stack/cli/stack/run.py +3 -25
- llama_stack/core/access_control/datatypes.py +78 -0
- llama_stack/core/configure.py +2 -2
- llama_stack/{distributions/meta-reference-gpu → core/connectors}/__init__.py +3 -1
- llama_stack/core/connectors/connectors.py +162 -0
- llama_stack/core/conversations/conversations.py +61 -58
- llama_stack/core/datatypes.py +54 -8
- llama_stack/core/library_client.py +60 -13
- llama_stack/core/prompts/prompts.py +43 -42
- llama_stack/core/routers/datasets.py +20 -17
- llama_stack/core/routers/eval_scoring.py +143 -53
- llama_stack/core/routers/inference.py +20 -9
- llama_stack/core/routers/safety.py +30 -42
- llama_stack/core/routers/vector_io.py +15 -7
- llama_stack/core/routing_tables/models.py +42 -3
- llama_stack/core/routing_tables/scoring_functions.py +19 -19
- llama_stack/core/routing_tables/shields.py +20 -17
- llama_stack/core/routing_tables/vector_stores.py +8 -5
- llama_stack/core/server/auth.py +192 -17
- llama_stack/core/server/fastapi_router_registry.py +40 -5
- llama_stack/core/server/server.py +24 -5
- llama_stack/core/stack.py +54 -10
- llama_stack/core/storage/datatypes.py +9 -0
- llama_stack/core/store/registry.py +1 -1
- llama_stack/core/utils/exec.py +2 -2
- llama_stack/core/utils/type_inspection.py +16 -2
- llama_stack/distributions/dell/config.yaml +4 -1
- llama_stack/distributions/dell/run-with-safety.yaml +4 -1
- llama_stack/distributions/nvidia/config.yaml +4 -1
- llama_stack/distributions/nvidia/run-with-safety.yaml +4 -1
- llama_stack/distributions/oci/config.yaml +4 -1
- llama_stack/distributions/open-benchmark/config.yaml +9 -1
- llama_stack/distributions/postgres-demo/config.yaml +1 -1
- llama_stack/distributions/starter/build.yaml +62 -0
- llama_stack/distributions/starter/config.yaml +22 -3
- llama_stack/distributions/starter/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/starter/starter.py +13 -1
- llama_stack/distributions/starter-gpu/build.yaml +62 -0
- llama_stack/distributions/starter-gpu/config.yaml +22 -3
- llama_stack/distributions/starter-gpu/run-with-postgres-store.yaml +22 -3
- llama_stack/distributions/template.py +10 -2
- llama_stack/distributions/watsonx/config.yaml +4 -1
- llama_stack/log.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/__init__.py +1 -0
- llama_stack/providers/inline/agents/meta_reference/agents.py +58 -61
- llama_stack/providers/inline/agents/meta_reference/responses/openai_responses.py +53 -51
- llama_stack/providers/inline/agents/meta_reference/responses/streaming.py +99 -22
- llama_stack/providers/inline/agents/meta_reference/responses/types.py +2 -1
- llama_stack/providers/inline/agents/meta_reference/responses/utils.py +4 -1
- llama_stack/providers/inline/agents/meta_reference/safety.py +2 -2
- llama_stack/providers/inline/batches/reference/batches.py +2 -1
- llama_stack/providers/inline/eval/meta_reference/eval.py +40 -32
- llama_stack/providers/inline/post_training/huggingface/post_training.py +33 -38
- llama_stack/providers/inline/post_training/huggingface/utils.py +2 -5
- llama_stack/providers/inline/post_training/torchtune/common/utils.py +5 -9
- llama_stack/providers/inline/post_training/torchtune/post_training.py +28 -33
- llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +2 -4
- llama_stack/providers/inline/safety/code_scanner/code_scanner.py +12 -15
- llama_stack/providers/inline/safety/llama_guard/llama_guard.py +20 -24
- llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +11 -17
- llama_stack/providers/inline/scoring/basic/scoring.py +13 -17
- llama_stack/providers/inline/scoring/braintrust/braintrust.py +15 -15
- llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +13 -17
- llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +1 -1
- llama_stack/providers/registry/agents.py +1 -0
- llama_stack/providers/registry/inference.py +1 -9
- llama_stack/providers/registry/vector_io.py +136 -16
- llama_stack/providers/remote/eval/nvidia/eval.py +22 -21
- llama_stack/providers/remote/files/s3/config.py +5 -3
- llama_stack/providers/remote/files/s3/files.py +2 -2
- llama_stack/providers/remote/inference/gemini/gemini.py +4 -0
- llama_stack/providers/remote/inference/openai/openai.py +2 -0
- llama_stack/providers/remote/inference/together/together.py +4 -0
- llama_stack/providers/remote/inference/vertexai/config.py +3 -3
- llama_stack/providers/remote/inference/vertexai/vertexai.py +5 -2
- llama_stack/providers/remote/inference/vllm/config.py +37 -18
- llama_stack/providers/remote/inference/vllm/vllm.py +0 -3
- llama_stack/providers/remote/inference/watsonx/watsonx.py +4 -0
- llama_stack/providers/remote/post_training/nvidia/models.py +3 -11
- llama_stack/providers/remote/post_training/nvidia/post_training.py +31 -33
- llama_stack/providers/remote/safety/bedrock/bedrock.py +10 -27
- llama_stack/providers/remote/safety/nvidia/nvidia.py +9 -25
- llama_stack/providers/remote/safety/sambanova/sambanova.py +13 -11
- llama_stack/providers/remote/vector_io/elasticsearch/__init__.py +17 -0
- llama_stack/providers/remote/vector_io/elasticsearch/config.py +32 -0
- llama_stack/providers/remote/vector_io/elasticsearch/elasticsearch.py +463 -0
- llama_stack/providers/remote/vector_io/oci/__init__.py +22 -0
- llama_stack/providers/remote/vector_io/oci/config.py +41 -0
- llama_stack/providers/remote/vector_io/oci/oci26ai.py +595 -0
- llama_stack/providers/remote/vector_io/pgvector/config.py +69 -2
- llama_stack/providers/remote/vector_io/pgvector/pgvector.py +255 -6
- llama_stack/providers/remote/vector_io/qdrant/qdrant.py +62 -38
- llama_stack/providers/utils/bedrock/client.py +3 -3
- llama_stack/providers/utils/bedrock/config.py +7 -7
- llama_stack/providers/utils/inference/__init__.py +0 -25
- llama_stack/providers/utils/inference/embedding_mixin.py +4 -0
- llama_stack/providers/utils/inference/http_client.py +239 -0
- llama_stack/providers/utils/inference/litellm_openai_mixin.py +6 -0
- llama_stack/providers/utils/inference/model_registry.py +148 -2
- llama_stack/providers/utils/inference/openai_compat.py +1 -158
- llama_stack/providers/utils/inference/openai_mixin.py +42 -2
- llama_stack/providers/utils/inference/prompt_adapter.py +0 -209
- llama_stack/providers/utils/memory/openai_vector_store_mixin.py +92 -5
- llama_stack/providers/utils/memory/vector_store.py +46 -19
- llama_stack/providers/utils/responses/responses_store.py +7 -7
- llama_stack/providers/utils/safety.py +114 -0
- llama_stack/providers/utils/tools/mcp.py +44 -3
- llama_stack/testing/api_recorder.py +9 -3
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/METADATA +14 -2
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/RECORD +115 -148
- llama_stack/distributions/meta-reference-gpu/config.yaml +0 -140
- llama_stack/distributions/meta-reference-gpu/doc_template.md +0 -119
- llama_stack/distributions/meta-reference-gpu/meta_reference.py +0 -163
- llama_stack/distributions/meta-reference-gpu/run-with-safety.yaml +0 -155
- llama_stack/models/llama/hadamard_utils.py +0 -88
- llama_stack/models/llama/llama3/args.py +0 -74
- llama_stack/models/llama/llama3/dog.jpg +0 -0
- llama_stack/models/llama/llama3/generation.py +0 -378
- llama_stack/models/llama/llama3/model.py +0 -304
- llama_stack/models/llama/llama3/multimodal/__init__.py +0 -12
- llama_stack/models/llama/llama3/multimodal/encoder_utils.py +0 -180
- llama_stack/models/llama/llama3/multimodal/image_transform.py +0 -409
- llama_stack/models/llama/llama3/multimodal/model.py +0 -1430
- llama_stack/models/llama/llama3/multimodal/utils.py +0 -26
- llama_stack/models/llama/llama3/pasta.jpeg +0 -0
- llama_stack/models/llama/llama3/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama3/quantization/loader.py +0 -316
- llama_stack/models/llama/llama3_1/__init__.py +0 -12
- llama_stack/models/llama/llama3_1/prompt_format.md +0 -358
- llama_stack/models/llama/llama3_1/prompts.py +0 -258
- llama_stack/models/llama/llama3_2/__init__.py +0 -5
- llama_stack/models/llama/llama3_2/prompts_text.py +0 -229
- llama_stack/models/llama/llama3_2/prompts_vision.py +0 -126
- llama_stack/models/llama/llama3_2/text_prompt_format.md +0 -286
- llama_stack/models/llama/llama3_2/vision_prompt_format.md +0 -141
- llama_stack/models/llama/llama3_3/__init__.py +0 -5
- llama_stack/models/llama/llama3_3/prompts.py +0 -259
- llama_stack/models/llama/llama4/args.py +0 -107
- llama_stack/models/llama/llama4/ffn.py +0 -58
- llama_stack/models/llama/llama4/moe.py +0 -214
- llama_stack/models/llama/llama4/preprocess.py +0 -435
- llama_stack/models/llama/llama4/quantization/__init__.py +0 -5
- llama_stack/models/llama/llama4/quantization/loader.py +0 -226
- llama_stack/models/llama/llama4/vision/__init__.py +0 -5
- llama_stack/models/llama/llama4/vision/embedding.py +0 -210
- llama_stack/models/llama/llama4/vision/encoder.py +0 -412
- llama_stack/models/llama/quantize_impls.py +0 -316
- llama_stack/providers/inline/inference/meta_reference/__init__.py +0 -20
- llama_stack/providers/inline/inference/meta_reference/common.py +0 -24
- llama_stack/providers/inline/inference/meta_reference/config.py +0 -68
- llama_stack/providers/inline/inference/meta_reference/generators.py +0 -201
- llama_stack/providers/inline/inference/meta_reference/inference.py +0 -542
- llama_stack/providers/inline/inference/meta_reference/model_parallel.py +0 -77
- llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +0 -353
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/WHEEL +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/entry_points.txt +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/licenses/LICENSE +0 -0
- {llama_stack-0.4.4.dist-info → llama_stack-0.5.0.dist-info}/top_level.txt +0 -0
|
@@ -1,1430 +0,0 @@
|
|
|
1
|
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
-
# All rights reserved.
|
|
3
|
-
#
|
|
4
|
-
# This source code is licensed under the terms described in the LICENSE file in
|
|
5
|
-
# the root directory of this source tree.
|
|
6
|
-
import math
|
|
7
|
-
from collections.abc import Callable
|
|
8
|
-
from functools import partial
|
|
9
|
-
from typing import Any
|
|
10
|
-
|
|
11
|
-
import fairscale.nn.model_parallel.initialize as fs_init
|
|
12
|
-
import torch
|
|
13
|
-
import torch.nn.functional as F
|
|
14
|
-
from fairscale.nn.model_parallel.layers import (
|
|
15
|
-
ColumnParallelLinear,
|
|
16
|
-
RowParallelLinear,
|
|
17
|
-
VocabParallelEmbedding,
|
|
18
|
-
)
|
|
19
|
-
from PIL import Image as PIL_Image
|
|
20
|
-
from torch import Tensor, nn
|
|
21
|
-
from torch.distributed import _functional_collectives as funcol
|
|
22
|
-
|
|
23
|
-
from llama_stack.log import get_logger
|
|
24
|
-
|
|
25
|
-
from ..model import ModelArgs, RMSNorm, apply_rotary_emb, precompute_freqs_cis
|
|
26
|
-
from .encoder_utils import (
|
|
27
|
-
build_encoder_attention_mask,
|
|
28
|
-
contract_num_tokens_from_mult8,
|
|
29
|
-
expand_num_tokens_to_mult8,
|
|
30
|
-
initialize_global_position_embedding_from_local,
|
|
31
|
-
resize_global_position_embedding,
|
|
32
|
-
resize_local_position_embedding,
|
|
33
|
-
)
|
|
34
|
-
from .image_transform import VariableSizeImageTransform
|
|
35
|
-
from .utils import get_negative_inf_value, to_2tuple
|
|
36
|
-
|
|
37
|
-
MP_SCALE = 8
|
|
38
|
-
|
|
39
|
-
logger = get_logger(name=__name__, category="models::llama")
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def reduce_from_tensor_model_parallel_region(input_):
|
|
43
|
-
"""All-reduce the input tensor across model parallel group."""
|
|
44
|
-
output = funcol.all_reduce(input_, "sum", group=fs_init.get_model_parallel_group())
|
|
45
|
-
output = funcol.wait_tensor(output)
|
|
46
|
-
return output
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
def gather_from_tensor_model_parallel_region(input_):
|
|
50
|
-
"""Gather tensors and concatenate along the last dimension."""
|
|
51
|
-
|
|
52
|
-
world_size = fs_init.get_model_parallel_world_size()
|
|
53
|
-
# Size and dimension.
|
|
54
|
-
last_dim = input_.dim() - 1
|
|
55
|
-
rank = fs_init.get_model_parallel_rank()
|
|
56
|
-
|
|
57
|
-
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
|
58
|
-
tensor_list[rank] = input_
|
|
59
|
-
output = funcol.all_gather_tensor(
|
|
60
|
-
input_,
|
|
61
|
-
gather_dim=last_dim,
|
|
62
|
-
group=fs_init.get_model_parallel_group(),
|
|
63
|
-
)
|
|
64
|
-
output = funcol.wait_tensor(output)
|
|
65
|
-
return output
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
def _get_full_row_masked_out_mask(
|
|
69
|
-
attn_bias,
|
|
70
|
-
negative_inf_value,
|
|
71
|
-
):
|
|
72
|
-
"""
|
|
73
|
-
attn_bias should be a 4D tensor of shape [B, H, S1, S2]
|
|
74
|
-
where B is the batch size, H is the number of heads,
|
|
75
|
-
and S1/S2 are the sequence lengths. This returns
|
|
76
|
-
a 4D tensor of shape [B, H, S1, 1] which stores boolean
|
|
77
|
-
values which are 0 if the a full row in the last dimension
|
|
78
|
-
contains negative infinity values, otherwise it's 1.
|
|
79
|
-
"""
|
|
80
|
-
return (attn_bias != negative_inf_value).any(dim=-1).type_as(attn_bias)[..., None]
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
# Image encoder for inference
|
|
84
|
-
class LayerNorm(nn.LayerNorm):
|
|
85
|
-
"""Subclass torch's LayerNorm to handle fp16."""
|
|
86
|
-
|
|
87
|
-
def forward(self, x: torch.Tensor):
|
|
88
|
-
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
89
|
-
return x
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
class ColumnParallelConv2dPatch(torch.nn.Module):
|
|
93
|
-
"""Conv2D Patching layer with model parallelism.
|
|
94
|
-
Column parallel over unfolded input.
|
|
95
|
-
Arguments:
|
|
96
|
-
in_channels: Input channels.
|
|
97
|
-
out_channels: Output channels.
|
|
98
|
-
kernel_size: Size of convolution kernel.
|
|
99
|
-
stride (default 1): Stride for convolution.
|
|
100
|
-
bias (default False): Use bias in Conv2d.
|
|
101
|
-
Input: (bsz, in_channels, width, height)
|
|
102
|
-
Output: (bsz, num_tokens, out_channels)
|
|
103
|
-
"""
|
|
104
|
-
|
|
105
|
-
def __init__(
|
|
106
|
-
self,
|
|
107
|
-
in_channels: int,
|
|
108
|
-
out_channels: int,
|
|
109
|
-
kernel_size: int | tuple[int, int],
|
|
110
|
-
stride: int | tuple[int, int],
|
|
111
|
-
bias: bool | None = False,
|
|
112
|
-
) -> None:
|
|
113
|
-
super().__init__()
|
|
114
|
-
if isinstance(kernel_size, int):
|
|
115
|
-
kernel_size = (kernel_size, kernel_size)
|
|
116
|
-
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
|
|
117
|
-
self._linear = ColumnParallelLinear(
|
|
118
|
-
in_channels * kernel_size[0] * kernel_size[1],
|
|
119
|
-
out_channels,
|
|
120
|
-
bias=bias,
|
|
121
|
-
)
|
|
122
|
-
|
|
123
|
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
124
|
-
x = self._unfold(x)
|
|
125
|
-
x = x.permute(0, 2, 1)
|
|
126
|
-
x = F.linear(x, self._linear.weight)
|
|
127
|
-
x = gather_from_tensor_model_parallel_region(x)
|
|
128
|
-
return x
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
class ImageFeedForward(torch.nn.Module):
|
|
132
|
-
def __init__(
|
|
133
|
-
self,
|
|
134
|
-
dim: int,
|
|
135
|
-
hidden_dim: int,
|
|
136
|
-
dropout: float,
|
|
137
|
-
act_layer: Callable = nn.GELU,
|
|
138
|
-
):
|
|
139
|
-
super().__init__()
|
|
140
|
-
# layers
|
|
141
|
-
self.c_fc = ColumnParallelLinear(
|
|
142
|
-
dim,
|
|
143
|
-
hidden_dim,
|
|
144
|
-
bias=True,
|
|
145
|
-
gather_output=False,
|
|
146
|
-
init_method=lambda x: x,
|
|
147
|
-
)
|
|
148
|
-
self.c_proj = RowParallelLinear(
|
|
149
|
-
hidden_dim,
|
|
150
|
-
dim,
|
|
151
|
-
bias=True,
|
|
152
|
-
input_is_parallel=True,
|
|
153
|
-
init_method=lambda x: x,
|
|
154
|
-
)
|
|
155
|
-
self.non_linearity = act_layer()
|
|
156
|
-
self.dropout = dropout
|
|
157
|
-
|
|
158
|
-
def forward(self, x):
|
|
159
|
-
hidden = F.linear(x, self.c_fc.weight, self.c_fc.bias)
|
|
160
|
-
hidden = self.non_linearity(hidden)
|
|
161
|
-
hidden = F.linear(hidden, self.c_proj.weight)
|
|
162
|
-
hidden = reduce_from_tensor_model_parallel_region(hidden)
|
|
163
|
-
hidden += self.c_proj.bias
|
|
164
|
-
return hidden
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
class ImageAttention(nn.Module):
|
|
168
|
-
def __init__(
|
|
169
|
-
self,
|
|
170
|
-
dim,
|
|
171
|
-
head_dim,
|
|
172
|
-
n_heads,
|
|
173
|
-
):
|
|
174
|
-
super().__init__()
|
|
175
|
-
world_size = fs_init.get_model_parallel_world_size()
|
|
176
|
-
qkvo_replication = 1
|
|
177
|
-
if world_size > 16:
|
|
178
|
-
qkvo_replication = world_size // 8
|
|
179
|
-
|
|
180
|
-
self.n_kv_heads = n_heads
|
|
181
|
-
self.n_local_heads = n_heads * qkvo_replication // world_size
|
|
182
|
-
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
|
|
183
|
-
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
|
184
|
-
self.head_dim = dim // n_heads
|
|
185
|
-
|
|
186
|
-
self.wq = ColumnParallelLinear(
|
|
187
|
-
dim,
|
|
188
|
-
qkvo_replication * n_heads * self.head_dim,
|
|
189
|
-
bias=False,
|
|
190
|
-
gather_output=False,
|
|
191
|
-
init_method=lambda x: x,
|
|
192
|
-
)
|
|
193
|
-
self.wk = ColumnParallelLinear(
|
|
194
|
-
dim,
|
|
195
|
-
qkvo_replication * self.n_kv_heads * self.head_dim,
|
|
196
|
-
bias=False,
|
|
197
|
-
gather_output=False,
|
|
198
|
-
init_method=lambda x: x,
|
|
199
|
-
)
|
|
200
|
-
self.wv = ColumnParallelLinear(
|
|
201
|
-
dim,
|
|
202
|
-
qkvo_replication * self.n_kv_heads * self.head_dim,
|
|
203
|
-
bias=False,
|
|
204
|
-
gather_output=False,
|
|
205
|
-
init_method=lambda x: x,
|
|
206
|
-
)
|
|
207
|
-
self.wo = RowParallelLinear(
|
|
208
|
-
qkvo_replication * n_heads * self.head_dim,
|
|
209
|
-
dim,
|
|
210
|
-
bias=False,
|
|
211
|
-
input_is_parallel=True,
|
|
212
|
-
init_method=lambda x: x,
|
|
213
|
-
)
|
|
214
|
-
self.qkvo_replication = qkvo_replication
|
|
215
|
-
|
|
216
|
-
def forward(
|
|
217
|
-
self,
|
|
218
|
-
x: torch.Tensor,
|
|
219
|
-
mask: torch.Tensor = None,
|
|
220
|
-
):
|
|
221
|
-
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
|
|
222
|
-
|
|
223
|
-
bs, slen, _ = xq.shape
|
|
224
|
-
|
|
225
|
-
xq = xq.view(bs, slen, self.n_local_heads, self.head_dim)
|
|
226
|
-
xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim)
|
|
227
|
-
xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim)
|
|
228
|
-
|
|
229
|
-
xq, xk, xv = [tensor.transpose(1, 2) for tensor in (xq, xk, xv)]
|
|
230
|
-
|
|
231
|
-
xk = xk.repeat_interleave(self.n_rep, dim=1)
|
|
232
|
-
xv = xv.repeat_interleave(self.n_rep, dim=1)
|
|
233
|
-
|
|
234
|
-
attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
|
|
235
|
-
|
|
236
|
-
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1)
|
|
237
|
-
|
|
238
|
-
out = F.linear(attn_output, self.wo.weight)
|
|
239
|
-
out = reduce_from_tensor_model_parallel_region(out)
|
|
240
|
-
out = out / self.qkvo_replication
|
|
241
|
-
return out
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
class ImageTransformerBlock(nn.Module):
|
|
245
|
-
def __init__(
|
|
246
|
-
self,
|
|
247
|
-
d_model: int,
|
|
248
|
-
n_head: int,
|
|
249
|
-
mlp_ratio: float = 4.0,
|
|
250
|
-
act_layer: Callable = nn.GELU,
|
|
251
|
-
gated: bool = False,
|
|
252
|
-
):
|
|
253
|
-
super().__init__()
|
|
254
|
-
assert d_model % n_head == 0
|
|
255
|
-
self.n_heads = n_head
|
|
256
|
-
self.head_dim = d_model // self.n_heads
|
|
257
|
-
self.attn = ImageAttention(
|
|
258
|
-
dim=d_model,
|
|
259
|
-
head_dim=self.head_dim,
|
|
260
|
-
n_heads=self.n_heads,
|
|
261
|
-
)
|
|
262
|
-
self.ln_1 = LayerNorm(d_model)
|
|
263
|
-
self.mlp = ImageFeedForward(
|
|
264
|
-
dim=d_model,
|
|
265
|
-
hidden_dim=int(mlp_ratio * d_model),
|
|
266
|
-
dropout=0.0,
|
|
267
|
-
act_layer=act_layer,
|
|
268
|
-
)
|
|
269
|
-
self.ln_2 = LayerNorm(d_model)
|
|
270
|
-
self.gated = gated
|
|
271
|
-
if gated:
|
|
272
|
-
self.gate_attn = nn.Parameter(torch.zeros(1))
|
|
273
|
-
self.gate_ffn = nn.Parameter(torch.zeros(1))
|
|
274
|
-
|
|
275
|
-
def forward(
|
|
276
|
-
self,
|
|
277
|
-
x: torch.Tensor,
|
|
278
|
-
mask: torch.Tensor = None,
|
|
279
|
-
):
|
|
280
|
-
_gate_attn = 1 if not self.gated else self.gate_attn.tanh()
|
|
281
|
-
_gate_ffn = 1 if not self.gated else self.gate_ffn.tanh()
|
|
282
|
-
x = x + _gate_attn * self.attn(self.ln_1(x), mask=mask)
|
|
283
|
-
x = x + _gate_ffn * self.mlp(self.ln_2(x))
|
|
284
|
-
return x
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
class ImageTransformer(nn.Module):
|
|
288
|
-
def __init__(
|
|
289
|
-
self,
|
|
290
|
-
width: int,
|
|
291
|
-
layers: int,
|
|
292
|
-
heads: int,
|
|
293
|
-
mlp_ratio: float = 4.0,
|
|
294
|
-
act_layer: Callable = nn.GELU,
|
|
295
|
-
gated: bool = False,
|
|
296
|
-
):
|
|
297
|
-
super().__init__()
|
|
298
|
-
self.width = width
|
|
299
|
-
self.layers = layers
|
|
300
|
-
self.resblocks = nn.ModuleList(
|
|
301
|
-
[
|
|
302
|
-
ImageTransformerBlock(
|
|
303
|
-
d_model=width,
|
|
304
|
-
n_head=heads,
|
|
305
|
-
mlp_ratio=mlp_ratio,
|
|
306
|
-
act_layer=act_layer,
|
|
307
|
-
gated=gated,
|
|
308
|
-
)
|
|
309
|
-
for _ in range(self.layers)
|
|
310
|
-
]
|
|
311
|
-
)
|
|
312
|
-
|
|
313
|
-
def forward(self, x: torch.Tensor, return_intermediate=None, mask=None):
|
|
314
|
-
out = []
|
|
315
|
-
for idx, r in enumerate(self.resblocks):
|
|
316
|
-
if return_intermediate is not None and idx in return_intermediate:
|
|
317
|
-
out.append(x)
|
|
318
|
-
x = r(x, mask=mask)
|
|
319
|
-
if return_intermediate is not None:
|
|
320
|
-
return x, torch.stack(out, dim=-1)
|
|
321
|
-
return x
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
class VisionEncoder(nn.Module):
|
|
325
|
-
def __init__(
|
|
326
|
-
self,
|
|
327
|
-
max_num_tiles: int,
|
|
328
|
-
ckpt_path: str = None,
|
|
329
|
-
image_size: int = 224,
|
|
330
|
-
patch_size: int = 14,
|
|
331
|
-
width: int = 1280,
|
|
332
|
-
layers: int = 32,
|
|
333
|
-
heads: int = 16,
|
|
334
|
-
mlp_ratio: float = 4.0,
|
|
335
|
-
act_layer: Callable = nn.GELU,
|
|
336
|
-
in_channels: int = 3,
|
|
337
|
-
load_ckpt: bool = False,
|
|
338
|
-
n_global_layers: int = 2,
|
|
339
|
-
global_model: bool = False,
|
|
340
|
-
return_intermediate=None,
|
|
341
|
-
):
|
|
342
|
-
super().__init__()
|
|
343
|
-
self.global_model = global_model
|
|
344
|
-
self.return_intermediate = return_intermediate
|
|
345
|
-
self.max_num_tiles = max_num_tiles
|
|
346
|
-
self.image_size = to_2tuple(image_size)
|
|
347
|
-
self.patch_size = to_2tuple(patch_size)
|
|
348
|
-
self.grid_size = (
|
|
349
|
-
self.image_size[0] // self.patch_size[0],
|
|
350
|
-
self.image_size[1] // self.patch_size[1],
|
|
351
|
-
)
|
|
352
|
-
self.conv1 = ColumnParallelConv2dPatch(
|
|
353
|
-
in_channels=in_channels,
|
|
354
|
-
out_channels=width,
|
|
355
|
-
kernel_size=patch_size,
|
|
356
|
-
stride=patch_size,
|
|
357
|
-
bias=False,
|
|
358
|
-
)
|
|
359
|
-
scale = width**-0.5
|
|
360
|
-
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
|
361
|
-
self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size[0] * self.grid_size[1] + 1, width))
|
|
362
|
-
self.ln_post = LayerNorm(width)
|
|
363
|
-
self.ln_pre = LayerNorm(width)
|
|
364
|
-
self.transformer = ImageTransformer(width, layers, heads, mlp_ratio, act_layer=act_layer)
|
|
365
|
-
# pre and post tile position embedding
|
|
366
|
-
self.global_transformer = ImageTransformer(
|
|
367
|
-
width, n_global_layers, heads, mlp_ratio, act_layer=act_layer, gated=True
|
|
368
|
-
)
|
|
369
|
-
# pre and post tile position embedding
|
|
370
|
-
self.pre_tile_pos_embed = TilePositionEmbedding(
|
|
371
|
-
num_tiles=max_num_tiles,
|
|
372
|
-
width=width,
|
|
373
|
-
gated=True,
|
|
374
|
-
)
|
|
375
|
-
self.post_tile_pos_embed = TilePositionEmbedding(
|
|
376
|
-
num_tiles=max_num_tiles,
|
|
377
|
-
width=width,
|
|
378
|
-
gated=True,
|
|
379
|
-
)
|
|
380
|
-
self.gated_positional_embedding = nn.Parameter(
|
|
381
|
-
scale
|
|
382
|
-
* torch.randn(
|
|
383
|
-
max_num_tiles,
|
|
384
|
-
max_num_tiles,
|
|
385
|
-
self.grid_size[0] * self.grid_size[1] + 1,
|
|
386
|
-
width,
|
|
387
|
-
)
|
|
388
|
-
)
|
|
389
|
-
self.gated_positional_embedding_gate = nn.Parameter(torch.zeros(1))
|
|
390
|
-
|
|
391
|
-
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
392
|
-
|
|
393
|
-
def load_hook(
|
|
394
|
-
self,
|
|
395
|
-
state_dict: dict[str, Any],
|
|
396
|
-
prefix: str,
|
|
397
|
-
local_metadata: dict[str, Any],
|
|
398
|
-
strict: bool = True,
|
|
399
|
-
missing_keys: list[str] = None,
|
|
400
|
-
unexpected_keys: list[str] = None,
|
|
401
|
-
error_msgs: list[str] = None,
|
|
402
|
-
return_state_dict: bool = False,
|
|
403
|
-
) -> None:
|
|
404
|
-
orig_pos_embed = state_dict.get(prefix + "positional_embedding")
|
|
405
|
-
if orig_pos_embed is not None:
|
|
406
|
-
new_pos_embed = resize_local_position_embedding(orig_pos_embed, self.grid_size)
|
|
407
|
-
state_dict[prefix + "positional_embedding"] = new_pos_embed
|
|
408
|
-
if hasattr(self, "gated_positional_embedding"):
|
|
409
|
-
if prefix + "gated_positional_embedding" not in state_dict:
|
|
410
|
-
# resize positional_embedding to fit the new grid size
|
|
411
|
-
global_pos_embed = initialize_global_position_embedding_from_local(
|
|
412
|
-
new_pos_embed,
|
|
413
|
-
self.grid_size,
|
|
414
|
-
self.max_num_tiles,
|
|
415
|
-
self.max_num_tiles,
|
|
416
|
-
)
|
|
417
|
-
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
|
|
418
|
-
state_dict[prefix + "gated_positional_embedding_gate"] = torch.zeros(1, dtype=global_pos_embed.dtype)
|
|
419
|
-
logger.info(f"Initialized global positional embedding with size {global_pos_embed.size()}")
|
|
420
|
-
else:
|
|
421
|
-
global_pos_embed = resize_global_position_embedding(
|
|
422
|
-
state_dict[prefix + "gated_positional_embedding"],
|
|
423
|
-
self.grid_size,
|
|
424
|
-
self.max_num_tiles,
|
|
425
|
-
self.max_num_tiles,
|
|
426
|
-
)
|
|
427
|
-
logger.info(
|
|
428
|
-
f"Resized global positional embedding from {state_dict[prefix + 'gated_positional_embedding'].size()} to {global_pos_embed.size()}"
|
|
429
|
-
)
|
|
430
|
-
state_dict[prefix + "gated_positional_embedding"] = global_pos_embed
|
|
431
|
-
if return_state_dict:
|
|
432
|
-
return state_dict
|
|
433
|
-
|
|
434
|
-
def apply_positional_embedding(self, x, ar):
|
|
435
|
-
# apply regular position embedding
|
|
436
|
-
bsz, num_chunks, num_tokens, dim = x.shape
|
|
437
|
-
x = x.view(bsz * num_chunks, num_tokens, dim)
|
|
438
|
-
x = x + self.positional_embedding * (1 - self.gated_positional_embedding_gate.tanh())
|
|
439
|
-
x = x.view(bsz, num_chunks, num_tokens, dim)
|
|
440
|
-
for idx, arx in enumerate(ar):
|
|
441
|
-
_pos_embed = self.gated_positional_embedding[: arx[0], : arx[1]]
|
|
442
|
-
_pos_embed = _pos_embed.reshape(arx[0] * arx[1], *_pos_embed.shape[2:])
|
|
443
|
-
x[idx, : arx[0] * arx[1]] += _pos_embed * self.gated_positional_embedding_gate.tanh()
|
|
444
|
-
return x
|
|
445
|
-
|
|
446
|
-
def apply_class_embedding(self, x):
|
|
447
|
-
x = torch.cat(
|
|
448
|
-
[
|
|
449
|
-
self.class_embedding.to(x.dtype)
|
|
450
|
-
+ torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
|
|
451
|
-
x,
|
|
452
|
-
],
|
|
453
|
-
dim=1,
|
|
454
|
-
) # shape = [*, grid ** 2 + 1, width]
|
|
455
|
-
return x
|
|
456
|
-
|
|
457
|
-
def forward(self, images: torch.Tensor, ar: torch.Tensor) -> torch.Tensor:
|
|
458
|
-
if images.ndim == 5:
|
|
459
|
-
num_concurrent_media = 1
|
|
460
|
-
bsz, num_chunks, nch, w, h = images.shape
|
|
461
|
-
else:
|
|
462
|
-
bsz, num_concurrent_media, num_chunks, nch, w, h = images.shape
|
|
463
|
-
|
|
464
|
-
images = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h)
|
|
465
|
-
ar = ar.reshape(bsz * num_concurrent_media, 2)
|
|
466
|
-
|
|
467
|
-
# patch embedding
|
|
468
|
-
x = images.reshape(bsz * num_concurrent_media * num_chunks, nch, w, h)
|
|
469
|
-
x = self.conv1(x) # shape = [*, width, grid ** 2]
|
|
470
|
-
_, ntok, dim = x.shape
|
|
471
|
-
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
|
|
472
|
-
|
|
473
|
-
# tile embeddings
|
|
474
|
-
x = self.pre_tile_pos_embed(x, ar)
|
|
475
|
-
x = x.reshape(bsz * num_concurrent_media * num_chunks, ntok, dim)
|
|
476
|
-
|
|
477
|
-
# apply cls token
|
|
478
|
-
x = self.apply_class_embedding(x)
|
|
479
|
-
ntok += 1
|
|
480
|
-
|
|
481
|
-
# apply position embeddings
|
|
482
|
-
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok, dim)
|
|
483
|
-
x = self.apply_positional_embedding(x, ar)
|
|
484
|
-
|
|
485
|
-
x = self.ln_pre(x)
|
|
486
|
-
npad, attn_mask = 0, None
|
|
487
|
-
x, npad = expand_num_tokens_to_mult8(x)
|
|
488
|
-
attn_mask = build_encoder_attention_mask(x, ar, ntok, num_chunks, 1)
|
|
489
|
-
x = x.view(bsz * num_concurrent_media, -1, dim)
|
|
490
|
-
x, int_x = self.transformer(x, return_intermediate=self.return_intermediate, mask=attn_mask)
|
|
491
|
-
|
|
492
|
-
x = self.ln_post(x)
|
|
493
|
-
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim)
|
|
494
|
-
x = self.post_tile_pos_embed(x, ar)
|
|
495
|
-
x = x.reshape(bsz * num_concurrent_media, num_chunks * (ntok + npad), dim)
|
|
496
|
-
x = self.global_transformer(x, mask=attn_mask)
|
|
497
|
-
x = x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, dim)
|
|
498
|
-
x = contract_num_tokens_from_mult8(x, npad)
|
|
499
|
-
|
|
500
|
-
# adding back intermediate layer outputs
|
|
501
|
-
x = x.reshape(bsz, num_concurrent_media, num_chunks, ntok, dim)
|
|
502
|
-
int_x = int_x.reshape(bsz * num_concurrent_media, num_chunks, ntok + npad, -1)
|
|
503
|
-
int_x = contract_num_tokens_from_mult8(int_x, npad)
|
|
504
|
-
int_x = int_x.reshape(bsz, num_concurrent_media, num_chunks, ntok, -1)
|
|
505
|
-
x = torch.cat([x, int_x], dim=-1)
|
|
506
|
-
return x
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
class Attention(nn.Module):
|
|
510
|
-
"""Multi-head attention module."""
|
|
511
|
-
|
|
512
|
-
def __init__(self, args: ModelArgs):
|
|
513
|
-
"""
|
|
514
|
-
Initialize the Attention module.
|
|
515
|
-
Args:
|
|
516
|
-
args (ModelArgs): Model configuration parameters.
|
|
517
|
-
Attributes:
|
|
518
|
-
n_kv_heads (int): Number of key and value heads.
|
|
519
|
-
n_local_heads (int): Number of local query heads.
|
|
520
|
-
n_local_kv_heads (int): Number of local key and value heads.
|
|
521
|
-
n_rep (int): Number of repetitions for local heads.
|
|
522
|
-
head_dim (int): Dimension size of each attention head.
|
|
523
|
-
wq (ColumnParallelLinear): Linear transformation for queries.
|
|
524
|
-
wk (ColumnParallelLinear): Linear transformation for keys.
|
|
525
|
-
wv (ColumnParallelLinear): Linear transformation for values.
|
|
526
|
-
wo (RowParallelLinear): Linear transformation for output.
|
|
527
|
-
cache_k (torch.Tensor): Cached keys for attention.
|
|
528
|
-
cache_v (torch.Tensor): Cached values for attention.
|
|
529
|
-
"""
|
|
530
|
-
super().__init__()
|
|
531
|
-
world_size = fs_init.get_model_parallel_world_size()
|
|
532
|
-
replication_factor = 1
|
|
533
|
-
if world_size > 8:
|
|
534
|
-
replication_factor = world_size // MP_SCALE
|
|
535
|
-
|
|
536
|
-
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
|
537
|
-
self.n_kv_heads *= replication_factor
|
|
538
|
-
|
|
539
|
-
self.n_local_heads = args.n_heads // world_size
|
|
540
|
-
self.n_local_kv_heads = self.n_kv_heads // world_size
|
|
541
|
-
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
|
542
|
-
self.head_dim = args.dim // args.n_heads
|
|
543
|
-
self.max_seq_len = args.max_seq_len
|
|
544
|
-
|
|
545
|
-
self.wq = ColumnParallelLinear(
|
|
546
|
-
args.dim,
|
|
547
|
-
args.n_heads * self.head_dim,
|
|
548
|
-
bias=False,
|
|
549
|
-
gather_output=False,
|
|
550
|
-
init_method=lambda x: x,
|
|
551
|
-
)
|
|
552
|
-
self.wk = ColumnParallelLinear(
|
|
553
|
-
args.dim,
|
|
554
|
-
self.n_kv_heads * self.head_dim,
|
|
555
|
-
bias=False,
|
|
556
|
-
gather_output=False,
|
|
557
|
-
init_method=lambda x: x,
|
|
558
|
-
)
|
|
559
|
-
self.wv = ColumnParallelLinear(
|
|
560
|
-
args.dim,
|
|
561
|
-
self.n_kv_heads * self.head_dim,
|
|
562
|
-
bias=False,
|
|
563
|
-
gather_output=False,
|
|
564
|
-
init_method=lambda x: x,
|
|
565
|
-
)
|
|
566
|
-
self.wo = RowParallelLinear(
|
|
567
|
-
args.n_heads * self.head_dim,
|
|
568
|
-
args.dim,
|
|
569
|
-
bias=False,
|
|
570
|
-
input_is_parallel=True,
|
|
571
|
-
init_method=lambda x: x,
|
|
572
|
-
)
|
|
573
|
-
self.n_heads = args.n_heads
|
|
574
|
-
|
|
575
|
-
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
|
|
576
|
-
cache_shape = (
|
|
577
|
-
max_batch_size,
|
|
578
|
-
self.max_seq_len,
|
|
579
|
-
self.n_local_kv_heads,
|
|
580
|
-
self.head_dim,
|
|
581
|
-
)
|
|
582
|
-
self.register_buffer(
|
|
583
|
-
"key_cache",
|
|
584
|
-
torch.zeros(
|
|
585
|
-
cache_shape,
|
|
586
|
-
dtype=dtype,
|
|
587
|
-
),
|
|
588
|
-
persistent=False,
|
|
589
|
-
)
|
|
590
|
-
self.register_buffer(
|
|
591
|
-
"value_cache",
|
|
592
|
-
torch.zeros(
|
|
593
|
-
cache_shape,
|
|
594
|
-
dtype=dtype,
|
|
595
|
-
),
|
|
596
|
-
persistent=False,
|
|
597
|
-
)
|
|
598
|
-
|
|
599
|
-
def forward(
|
|
600
|
-
self,
|
|
601
|
-
x: torch.Tensor,
|
|
602
|
-
mask: torch.Tensor,
|
|
603
|
-
freqs_cis: torch.Tensor,
|
|
604
|
-
position_ids: torch.LongTensor,
|
|
605
|
-
):
|
|
606
|
-
self.key_cache = self.key_cache.to(x.device)
|
|
607
|
-
self.value_cache = self.value_cache.to(x.device)
|
|
608
|
-
|
|
609
|
-
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
|
|
610
|
-
|
|
611
|
-
bs, slen, _ = xq.shape
|
|
612
|
-
|
|
613
|
-
xq = xq.view(bs, slen, self.n_local_heads, self.head_dim)
|
|
614
|
-
xk = xk.view(bs, xk.shape[1], self.n_local_kv_heads, self.head_dim)
|
|
615
|
-
xv = xv.view(bs, xv.shape[1], self.n_local_kv_heads, self.head_dim)
|
|
616
|
-
|
|
617
|
-
xq, xk = apply_rotary_emb(xq, xk, freqs_cis)
|
|
618
|
-
|
|
619
|
-
self.key_cache[:bs, position_ids, ...] = xk
|
|
620
|
-
self.value_cache[:bs, position_ids, ...] = xv
|
|
621
|
-
|
|
622
|
-
# TODO: we can avoid slicing on first dimension by always padding to max_batch_size()
|
|
623
|
-
xk = self.key_cache[:bs, ...]
|
|
624
|
-
xv = self.value_cache[:bs, ...]
|
|
625
|
-
|
|
626
|
-
xq, xk, xv = [tensor.transpose(1, 2) for tensor in (xq, xk, xv)]
|
|
627
|
-
|
|
628
|
-
xk = xk.repeat_interleave(self.n_rep, dim=1)
|
|
629
|
-
xv = xv.repeat_interleave(self.n_rep, dim=1)
|
|
630
|
-
|
|
631
|
-
attn_output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=mask, dropout_p=0.0)
|
|
632
|
-
|
|
633
|
-
attn_output = attn_output.transpose(1, 2).contiguous().reshape(bs, slen, -1)
|
|
634
|
-
|
|
635
|
-
out = F.linear(attn_output, self.wo.weight)
|
|
636
|
-
out = reduce_from_tensor_model_parallel_region(out)
|
|
637
|
-
return out
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
class FeedForward(nn.Module):
|
|
641
|
-
def __init__(
|
|
642
|
-
self,
|
|
643
|
-
dim: int,
|
|
644
|
-
hidden_dim: int,
|
|
645
|
-
multiple_of: int,
|
|
646
|
-
ffn_dim_multiplier: float | None,
|
|
647
|
-
):
|
|
648
|
-
"""
|
|
649
|
-
Initialize the FeedForward module.
|
|
650
|
-
Args:
|
|
651
|
-
dim (int): Input dimension.
|
|
652
|
-
hidden_dim (int): Hidden dimension of the feedforward layer.
|
|
653
|
-
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
|
|
654
|
-
ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
|
|
655
|
-
Attributes:
|
|
656
|
-
w1 (ColumnParallelLinear): Linear transformation for the first layer.
|
|
657
|
-
w2 (RowParallelLinear): Linear transformation for the second layer.
|
|
658
|
-
w3 (ColumnParallelLinear): Linear transformation for the third layer.
|
|
659
|
-
"""
|
|
660
|
-
super().__init__()
|
|
661
|
-
hidden_dim = int(2 * hidden_dim / 3)
|
|
662
|
-
# custom dim factor multiplier
|
|
663
|
-
if ffn_dim_multiplier is not None:
|
|
664
|
-
hidden_dim = int(ffn_dim_multiplier * hidden_dim)
|
|
665
|
-
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
|
666
|
-
|
|
667
|
-
self.w1 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
|
668
|
-
self.w2 = RowParallelLinear(hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x)
|
|
669
|
-
self.w3 = ColumnParallelLinear(dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x)
|
|
670
|
-
|
|
671
|
-
def forward(self, x):
|
|
672
|
-
x1, x3 = [F.linear(x, w) for w in [self.w1.weight, self.w3.weight]]
|
|
673
|
-
x1 = F.silu(x1)
|
|
674
|
-
x_in = x1 * x3
|
|
675
|
-
out = F.linear(x_in, self.w2.weight)
|
|
676
|
-
out = reduce_from_tensor_model_parallel_region(out)
|
|
677
|
-
return out
|
|
678
|
-
|
|
679
|
-
|
|
680
|
-
class TransformerBlock(nn.Module):
|
|
681
|
-
def __init__(self, layer_id: int, args: ModelArgs):
|
|
682
|
-
"""
|
|
683
|
-
Initialize a TransformerBlock.
|
|
684
|
-
Args:
|
|
685
|
-
layer_id (int): Identifier for the layer.
|
|
686
|
-
args (ModelArgs): Model configuration parameters.
|
|
687
|
-
Attributes:
|
|
688
|
-
n_heads (int): Number of attention heads.
|
|
689
|
-
dim (int): Dimension size of the model.
|
|
690
|
-
head_dim (int): Dimension size of each attention head.
|
|
691
|
-
attention (Attention): Attention module.
|
|
692
|
-
feed_forward (FeedForward): FeedForward module.
|
|
693
|
-
layer_id (int): Identifier for the layer.
|
|
694
|
-
attention_norm (RMSNorm): Layer normalization for attention output.
|
|
695
|
-
ffn_norm (RMSNorm): Layer normalization for feedforward output.
|
|
696
|
-
"""
|
|
697
|
-
super().__init__()
|
|
698
|
-
self.n_heads = args.n_heads
|
|
699
|
-
self.dim = args.dim
|
|
700
|
-
self.head_dim = args.dim // args.n_heads
|
|
701
|
-
self.attention = Attention(args)
|
|
702
|
-
self.feed_forward = FeedForward(
|
|
703
|
-
dim=args.dim,
|
|
704
|
-
hidden_dim=4 * args.dim,
|
|
705
|
-
multiple_of=args.multiple_of,
|
|
706
|
-
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
|
707
|
-
)
|
|
708
|
-
self.layer_id = layer_id
|
|
709
|
-
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
|
710
|
-
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
|
711
|
-
|
|
712
|
-
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
|
|
713
|
-
self.attention.setup_cache(max_batch_size, dtype)
|
|
714
|
-
|
|
715
|
-
def forward(
|
|
716
|
-
self,
|
|
717
|
-
x: torch.Tensor,
|
|
718
|
-
freqs_cis: torch.Tensor,
|
|
719
|
-
mask: torch.Tensor,
|
|
720
|
-
position_ids: torch.LongTensor,
|
|
721
|
-
) -> torch.Tensor:
|
|
722
|
-
"""
|
|
723
|
-
Perform a forward pass through the TransformerBlock.
|
|
724
|
-
Args:
|
|
725
|
-
x (torch.Tensor): Input tensor.
|
|
726
|
-
start_pos (int): Starting position for attention caching.
|
|
727
|
-
freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
|
|
728
|
-
mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
|
|
729
|
-
Returns:
|
|
730
|
-
torch.Tensor: Output tensor after applying attention and feedforward layers.
|
|
731
|
-
"""
|
|
732
|
-
h = self.attention.forward(
|
|
733
|
-
x=self.attention_norm(x),
|
|
734
|
-
freqs_cis=freqs_cis,
|
|
735
|
-
mask=mask,
|
|
736
|
-
position_ids=position_ids,
|
|
737
|
-
)
|
|
738
|
-
h = h + x
|
|
739
|
-
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
|
740
|
-
return out
|
|
741
|
-
|
|
742
|
-
|
|
743
|
-
class TilePositionEmbedding(nn.Module):
|
|
744
|
-
def __init__(
|
|
745
|
-
self,
|
|
746
|
-
num_tiles: int,
|
|
747
|
-
width: int,
|
|
748
|
-
gated: bool = False,
|
|
749
|
-
):
|
|
750
|
-
super().__init__()
|
|
751
|
-
self.num_tiles = num_tiles
|
|
752
|
-
self.width = width
|
|
753
|
-
self.embedding = nn.Parameter(torch.randn(num_tiles, num_tiles, 1, width) / math.sqrt(width))
|
|
754
|
-
self.gated = gated
|
|
755
|
-
if gated:
|
|
756
|
-
self.gate = nn.Parameter(torch.zeros(1))
|
|
757
|
-
|
|
758
|
-
self._register_load_state_dict_pre_hook(self.load_hook)
|
|
759
|
-
|
|
760
|
-
def load_hook(
|
|
761
|
-
self,
|
|
762
|
-
state_dict,
|
|
763
|
-
prefix,
|
|
764
|
-
local_metadata,
|
|
765
|
-
strict,
|
|
766
|
-
missing_keys,
|
|
767
|
-
unexpected_keys,
|
|
768
|
-
error_msgs,
|
|
769
|
-
):
|
|
770
|
-
# load the weights from the checkpoint
|
|
771
|
-
embed = state_dict.get(prefix + "embedding")
|
|
772
|
-
if embed is not None:
|
|
773
|
-
# reshape the weights to the correct shape
|
|
774
|
-
nt_old, nt_old, _, w = embed.shape
|
|
775
|
-
logger.info(f"Resizing tile embedding from {nt_old}x{nt_old} to {self.num_tiles}x{self.num_tiles}")
|
|
776
|
-
embed_new = TilePositionEmbedding._dynamic_resize(embed, self.num_tiles)
|
|
777
|
-
# assign the weights to the module
|
|
778
|
-
state_dict[prefix + "embedding"] = embed_new
|
|
779
|
-
|
|
780
|
-
@staticmethod
|
|
781
|
-
def _dynamic_resize(embed: torch.Tensor, num_tiles: int):
|
|
782
|
-
nt_old, nt_old, _, w = embed.shape
|
|
783
|
-
embed = embed.permute(2, 3, 0, 1)
|
|
784
|
-
|
|
785
|
-
embed_new = F.interpolate(
|
|
786
|
-
embed,
|
|
787
|
-
size=(num_tiles, num_tiles),
|
|
788
|
-
mode="bilinear",
|
|
789
|
-
align_corners=True,
|
|
790
|
-
)
|
|
791
|
-
# reshape the weights to the correct shape
|
|
792
|
-
embed_new = embed_new.permute(2, 3, 0, 1)
|
|
793
|
-
return embed_new
|
|
794
|
-
|
|
795
|
-
def forward(self, x: torch.Tensor, ar: torch.Tensor, num_tiles: int = None):
|
|
796
|
-
embed = self.embedding
|
|
797
|
-
if num_tiles is None:
|
|
798
|
-
num_tiles = self.num_tiles
|
|
799
|
-
elif num_tiles > self.num_tiles:
|
|
800
|
-
embed = TilePositionEmbedding._dynamic_resize(self.embedding, num_tiles)
|
|
801
|
-
out_pos_embed = torch.zeros(x.shape[0], num_tiles, 1, self.width, device=x.device, dtype=x.dtype)
|
|
802
|
-
for idx, arx in enumerate(ar):
|
|
803
|
-
h, w = arx
|
|
804
|
-
out_pos_embed[idx, : w * h] = embed[:h, :w].reshape(w * h, 1, self.width)
|
|
805
|
-
if self.gated:
|
|
806
|
-
out_pos_embed = out_pos_embed * self.gate.tanh()
|
|
807
|
-
x = x + out_pos_embed
|
|
808
|
-
return x
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
def _noinit(x):
|
|
812
|
-
return x
|
|
813
|
-
|
|
814
|
-
|
|
815
|
-
class CrossAttention(torch.nn.Module):
|
|
816
|
-
"""Cross attention layer with model-parallel attention layers."""
|
|
817
|
-
|
|
818
|
-
def __init__(
|
|
819
|
-
self,
|
|
820
|
-
dim: int,
|
|
821
|
-
head_dim: int,
|
|
822
|
-
n_heads: int,
|
|
823
|
-
n_kv_heads: int,
|
|
824
|
-
norm_eps: float,
|
|
825
|
-
):
|
|
826
|
-
super().__init__()
|
|
827
|
-
self.world_size = fs_init.get_model_parallel_world_size()
|
|
828
|
-
replication_factor = 1
|
|
829
|
-
if self.world_size > 8:
|
|
830
|
-
replication_factor = self.world_size // MP_SCALE
|
|
831
|
-
n_kv_heads *= replication_factor
|
|
832
|
-
|
|
833
|
-
assert n_heads % n_kv_heads == 0
|
|
834
|
-
|
|
835
|
-
self.wq = ColumnParallelLinear(
|
|
836
|
-
dim,
|
|
837
|
-
n_heads * head_dim,
|
|
838
|
-
bias=False,
|
|
839
|
-
gather_output=False,
|
|
840
|
-
init_method=_noinit,
|
|
841
|
-
)
|
|
842
|
-
|
|
843
|
-
self.wk = ColumnParallelLinear(
|
|
844
|
-
dim,
|
|
845
|
-
n_kv_heads * head_dim,
|
|
846
|
-
bias=False,
|
|
847
|
-
gather_output=False,
|
|
848
|
-
init_method=_noinit,
|
|
849
|
-
)
|
|
850
|
-
self.wv = ColumnParallelLinear(
|
|
851
|
-
dim,
|
|
852
|
-
n_kv_heads * head_dim,
|
|
853
|
-
bias=False,
|
|
854
|
-
gather_output=False,
|
|
855
|
-
init_method=_noinit,
|
|
856
|
-
)
|
|
857
|
-
self.wo = RowParallelLinear(
|
|
858
|
-
n_heads * head_dim,
|
|
859
|
-
dim,
|
|
860
|
-
bias=False,
|
|
861
|
-
input_is_parallel=True,
|
|
862
|
-
init_method=_noinit,
|
|
863
|
-
)
|
|
864
|
-
|
|
865
|
-
self.n_heads = n_heads
|
|
866
|
-
self.head_dim = head_dim
|
|
867
|
-
self.n_kv_heads = n_kv_heads
|
|
868
|
-
|
|
869
|
-
self.q_norm = RMSNorm(
|
|
870
|
-
self.head_dim,
|
|
871
|
-
eps=norm_eps,
|
|
872
|
-
)
|
|
873
|
-
self.k_norm = RMSNorm(
|
|
874
|
-
self.head_dim,
|
|
875
|
-
eps=norm_eps,
|
|
876
|
-
)
|
|
877
|
-
|
|
878
|
-
# cross-attention heads are model parallel similar to
|
|
879
|
-
# self-attention, and we also use the identical KV head
|
|
880
|
-
# combination to ensure parity with the corresponding
|
|
881
|
-
# trunk LLM (i.e., group query attention) -- @dubeya
|
|
882
|
-
# local heads
|
|
883
|
-
assert self.n_heads % self.n_kv_heads == 0
|
|
884
|
-
assert self.n_heads % self.world_size == 0
|
|
885
|
-
assert self.n_kv_heads % self.world_size == 0
|
|
886
|
-
self.n_local_heads = self.n_heads // self.world_size
|
|
887
|
-
self.n_local_kv_heads = self.n_kv_heads // self.world_size
|
|
888
|
-
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
|
889
|
-
|
|
890
|
-
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
|
|
891
|
-
bsz = xattn_tokens.shape[0]
|
|
892
|
-
xk = self.wk(xattn_tokens)
|
|
893
|
-
xv = self.wv(xattn_tokens)
|
|
894
|
-
|
|
895
|
-
_, seqlen_y, _ = xk.shape
|
|
896
|
-
|
|
897
|
-
xk = xk.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim)
|
|
898
|
-
xv = xv.view(bsz, seqlen_y, self.n_local_kv_heads, self.head_dim)
|
|
899
|
-
|
|
900
|
-
xk, xv = [tensor.transpose(1, 2) for tensor in (xk, xv)]
|
|
901
|
-
|
|
902
|
-
# repeat k/v heads if n_kv_heads < n_heads
|
|
903
|
-
xk = xk.repeat_interleave(self.n_rep, dim=1)
|
|
904
|
-
xv = xv.repeat_interleave(self.n_rep, dim=1)
|
|
905
|
-
|
|
906
|
-
xk = self.k_norm(xk)
|
|
907
|
-
|
|
908
|
-
return torch.stack([xk, xv])
|
|
909
|
-
|
|
910
|
-
def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
|
|
911
|
-
return self._compute_xattn_kv_cache(xattn_tokens)
|
|
912
|
-
|
|
913
|
-
def forward(
|
|
914
|
-
self,
|
|
915
|
-
x: torch.Tensor,
|
|
916
|
-
xattn_mask: torch.Tensor,
|
|
917
|
-
full_text_row_masked_out_mask: torch.Tensor,
|
|
918
|
-
xattn_cache: torch.Tensor,
|
|
919
|
-
) -> torch.Tensor:
|
|
920
|
-
xq = F.linear(x, self.wq.weight)
|
|
921
|
-
bsz, seqlen, _ = x.shape
|
|
922
|
-
|
|
923
|
-
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
|
924
|
-
xq = self.q_norm(xq)
|
|
925
|
-
xq = xq.transpose(1, 2)
|
|
926
|
-
|
|
927
|
-
xk, xv = xattn_cache
|
|
928
|
-
|
|
929
|
-
output = F.scaled_dot_product_attention(xq, xk, xv, attn_mask=xattn_mask, dropout_p=0.0)
|
|
930
|
-
output = output * full_text_row_masked_out_mask
|
|
931
|
-
output = output.transpose(1, 2).contiguous().reshape(bsz, seqlen, -1)
|
|
932
|
-
|
|
933
|
-
out = F.linear(output, self.wo.weight)
|
|
934
|
-
out = reduce_from_tensor_model_parallel_region(out)
|
|
935
|
-
return out
|
|
936
|
-
|
|
937
|
-
|
|
938
|
-
class CrossAttentionTransformerBlock(torch.nn.Module):
|
|
939
|
-
"""Cross-attention transformer block with tanh-gated attention and feedforward."""
|
|
940
|
-
|
|
941
|
-
def __init__(
|
|
942
|
-
self,
|
|
943
|
-
args: ModelArgs,
|
|
944
|
-
layer_id: int,
|
|
945
|
-
no_ffn: bool = False,
|
|
946
|
-
) -> None:
|
|
947
|
-
super().__init__()
|
|
948
|
-
self.layer_id = layer_id
|
|
949
|
-
self.n_heads = args.n_heads
|
|
950
|
-
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
|
951
|
-
self.dim = args.dim
|
|
952
|
-
self.head_dim = args.dim // args.n_heads
|
|
953
|
-
self.attention = CrossAttention(
|
|
954
|
-
dim=args.dim,
|
|
955
|
-
head_dim=self.head_dim,
|
|
956
|
-
n_heads=self.n_heads,
|
|
957
|
-
n_kv_heads=self.n_kv_heads,
|
|
958
|
-
norm_eps=args.norm_eps,
|
|
959
|
-
)
|
|
960
|
-
|
|
961
|
-
self.attention_norm = RMSNorm(
|
|
962
|
-
args.dim,
|
|
963
|
-
eps=args.norm_eps,
|
|
964
|
-
)
|
|
965
|
-
self.gate_attn = torch.nn.Parameter(torch.zeros(1))
|
|
966
|
-
|
|
967
|
-
self.feed_forward = FeedForward(
|
|
968
|
-
dim=args.dim,
|
|
969
|
-
hidden_dim=4 * args.dim,
|
|
970
|
-
ffn_dim_multiplier=args.ffn_dim_multiplier,
|
|
971
|
-
multiple_of=args.multiple_of,
|
|
972
|
-
)
|
|
973
|
-
self.ffn_norm = RMSNorm(
|
|
974
|
-
args.dim,
|
|
975
|
-
eps=args.norm_eps,
|
|
976
|
-
)
|
|
977
|
-
self.gate_ffwd = torch.nn.Parameter(torch.zeros(1))
|
|
978
|
-
|
|
979
|
-
self.no_ffn = no_ffn
|
|
980
|
-
|
|
981
|
-
def compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
|
|
982
|
-
return self.attention.compute_xattn_kv_cache(xattn_tokens)
|
|
983
|
-
|
|
984
|
-
def forward(
|
|
985
|
-
self,
|
|
986
|
-
x: torch.Tensor,
|
|
987
|
-
xattn_mask: torch.Tensor,
|
|
988
|
-
full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor],
|
|
989
|
-
xattn_cache: torch.Tensor,
|
|
990
|
-
) -> torch.Tensor:
|
|
991
|
-
_attn_out = self.attention(
|
|
992
|
-
x=self.attention_norm(x),
|
|
993
|
-
xattn_mask=xattn_mask,
|
|
994
|
-
xattn_cache=xattn_cache,
|
|
995
|
-
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
996
|
-
)
|
|
997
|
-
h = x + self.gate_attn.tanh() * _attn_out
|
|
998
|
-
_ffn = self.feed_forward(self.ffn_norm(h))
|
|
999
|
-
_ffn = full_text_row_masked_out_mask[:, 0] * _ffn # type: ignore
|
|
1000
|
-
h = h + self.gate_ffwd.tanh() * _ffn * float(not self.no_ffn)
|
|
1001
|
-
return h
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
class DummyCrossAttentionTransformerBlock:
|
|
1005
|
-
"""Dummy cross-attention transformer block with tanh-gated attention and feedforward."""
|
|
1006
|
-
|
|
1007
|
-
def __call__(
|
|
1008
|
-
self,
|
|
1009
|
-
x: torch.Tensor,
|
|
1010
|
-
*args,
|
|
1011
|
-
**kwargs,
|
|
1012
|
-
) -> torch.Tensor:
|
|
1013
|
-
return x
|
|
1014
|
-
|
|
1015
|
-
|
|
1016
|
-
class DummySelfAttentionTransformerBlock:
|
|
1017
|
-
"""Dummy self-attention transformer block"""
|
|
1018
|
-
|
|
1019
|
-
def __call__(
|
|
1020
|
-
self,
|
|
1021
|
-
x: torch.Tensor,
|
|
1022
|
-
*args,
|
|
1023
|
-
**kwargs,
|
|
1024
|
-
) -> torch.Tensor:
|
|
1025
|
-
return x
|
|
1026
|
-
|
|
1027
|
-
|
|
1028
|
-
class CrossAttentionTransformerVision(torch.nn.Module):
|
|
1029
|
-
def __init__(self, args: ModelArgs) -> None:
|
|
1030
|
-
super().__init__()
|
|
1031
|
-
return_intermediate = "3,7,15,23,30"
|
|
1032
|
-
self.vision_input_dim = 1280
|
|
1033
|
-
self.image_res = args.vision_chunk_size
|
|
1034
|
-
self.max_num_chunks = args.vision_max_num_chunks
|
|
1035
|
-
if return_intermediate is not None:
|
|
1036
|
-
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
|
|
1037
|
-
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
|
|
1038
|
-
self.patch_size = 14
|
|
1039
|
-
self.vision_encoder = VisionEncoder(
|
|
1040
|
-
max_num_tiles=4,
|
|
1041
|
-
image_size=args.vision_chunk_size,
|
|
1042
|
-
patch_size=self.patch_size,
|
|
1043
|
-
n_global_layers=8,
|
|
1044
|
-
global_model=True,
|
|
1045
|
-
return_intermediate=return_intermediate,
|
|
1046
|
-
)
|
|
1047
|
-
# vision token projection
|
|
1048
|
-
self.vision_projection = ColumnParallelLinear(
|
|
1049
|
-
self.vision_input_dim,
|
|
1050
|
-
args.dim,
|
|
1051
|
-
bias=True,
|
|
1052
|
-
init_method=lambda x: x,
|
|
1053
|
-
)
|
|
1054
|
-
|
|
1055
|
-
def forward(self, images: torch.Tensor, aspect_ratios: torch.Tensor) -> torch.Tensor:
|
|
1056
|
-
# vision_tokens: (B, T, D)
|
|
1057
|
-
# aspect_ratios: (B, T)
|
|
1058
|
-
# h: (B, T, D)
|
|
1059
|
-
vision_tokens = self.vision_encoder(images.to(dtype=torch.get_default_dtype()), aspect_ratios)
|
|
1060
|
-
|
|
1061
|
-
vision_tokens = F.linear(vision_tokens, self.vision_projection.weight, self.vision_projection.bias)
|
|
1062
|
-
vision_tokens = gather_from_tensor_model_parallel_region(vision_tokens)
|
|
1063
|
-
return vision_tokens
|
|
1064
|
-
|
|
1065
|
-
|
|
1066
|
-
class CrossAttentionTransformerText(torch.nn.Module):
|
|
1067
|
-
INFERENCE_IMAGE_TOKEN_ID = 128010
|
|
1068
|
-
|
|
1069
|
-
def __init__(self, args: ModelArgs) -> None:
|
|
1070
|
-
super().__init__()
|
|
1071
|
-
self.world_size = fs_init.get_model_parallel_world_size()
|
|
1072
|
-
assert args.vocab_size > 0
|
|
1073
|
-
self.vocab_size = args.vocab_size
|
|
1074
|
-
self.n_layers = args.n_layers
|
|
1075
|
-
self.dim = args.dim
|
|
1076
|
-
self.head_dim = args.dim // args.n_heads
|
|
1077
|
-
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
|
1078
|
-
self.n_local_kv_heads = self.n_kv_heads // self.world_size
|
|
1079
|
-
assert self.vocab_size % self.world_size == 0
|
|
1080
|
-
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
|
1081
|
-
self.pos_embeddings = None
|
|
1082
|
-
# final norm layer (not necessary for post-norm)
|
|
1083
|
-
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
|
|
1084
|
-
|
|
1085
|
-
# output layer
|
|
1086
|
-
self.output = ColumnParallelLinear(args.dim, args.vocab_size, bias=False, init_method=lambda x: x)
|
|
1087
|
-
|
|
1088
|
-
self.n_llama_layers = args.n_layers
|
|
1089
|
-
self.model_dim = args.dim
|
|
1090
|
-
|
|
1091
|
-
# BLOCKS
|
|
1092
|
-
|
|
1093
|
-
self.fusion_schedule = self._init_fusion_schedule(args.vision_num_cross_attention_layers)
|
|
1094
|
-
self.learnable_embedding = VocabParallelEmbedding(
|
|
1095
|
-
max(fs_init.get_model_parallel_world_size(), 8),
|
|
1096
|
-
args.dim,
|
|
1097
|
-
init_method=lambda x: x,
|
|
1098
|
-
)
|
|
1099
|
-
self.num_frozen_embeddings = self.tok_embeddings.num_embeddings
|
|
1100
|
-
self._thresh = self.num_frozen_embeddings - 1
|
|
1101
|
-
|
|
1102
|
-
# transformer blocks
|
|
1103
|
-
self.layers = torch.nn.ModuleList()
|
|
1104
|
-
self.cross_attention_layers = torch.nn.ModuleList()
|
|
1105
|
-
for i in range(args.n_layers):
|
|
1106
|
-
layer_id = i
|
|
1107
|
-
block = TransformerBlock(args=args, layer_id=layer_id)
|
|
1108
|
-
self.layers.append(block)
|
|
1109
|
-
if layer_id in self.fusion_schedule:
|
|
1110
|
-
xa_layer_id = self.fusion_schedule.index(layer_id) + args.n_layers
|
|
1111
|
-
block = CrossAttentionTransformerBlock(
|
|
1112
|
-
args,
|
|
1113
|
-
layer_id=xa_layer_id,
|
|
1114
|
-
)
|
|
1115
|
-
self.cross_attention_layers.append(block)
|
|
1116
|
-
|
|
1117
|
-
# add xattn and dummy layers to avoid conditionals in forward()
|
|
1118
|
-
self.text_and_xattn_layers = []
|
|
1119
|
-
|
|
1120
|
-
for idx, layer in enumerate(self.layers):
|
|
1121
|
-
if idx in self.fusion_schedule:
|
|
1122
|
-
xattn_layer_idx = self.fusion_schedule.index(idx)
|
|
1123
|
-
xattn_layer = self.cross_attention_layers[xattn_layer_idx]
|
|
1124
|
-
else:
|
|
1125
|
-
xattn_layer_idx = 0
|
|
1126
|
-
xattn_layer = DummyCrossAttentionTransformerBlock()
|
|
1127
|
-
|
|
1128
|
-
self.text_and_xattn_layers.append(
|
|
1129
|
-
(
|
|
1130
|
-
layer,
|
|
1131
|
-
xattn_layer,
|
|
1132
|
-
xattn_layer_idx,
|
|
1133
|
-
)
|
|
1134
|
-
)
|
|
1135
|
-
self.freqs_cis = precompute_freqs_cis(
|
|
1136
|
-
args.dim // args.n_heads,
|
|
1137
|
-
args.max_seq_len * 2,
|
|
1138
|
-
args.rope_theta,
|
|
1139
|
-
args.use_scaled_rope,
|
|
1140
|
-
)
|
|
1141
|
-
|
|
1142
|
-
self.args = args
|
|
1143
|
-
self.cache_is_setup = False
|
|
1144
|
-
self.max_seq_len = args.max_seq_len
|
|
1145
|
-
|
|
1146
|
-
def _init_fusion_schedule(
|
|
1147
|
-
self,
|
|
1148
|
-
num_layers: int,
|
|
1149
|
-
) -> list[int]:
|
|
1150
|
-
llama_layers = list(range(self.n_llama_layers))
|
|
1151
|
-
|
|
1152
|
-
# uniformly spread the layers
|
|
1153
|
-
k = math.ceil(len(llama_layers) / num_layers)
|
|
1154
|
-
return llama_layers[::-1][::k][:num_layers][::-1]
|
|
1155
|
-
|
|
1156
|
-
def get_partially_trainable_embedding(self, x):
|
|
1157
|
-
xz = torch.zeros_like(x, device=x.device)
|
|
1158
|
-
oz = torch.ones_like(x, device=x.device)
|
|
1159
|
-
x_orig = torch.minimum(x, torch.tensor(self._thresh, device=x.device))
|
|
1160
|
-
x_new = torch.maximum(x, torch.tensor(self._thresh + 1, device=x.device)) - self.num_frozen_embeddings
|
|
1161
|
-
|
|
1162
|
-
mask_orig = torch.where(x >= self.num_frozen_embeddings, xz, oz).unsqueeze(-1)
|
|
1163
|
-
mask_new = torch.where(x < self.num_frozen_embeddings, xz, oz).unsqueeze(-1)
|
|
1164
|
-
|
|
1165
|
-
x_orig = self.tok_embeddings(x_orig)
|
|
1166
|
-
x_new = self.learnable_embedding(x_new).type_as(x_orig)
|
|
1167
|
-
return x_orig * mask_orig.type_as(x_orig) + x_new * mask_new.type_as(x_new)
|
|
1168
|
-
|
|
1169
|
-
def forward(
|
|
1170
|
-
self,
|
|
1171
|
-
position_ids: torch.LongTensor,
|
|
1172
|
-
h: torch.Tensor,
|
|
1173
|
-
xattn_mask: torch.Tensor,
|
|
1174
|
-
full_text_row_masked_out_mask: torch.Tensor,
|
|
1175
|
-
xattn_caches: torch.Tensor,
|
|
1176
|
-
text_only_inference: bool = False,
|
|
1177
|
-
):
|
|
1178
|
-
assert self.cache_is_setup, "Please set up cache before calling forward"
|
|
1179
|
-
self.mask_cache = self.mask_cache.to(h.device)
|
|
1180
|
-
self.freqs_cis = self.freqs_cis.to(h.device)
|
|
1181
|
-
mask = self.mask_cache.index_select(2, position_ids)
|
|
1182
|
-
freqs_cis = self.freqs_cis.index_select(0, position_ids)
|
|
1183
|
-
|
|
1184
|
-
for (
|
|
1185
|
-
layer,
|
|
1186
|
-
xattn_layer,
|
|
1187
|
-
xattn_layer_idx,
|
|
1188
|
-
) in self.text_and_xattn_layers:
|
|
1189
|
-
if not text_only_inference:
|
|
1190
|
-
h = xattn_layer(
|
|
1191
|
-
x=h,
|
|
1192
|
-
xattn_mask=xattn_mask,
|
|
1193
|
-
xattn_cache=xattn_caches[xattn_layer_idx],
|
|
1194
|
-
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
|
1195
|
-
)
|
|
1196
|
-
h = layer(
|
|
1197
|
-
x=h,
|
|
1198
|
-
mask=mask,
|
|
1199
|
-
freqs_cis=freqs_cis,
|
|
1200
|
-
position_ids=position_ids,
|
|
1201
|
-
)
|
|
1202
|
-
|
|
1203
|
-
h = self.norm(h)
|
|
1204
|
-
|
|
1205
|
-
output = F.linear(h, self.output.weight)
|
|
1206
|
-
output = gather_from_tensor_model_parallel_region(output)
|
|
1207
|
-
return output.float()
|
|
1208
|
-
|
|
1209
|
-
def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
|
|
1210
|
-
# Set up the text kv caches
|
|
1211
|
-
ones = torch.ones(
|
|
1212
|
-
(self.max_seq_len, self.max_seq_len),
|
|
1213
|
-
dtype=torch.bool,
|
|
1214
|
-
device=device,
|
|
1215
|
-
)
|
|
1216
|
-
self.register_buffer(
|
|
1217
|
-
"mask_cache",
|
|
1218
|
-
torch.tril(
|
|
1219
|
-
ones,
|
|
1220
|
-
)
|
|
1221
|
-
.unsqueeze(0)
|
|
1222
|
-
.unsqueeze(0),
|
|
1223
|
-
persistent=False,
|
|
1224
|
-
)
|
|
1225
|
-
for layer in self.layers:
|
|
1226
|
-
layer.setup_cache(max_batch_size, dtype=dtype)
|
|
1227
|
-
self.cache_is_setup = True
|
|
1228
|
-
|
|
1229
|
-
def _get_xattn_mask(
|
|
1230
|
-
self,
|
|
1231
|
-
num_tokens,
|
|
1232
|
-
text_device,
|
|
1233
|
-
text_dtype,
|
|
1234
|
-
vision_tokens,
|
|
1235
|
-
cross_attention_masks,
|
|
1236
|
-
) -> tuple[Tensor, Tensor]:
|
|
1237
|
-
assert vision_tokens is not None, "Vision tokens must be provided"
|
|
1238
|
-
vision_seqlen = vision_tokens.shape[3]
|
|
1239
|
-
assert vision_tokens.shape[1] == cross_attention_masks.shape[2], (
|
|
1240
|
-
f"Mismatch in number of images given and number of masks given {vision_tokens.shape} {cross_attention_masks.shape}"
|
|
1241
|
-
)
|
|
1242
|
-
assert vision_tokens.shape[2] == cross_attention_masks.shape[3], (
|
|
1243
|
-
f"Vision tokens shape {vision_tokens.shape} mismatch with xattn shape {cross_attention_masks.shape}"
|
|
1244
|
-
)
|
|
1245
|
-
assert num_tokens == cross_attention_masks.shape[1], (
|
|
1246
|
-
f"Mismatch in text sequence length and cross attention mask sequence length {num_tokens} {cross_attention_masks.shape}"
|
|
1247
|
-
)
|
|
1248
|
-
_, _, _, num_image_tokens, image_token_dim = tuple(vision_tokens.shape)
|
|
1249
|
-
bsz, ntext, nimg, nchunks = cross_attention_masks.shape
|
|
1250
|
-
cross_attention_masks = (
|
|
1251
|
-
cross_attention_masks.repeat_interleave(vision_seqlen, dim=3).view(bsz, ntext, -1).unsqueeze(1)
|
|
1252
|
-
)
|
|
1253
|
-
full_text_row_masked_out_mask = _get_full_row_masked_out_mask(
|
|
1254
|
-
cross_attention_masks,
|
|
1255
|
-
get_negative_inf_value(cross_attention_masks.dtype),
|
|
1256
|
-
)
|
|
1257
|
-
cross_attention_masks *= full_text_row_masked_out_mask
|
|
1258
|
-
|
|
1259
|
-
return (
|
|
1260
|
-
cross_attention_masks.to(device=text_device, dtype=text_dtype),
|
|
1261
|
-
full_text_row_masked_out_mask.to(device=text_device),
|
|
1262
|
-
)
|
|
1263
|
-
|
|
1264
|
-
|
|
1265
|
-
class CrossAttentionTransformer(torch.nn.Module):
|
|
1266
|
-
def __init__(self, args: ModelArgs) -> None:
|
|
1267
|
-
super().__init__()
|
|
1268
|
-
self.params = args
|
|
1269
|
-
|
|
1270
|
-
self.model_dim = args.dim
|
|
1271
|
-
self.vision_model = CrossAttentionTransformerVision(args)
|
|
1272
|
-
self.text_model = CrossAttentionTransformerText(args)
|
|
1273
|
-
self.image_res = args.vision_chunk_size
|
|
1274
|
-
self.max_num_chunks = args.vision_max_num_chunks
|
|
1275
|
-
self.image_transform = partial(
|
|
1276
|
-
VariableSizeImageTransform(size=args.vision_chunk_size),
|
|
1277
|
-
max_num_chunks=args.vision_max_num_chunks,
|
|
1278
|
-
)
|
|
1279
|
-
|
|
1280
|
-
def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
|
|
1281
|
-
self.text_model.setup_cache(max_batch_size, device, dtype)
|
|
1282
|
-
|
|
1283
|
-
def compute_vision_tokens_masks(
|
|
1284
|
-
self,
|
|
1285
|
-
batch_images: list[list[PIL_Image.Image]],
|
|
1286
|
-
batch_masks: list[list[list[int]]],
|
|
1287
|
-
total_len: int,
|
|
1288
|
-
device: torch.device,
|
|
1289
|
-
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
1290
|
-
skip_vision_encoder = False
|
|
1291
|
-
|
|
1292
|
-
assert len(batch_images) == len(batch_masks), "Images and masks must have the same length"
|
|
1293
|
-
|
|
1294
|
-
max_num_images = max(len(x) for x in batch_images)
|
|
1295
|
-
bsz = len(batch_images)
|
|
1296
|
-
|
|
1297
|
-
if max_num_images == 0:
|
|
1298
|
-
num_chunks = [[self.max_num_chunks] for _ in batch_images]
|
|
1299
|
-
skip_vision_encoder = True
|
|
1300
|
-
else:
|
|
1301
|
-
images_and_aspect_ratios = [[self.image_transform(im) for im in row] for row in batch_images]
|
|
1302
|
-
transformed_images = [[x[0] for x in row] for row in images_and_aspect_ratios]
|
|
1303
|
-
|
|
1304
|
-
aspect_ratios = torch.ones(bsz, max_num_images, 2, dtype=torch.int64)
|
|
1305
|
-
for i, row in enumerate(images_and_aspect_ratios):
|
|
1306
|
-
if len(row) > 0:
|
|
1307
|
-
aspect_ratios[i, : len(row)] = torch.stack([torch.tensor(x[1]) for x in row])
|
|
1308
|
-
|
|
1309
|
-
stacked_images, num_chunks = _stack_images(
|
|
1310
|
-
transformed_images,
|
|
1311
|
-
max_num_chunks=self.max_num_chunks,
|
|
1312
|
-
image_res=self.params.vision_chunk_size,
|
|
1313
|
-
max_num_images=max_num_images,
|
|
1314
|
-
)
|
|
1315
|
-
stacked_images = stacked_images.to(device=device)
|
|
1316
|
-
|
|
1317
|
-
if skip_vision_encoder:
|
|
1318
|
-
vision_tokens = torch.zeros(
|
|
1319
|
-
(
|
|
1320
|
-
bsz,
|
|
1321
|
-
max_num_images,
|
|
1322
|
-
self.max_num_chunks,
|
|
1323
|
-
int((self.vision_model.image_res / self.vision_model.patch_size) ** 2 + 1),
|
|
1324
|
-
self.model_dim,
|
|
1325
|
-
),
|
|
1326
|
-
)
|
|
1327
|
-
else:
|
|
1328
|
-
vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
|
|
1329
|
-
|
|
1330
|
-
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
|
|
1331
|
-
xattn_caches = torch.stack(
|
|
1332
|
-
[
|
|
1333
|
-
layer.compute_xattn_kv_cache(vision_tokens.view(bsz, -1, image_token_dim))
|
|
1334
|
-
for layer in self.text_model.cross_attention_layers
|
|
1335
|
-
]
|
|
1336
|
-
)
|
|
1337
|
-
padded_masks = _pad_masks(
|
|
1338
|
-
batch_masks,
|
|
1339
|
-
num_chunks,
|
|
1340
|
-
total_len,
|
|
1341
|
-
self.max_num_chunks,
|
|
1342
|
-
)
|
|
1343
|
-
|
|
1344
|
-
cross_attention_masks, full_text_row_masked_out_mask = self.text_model._get_xattn_mask(
|
|
1345
|
-
num_tokens=total_len,
|
|
1346
|
-
text_device=vision_tokens.device.type,
|
|
1347
|
-
text_dtype=next(self.text_model.parameters()).dtype,
|
|
1348
|
-
vision_tokens=vision_tokens,
|
|
1349
|
-
cross_attention_masks=padded_masks,
|
|
1350
|
-
)
|
|
1351
|
-
|
|
1352
|
-
return (xattn_caches, cross_attention_masks, full_text_row_masked_out_mask)
|
|
1353
|
-
|
|
1354
|
-
def forward(
|
|
1355
|
-
self,
|
|
1356
|
-
position_ids: torch.Tensor,
|
|
1357
|
-
tokens: torch.Tensor,
|
|
1358
|
-
cross_attention_masks: torch.Tensor,
|
|
1359
|
-
full_text_row_masked_out_mask: torch.Tensor,
|
|
1360
|
-
xattn_caches: torch.Tensor,
|
|
1361
|
-
text_only_inference: bool = False,
|
|
1362
|
-
) -> torch.Tensor:
|
|
1363
|
-
h = self.text_model.get_partially_trainable_embedding(tokens[:, position_ids])
|
|
1364
|
-
logits = self.text_model.forward(
|
|
1365
|
-
position_ids=position_ids,
|
|
1366
|
-
h=h,
|
|
1367
|
-
xattn_mask=cross_attention_masks[:, :, position_ids],
|
|
1368
|
-
full_text_row_masked_out_mask=full_text_row_masked_out_mask[:, :, position_ids],
|
|
1369
|
-
xattn_caches=xattn_caches,
|
|
1370
|
-
text_only_inference=text_only_inference,
|
|
1371
|
-
)
|
|
1372
|
-
return logits
|
|
1373
|
-
|
|
1374
|
-
|
|
1375
|
-
def _stack_images(
|
|
1376
|
-
images: list[list[PIL_Image.Image]],
|
|
1377
|
-
max_num_chunks: int,
|
|
1378
|
-
image_res: int,
|
|
1379
|
-
max_num_images: int,
|
|
1380
|
-
) -> tuple[torch.Tensor, list[int]]:
|
|
1381
|
-
"""
|
|
1382
|
-
Takes a list of list of images and stacks them into a tensor.
|
|
1383
|
-
This function is needed since images can be of completely
|
|
1384
|
-
different resolutions and aspect ratios.
|
|
1385
|
-
"""
|
|
1386
|
-
out_images, out_num_chunks = [], []
|
|
1387
|
-
for imgs_sample in images:
|
|
1388
|
-
out_images_i = torch.zeros(
|
|
1389
|
-
max_num_images,
|
|
1390
|
-
max_num_chunks,
|
|
1391
|
-
3,
|
|
1392
|
-
image_res,
|
|
1393
|
-
image_res,
|
|
1394
|
-
)
|
|
1395
|
-
_num_chunks = []
|
|
1396
|
-
for j, chunks_image in enumerate(imgs_sample):
|
|
1397
|
-
out_images_i[j, : chunks_image.shape[0]] = chunks_image
|
|
1398
|
-
_num_chunks.append(chunks_image.shape[0])
|
|
1399
|
-
out_images.append(out_images_i)
|
|
1400
|
-
out_num_chunks.append(_num_chunks)
|
|
1401
|
-
return torch.stack(out_images), out_num_chunks
|
|
1402
|
-
|
|
1403
|
-
|
|
1404
|
-
def _pad_masks(
|
|
1405
|
-
all_masks: list[list[list[int]]],
|
|
1406
|
-
all_num_chunks: list[list[int]],
|
|
1407
|
-
total_len: int,
|
|
1408
|
-
max_num_chunks: int,
|
|
1409
|
-
) -> torch.Tensor:
|
|
1410
|
-
dtype = torch.get_default_dtype()
|
|
1411
|
-
inf_value = get_negative_inf_value(dtype)
|
|
1412
|
-
|
|
1413
|
-
bsz = len(all_masks)
|
|
1414
|
-
max_num_media = max([len(m) for m in all_masks])
|
|
1415
|
-
|
|
1416
|
-
out_masks = torch.full(
|
|
1417
|
-
(bsz, total_len, max_num_media, max_num_chunks),
|
|
1418
|
-
inf_value,
|
|
1419
|
-
dtype=dtype,
|
|
1420
|
-
)
|
|
1421
|
-
|
|
1422
|
-
for idx, (mask, num_chunks) in enumerate(zip(all_masks, all_num_chunks, strict=False)):
|
|
1423
|
-
for mask_idx, (mask_elem, mask_num_chunks) in enumerate(zip(mask, num_chunks, strict=False)):
|
|
1424
|
-
if len(mask_elem) == 2:
|
|
1425
|
-
mask_elem[1] = min(mask_elem[1], total_len)
|
|
1426
|
-
if mask_elem[1] == -1:
|
|
1427
|
-
mask_elem[1] = total_len
|
|
1428
|
-
out_masks[idx, mask_elem[0] : mask_elem[1], mask_idx, :mask_num_chunks].fill_(0.0)
|
|
1429
|
-
|
|
1430
|
-
return out_masks
|