sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
- sglang/lang/chat_template.py +17 -0
- sglang/launch_server_llavavid.py +1 -1
- sglang/srt/configs/__init__.py +3 -0
- sglang/srt/configs/model_config.py +27 -2
- sglang/srt/configs/qwen2vl.py +133 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/conversation.py +27 -0
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/__init__.py +16 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
- sglang/srt/layers/attention/flashinfer_backend.py +174 -54
- sglang/srt/layers/attention/triton_backend.py +22 -6
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
- sglang/srt/layers/linear.py +89 -63
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/rotary_embedding.py +112 -0
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/lora/lora.py +3 -1
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +4 -0
- sglang/srt/managers/image_processor.py +186 -13
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/schedule_batch.py +238 -68
- sglang/srt/managers/scheduler.py +69 -50
- sglang/srt/managers/tokenizer_manager.py +24 -4
- sglang/srt/managers/tp_worker.py +26 -111
- sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
- sglang/srt/mem_cache/memory_pool.py +56 -10
- sglang/srt/mem_cache/radix_cache.py +4 -3
- sglang/srt/model_executor/cuda_graph_runner.py +87 -28
- sglang/srt/model_executor/forward_batch_info.py +83 -3
- sglang/srt/model_executor/model_runner.py +32 -11
- sglang/srt/models/chatglm.py +3 -3
- sglang/srt/models/deepseek_v2.py +2 -2
- sglang/srt/models/mllama.py +1004 -0
- sglang/srt/models/qwen2_vl.py +724 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +13 -3
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +12 -0
- sglang/srt/server_args.py +10 -0
- sglang/srt/utils.py +22 -0
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +20 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +100 -3
- sglang/version.py +1 -1
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1004 @@
|
|
1
|
+
# Adapted from:
|
2
|
+
# https://github.com/vllm-project/vllm/blob/7193774b1ff8603ad5bf4598e5efba0d9a39b436/vllm/model_executor/models/mllama.py
|
3
|
+
"""PyTorch Mllama model."""
|
4
|
+
import math
|
5
|
+
from typing import Iterable, List, Optional, Tuple, Union
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
import torch.utils.checkpoint
|
10
|
+
import transformers.models.mllama.configuration_mllama as config_mllama
|
11
|
+
import vllm.distributed.parallel_state as ps
|
12
|
+
from torch import nn
|
13
|
+
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
|
14
|
+
from transformers.models.mllama.modeling_mllama import (
|
15
|
+
_prepare_aspect_ratio_attention_mask,
|
16
|
+
)
|
17
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
18
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
19
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
20
|
+
ParallelLMHead,
|
21
|
+
VocabParallelEmbedding,
|
22
|
+
)
|
23
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24
|
+
|
25
|
+
from sglang.srt.layers.activation import get_act_fn
|
26
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
27
|
+
from sglang.srt.layers.linear import (
|
28
|
+
ColumnParallelLinear,
|
29
|
+
QKVParallelLinear,
|
30
|
+
RowParallelLinear,
|
31
|
+
)
|
32
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
34
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
35
|
+
from sglang.srt.managers.schedule_batch import ImageInputs
|
36
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
|
+
from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
|
38
|
+
|
39
|
+
|
40
|
+
class ColumnParallelConv2dPatch(torch.nn.Module):
|
41
|
+
"""Conv2D Patching layer with model parallelism.
|
42
|
+
Column parallel over unfolded input.
|
43
|
+
Arguments:
|
44
|
+
in_channels: Input channels.
|
45
|
+
out_channels: Output channels.
|
46
|
+
kernel_size: Size of convolution kernel.
|
47
|
+
stride (default 1): Stride for convolution.
|
48
|
+
bias (default False): Use bias in Conv2d.
|
49
|
+
Input: (bsz, in_channels, width, height)
|
50
|
+
Output: (bsz, num_tokens, out_channels)
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(
|
54
|
+
self,
|
55
|
+
in_channels: int,
|
56
|
+
out_channels: int,
|
57
|
+
kernel_size: Union[int, Tuple[int, int]],
|
58
|
+
stride: Union[int, Tuple[int, int]],
|
59
|
+
bias: bool = False,
|
60
|
+
) -> None:
|
61
|
+
super().__init__()
|
62
|
+
if isinstance(kernel_size, int):
|
63
|
+
kernel_size = (kernel_size, kernel_size)
|
64
|
+
self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
|
65
|
+
self._linear = ColumnParallelLinear(
|
66
|
+
in_channels * kernel_size[0] * kernel_size[1],
|
67
|
+
out_channels,
|
68
|
+
bias=bias,
|
69
|
+
)
|
70
|
+
|
71
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
72
|
+
x = self._unfold(x)
|
73
|
+
x = x.permute(0, 2, 1)
|
74
|
+
x, _ = self._linear(x)
|
75
|
+
return x
|
76
|
+
|
77
|
+
|
78
|
+
class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
|
79
|
+
|
80
|
+
def __init__(self, config: config_mllama.MllamaVisionConfig, is_gated: bool = True):
|
81
|
+
super().__init__()
|
82
|
+
self.max_num_tiles = config.max_num_tiles
|
83
|
+
self.hidden_size = config.hidden_size
|
84
|
+
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
85
|
+
self.is_gated = is_gated
|
86
|
+
|
87
|
+
self.embedding = nn.Embedding(
|
88
|
+
self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size
|
89
|
+
)
|
90
|
+
if is_gated:
|
91
|
+
self.gate = nn.Parameter(torch.zeros(1))
|
92
|
+
|
93
|
+
def forward(
|
94
|
+
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
95
|
+
) -> torch.Tensor:
|
96
|
+
embeddings = self.embedding(aspect_ratio_ids)
|
97
|
+
embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
|
98
|
+
|
99
|
+
if self.is_gated:
|
100
|
+
embeddings = embeddings * self.gate.tanh()
|
101
|
+
|
102
|
+
hidden_state = hidden_state + embeddings
|
103
|
+
return hidden_state
|
104
|
+
|
105
|
+
|
106
|
+
class MllamaPrecomputedPositionEmbedding(nn.Module):
|
107
|
+
def __init__(self, config: config_mllama.MllamaVisionConfig):
|
108
|
+
super().__init__()
|
109
|
+
self.max_num_tiles = config.max_num_tiles
|
110
|
+
self.max_aspect_ratio_id = config.max_aspect_ratio_id
|
111
|
+
self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
|
112
|
+
self.hidden_size = config.hidden_size
|
113
|
+
self.scale = config.hidden_size**-0.5
|
114
|
+
|
115
|
+
self.gate = nn.Parameter(torch.zeros(1))
|
116
|
+
|
117
|
+
# position embedding
|
118
|
+
position_embedding = torch.randn(self.num_patches, self.hidden_size)
|
119
|
+
self.embedding = nn.Parameter(self.scale * position_embedding)
|
120
|
+
|
121
|
+
# tile position embedding
|
122
|
+
self.tile_embedding = nn.Embedding(
|
123
|
+
self.max_aspect_ratio_id + 1,
|
124
|
+
self.max_num_tiles * self.num_patches * self.hidden_size,
|
125
|
+
)
|
126
|
+
|
127
|
+
def forward(
|
128
|
+
self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
|
129
|
+
) -> torch.Tensor:
|
130
|
+
# position embeddings
|
131
|
+
gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
|
132
|
+
hidden_state = hidden_state + gated_position_embedding.view(
|
133
|
+
1, 1, self.num_patches, self.hidden_size
|
134
|
+
)
|
135
|
+
|
136
|
+
# precomputed tile position embeddings
|
137
|
+
tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
|
138
|
+
batch_size = hidden_state.shape[0]
|
139
|
+
tile_position_embedding = tile_position_embedding.reshape(
|
140
|
+
batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
|
141
|
+
)
|
142
|
+
gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
|
143
|
+
hidden_state = hidden_state + gated_tile_position_embedding
|
144
|
+
|
145
|
+
return hidden_state
|
146
|
+
|
147
|
+
|
148
|
+
class MllamaVisionSdpaAttention(nn.Module):
|
149
|
+
def __init__(self, config: config_mllama.MllamaVisionConfig):
|
150
|
+
super().__init__()
|
151
|
+
|
152
|
+
model_parallel_size = get_tensor_model_parallel_world_size()
|
153
|
+
self.embed_dim = config.hidden_size
|
154
|
+
self.num_heads = config.attention_heads
|
155
|
+
self.head_dim = config.hidden_size // config.attention_heads
|
156
|
+
self.num_local_heads = self.num_heads // model_parallel_size
|
157
|
+
self.q_size = self.num_local_heads * self.head_dim
|
158
|
+
self.kv_size = self.num_local_heads * self.head_dim
|
159
|
+
|
160
|
+
self.qkv_proj = QKVParallelLinear(
|
161
|
+
self.embed_dim,
|
162
|
+
self.head_dim,
|
163
|
+
self.num_heads,
|
164
|
+
bias=False,
|
165
|
+
)
|
166
|
+
self.o_proj = RowParallelLinear(
|
167
|
+
self.num_heads * self.head_dim,
|
168
|
+
self.embed_dim,
|
169
|
+
bias=False,
|
170
|
+
input_is_parallel=True,
|
171
|
+
)
|
172
|
+
|
173
|
+
def forward(
|
174
|
+
self,
|
175
|
+
hidden_state: torch.Tensor,
|
176
|
+
attention_mask: Optional[torch.Tensor] = None,
|
177
|
+
) -> torch.Tensor:
|
178
|
+
qkv, _ = self.qkv_proj(hidden_state)
|
179
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
180
|
+
q = q.view(
|
181
|
+
q.shape[0], q.shape[1], self.num_local_heads, self.head_dim
|
182
|
+
).transpose(1, 2)
|
183
|
+
k = k.view(
|
184
|
+
k.shape[0], k.shape[1], self.num_local_heads, self.head_dim
|
185
|
+
).transpose(1, 2)
|
186
|
+
v = v.view(
|
187
|
+
v.shape[0], v.shape[1], self.num_local_heads, self.head_dim
|
188
|
+
).transpose(1, 2)
|
189
|
+
|
190
|
+
# TODO: remove padding in image encoder
|
191
|
+
attn_output = F.scaled_dot_product_attention(
|
192
|
+
q, k, v, attn_mask=attention_mask, dropout_p=0.0
|
193
|
+
)
|
194
|
+
|
195
|
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
196
|
+
attn_output = attn_output.reshape(
|
197
|
+
attn_output.shape[0], attn_output.shape[1], -1
|
198
|
+
)
|
199
|
+
output, _ = self.o_proj(attn_output)
|
200
|
+
return output
|
201
|
+
|
202
|
+
|
203
|
+
class MllamaVisionMLP(nn.Module):
|
204
|
+
def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
|
205
|
+
super().__init__()
|
206
|
+
self.config = config
|
207
|
+
self.activation_fn = get_act_fn(config.hidden_act)
|
208
|
+
self.fc1 = ColumnParallelLinear(
|
209
|
+
config.hidden_size,
|
210
|
+
config.intermediate_size,
|
211
|
+
bias=True,
|
212
|
+
quant_config=quant_config,
|
213
|
+
)
|
214
|
+
self.fc2 = RowParallelLinear(
|
215
|
+
config.intermediate_size,
|
216
|
+
config.hidden_size,
|
217
|
+
bias=True,
|
218
|
+
quant_config=quant_config,
|
219
|
+
)
|
220
|
+
|
221
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
222
|
+
hidden_states, _ = self.fc1(hidden_states)
|
223
|
+
hidden_states = self.activation_fn(hidden_states)
|
224
|
+
hidden_states, _ = self.fc2(hidden_states)
|
225
|
+
|
226
|
+
return hidden_states
|
227
|
+
|
228
|
+
|
229
|
+
class MllamaVisionEncoderLayer(nn.Module):
|
230
|
+
def __init__(
|
231
|
+
self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False
|
232
|
+
):
|
233
|
+
super().__init__()
|
234
|
+
|
235
|
+
self.hidden_size = config.hidden_size
|
236
|
+
self.num_attention_heads = config.attention_heads
|
237
|
+
self.is_gated = is_gated
|
238
|
+
self.intermediate_size = config.intermediate_size
|
239
|
+
|
240
|
+
self.self_attn = MllamaVisionSdpaAttention(config)
|
241
|
+
self.mlp = MllamaVisionMLP(config)
|
242
|
+
|
243
|
+
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
|
244
|
+
self.post_attention_layernorm = nn.LayerNorm(
|
245
|
+
self.hidden_size, eps=config.norm_eps
|
246
|
+
)
|
247
|
+
|
248
|
+
# there used to be an if else here, no code path
|
249
|
+
if is_gated:
|
250
|
+
self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
|
251
|
+
self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)
|
252
|
+
|
253
|
+
def forward(
|
254
|
+
self,
|
255
|
+
hidden_state: torch.Tensor,
|
256
|
+
attention_mask: Optional[torch.Tensor] = None,
|
257
|
+
):
|
258
|
+
# Self Attention
|
259
|
+
residual = hidden_state
|
260
|
+
hidden_state = self.input_layernorm(hidden_state)
|
261
|
+
hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
|
262
|
+
gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
|
263
|
+
hidden_state = residual + gate_attn * hidden_state
|
264
|
+
|
265
|
+
# Feed forward
|
266
|
+
residual = hidden_state
|
267
|
+
hidden_state = self.post_attention_layernorm(hidden_state)
|
268
|
+
hidden_state = self.mlp(hidden_state)
|
269
|
+
gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
|
270
|
+
hidden_state = residual + gate_ffn * hidden_state
|
271
|
+
|
272
|
+
return hidden_state
|
273
|
+
|
274
|
+
|
275
|
+
class MllamaVisionEncoder(nn.Module):
|
276
|
+
def __init__(
|
277
|
+
self,
|
278
|
+
config: config_mllama.MllamaVisionConfig,
|
279
|
+
num_layers=32,
|
280
|
+
is_gated=False,
|
281
|
+
output_hidden_states=None,
|
282
|
+
):
|
283
|
+
super().__init__()
|
284
|
+
self.config = config
|
285
|
+
self.layers = nn.ModuleList(
|
286
|
+
[MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)]
|
287
|
+
)
|
288
|
+
self.output_hidden_states = output_hidden_states or []
|
289
|
+
|
290
|
+
def forward(
|
291
|
+
self,
|
292
|
+
hidden_states: torch.Tensor,
|
293
|
+
attention_mask: Optional[torch.Tensor] = None,
|
294
|
+
) -> Union[Tuple, BaseModelOutput]:
|
295
|
+
encoder_states = ()
|
296
|
+
|
297
|
+
for i, encoder_layer in enumerate(self.layers):
|
298
|
+
if i in self.output_hidden_states:
|
299
|
+
encoder_states = encoder_states + (hidden_states,)
|
300
|
+
hidden_states = encoder_layer(
|
301
|
+
hidden_states,
|
302
|
+
attention_mask,
|
303
|
+
)
|
304
|
+
|
305
|
+
if len(self.layers) - 1 in self.output_hidden_states:
|
306
|
+
encoder_states = encoder_states + (hidden_states,)
|
307
|
+
|
308
|
+
return hidden_states, encoder_states
|
309
|
+
|
310
|
+
|
311
|
+
class MllamaVisionModel(nn.Module):
|
312
|
+
def __init__(self, config: config_mllama.MllamaVisionConfig):
|
313
|
+
super().__init__()
|
314
|
+
self.image_size = config.image_size
|
315
|
+
self.patch_size = config.patch_size
|
316
|
+
self.max_num_tiles = config.max_num_tiles
|
317
|
+
self.hidden_size = config.hidden_size
|
318
|
+
self.in_channels = config.num_channels
|
319
|
+
self.intermediate_layers_indices = config.intermediate_layers_indices
|
320
|
+
|
321
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
|
322
|
+
self.scale = config.hidden_size**-0.5
|
323
|
+
|
324
|
+
self.patch_embedding = ColumnParallelConv2dPatch(
|
325
|
+
in_channels=config.num_channels,
|
326
|
+
out_channels=self.hidden_size,
|
327
|
+
kernel_size=self.patch_size,
|
328
|
+
stride=self.patch_size,
|
329
|
+
bias=False,
|
330
|
+
)
|
331
|
+
|
332
|
+
self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
|
333
|
+
self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config)
|
334
|
+
|
335
|
+
self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
336
|
+
config, is_gated=True
|
337
|
+
)
|
338
|
+
self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
|
339
|
+
config, is_gated=True
|
340
|
+
)
|
341
|
+
|
342
|
+
# layer norms
|
343
|
+
self.layernorm_pre = nn.LayerNorm(self.hidden_size)
|
344
|
+
self.layernorm_post = nn.LayerNorm(self.hidden_size)
|
345
|
+
|
346
|
+
# encoders
|
347
|
+
self.transformer = MllamaVisionEncoder(
|
348
|
+
config,
|
349
|
+
config.num_hidden_layers,
|
350
|
+
is_gated=False,
|
351
|
+
output_hidden_states=config.intermediate_layers_indices,
|
352
|
+
)
|
353
|
+
self.global_transformer = MllamaVisionEncoder(
|
354
|
+
config, config.num_global_layers, is_gated=True
|
355
|
+
)
|
356
|
+
|
357
|
+
def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
|
358
|
+
batch_size, _, hidden_size = hidden_state.shape
|
359
|
+
class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
|
360
|
+
hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
|
361
|
+
return hidden_state
|
362
|
+
|
363
|
+
def forward(
|
364
|
+
self,
|
365
|
+
pixel_values: torch.Tensor,
|
366
|
+
aspect_ratio_ids: torch.Tensor,
|
367
|
+
aspect_ratio_mask: torch.Tensor,
|
368
|
+
) -> torch.Tensor:
|
369
|
+
batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
|
370
|
+
pixel_values.shape
|
371
|
+
)
|
372
|
+
|
373
|
+
pixel_values = pixel_values.reshape(
|
374
|
+
batch_size * num_concurrent_media * num_tiles, num_channels, height, width
|
375
|
+
)
|
376
|
+
aspect_ratio_ids = aspect_ratio_ids.reshape(
|
377
|
+
batch_size * num_concurrent_media, -1
|
378
|
+
)
|
379
|
+
|
380
|
+
# patch embedding
|
381
|
+
patch_embeds = self.patch_embedding(
|
382
|
+
pixel_values.to(self.layernorm_pre.weight.dtype)
|
383
|
+
)
|
384
|
+
hidden_state = patch_embeds
|
385
|
+
hidden_state = ps.get_tp_group().all_gather(hidden_state)
|
386
|
+
|
387
|
+
# tile embeddings
|
388
|
+
_, num_patches, dim = hidden_state.shape
|
389
|
+
hidden_state = hidden_state.reshape(
|
390
|
+
batch_size * num_concurrent_media, num_tiles, -1, dim
|
391
|
+
)
|
392
|
+
hidden_state = self.pre_tile_positional_embedding(
|
393
|
+
hidden_state, aspect_ratio_ids
|
394
|
+
)
|
395
|
+
|
396
|
+
# apply cls token
|
397
|
+
hidden_state = hidden_state.reshape(
|
398
|
+
batch_size * num_concurrent_media * num_tiles, num_patches, dim
|
399
|
+
)
|
400
|
+
hidden_state = self.apply_class_embedding(hidden_state)
|
401
|
+
num_patches += 1
|
402
|
+
|
403
|
+
# apply position embeddings
|
404
|
+
hidden_state = hidden_state.reshape(
|
405
|
+
batch_size * num_concurrent_media, num_tiles, num_patches, dim
|
406
|
+
)
|
407
|
+
hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
|
408
|
+
|
409
|
+
# apply encoder
|
410
|
+
hidden_state = self.layernorm_pre(hidden_state)
|
411
|
+
|
412
|
+
# Compute the number of tokens to pad
|
413
|
+
num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
|
414
|
+
# Compute padding tuple for pad function
|
415
|
+
padding = (
|
416
|
+
0,
|
417
|
+
0,
|
418
|
+
0,
|
419
|
+
num_padding_patches,
|
420
|
+
) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
|
421
|
+
# Pad the tensor
|
422
|
+
hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
|
423
|
+
slice_index = -num_padding_patches if num_padding_patches > 0 else None
|
424
|
+
|
425
|
+
attention_mask = aspect_ratio_mask.reshape(
|
426
|
+
batch_size * num_concurrent_media, -1
|
427
|
+
)
|
428
|
+
attention_mask = _prepare_aspect_ratio_attention_mask(
|
429
|
+
aspect_ratio_mask=attention_mask,
|
430
|
+
num_patches=self.num_patches,
|
431
|
+
target_length=hidden_state.shape[2],
|
432
|
+
dtype=self.layernorm_pre.weight.dtype,
|
433
|
+
)
|
434
|
+
|
435
|
+
hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
|
436
|
+
output = self.transformer(
|
437
|
+
hidden_state,
|
438
|
+
attention_mask=attention_mask,
|
439
|
+
)
|
440
|
+
hidden_state, intermediate_hidden_states = output[0], output[1]
|
441
|
+
intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
|
442
|
+
|
443
|
+
# apply global encoder
|
444
|
+
hidden_state = self.layernorm_post(hidden_state)
|
445
|
+
hidden_state = hidden_state.reshape(
|
446
|
+
batch_size * num_concurrent_media,
|
447
|
+
num_tiles,
|
448
|
+
num_patches + num_padding_patches,
|
449
|
+
dim,
|
450
|
+
)
|
451
|
+
hidden_state = self.post_tile_positional_embedding(
|
452
|
+
hidden_state, aspect_ratio_ids
|
453
|
+
)
|
454
|
+
hidden_state = hidden_state.reshape(
|
455
|
+
batch_size * num_concurrent_media,
|
456
|
+
num_tiles * (num_patches + num_padding_patches),
|
457
|
+
dim,
|
458
|
+
)
|
459
|
+
hidden_state = self.global_transformer(
|
460
|
+
hidden_state, attention_mask=attention_mask
|
461
|
+
)[0]
|
462
|
+
hidden_state = hidden_state.reshape(
|
463
|
+
batch_size * num_concurrent_media,
|
464
|
+
num_tiles,
|
465
|
+
num_patches + num_padding_patches,
|
466
|
+
dim,
|
467
|
+
)
|
468
|
+
hidden_state = hidden_state[:, :, :slice_index]
|
469
|
+
|
470
|
+
# adding intermediate layer outputs
|
471
|
+
hidden_state = hidden_state.reshape(
|
472
|
+
batch_size, num_concurrent_media, num_tiles, num_patches, dim
|
473
|
+
)
|
474
|
+
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
475
|
+
batch_size * num_concurrent_media,
|
476
|
+
num_tiles,
|
477
|
+
num_patches + num_padding_patches,
|
478
|
+
-1,
|
479
|
+
)
|
480
|
+
intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
|
481
|
+
intermediate_hidden_states = intermediate_hidden_states.reshape(
|
482
|
+
batch_size, num_concurrent_media, num_tiles, num_patches, -1
|
483
|
+
)
|
484
|
+
hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
|
485
|
+
return hidden_state
|
486
|
+
|
487
|
+
|
488
|
+
class MllamaTextRMSNorm(nn.Module):
|
489
|
+
def __init__(self, hidden_size, eps=1e-6):
|
490
|
+
super().__init__()
|
491
|
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
492
|
+
self.variance_epsilon = eps
|
493
|
+
|
494
|
+
def forward(self, hidden_states):
|
495
|
+
input_dtype = hidden_states.dtype
|
496
|
+
hidden_states = hidden_states.to(torch.float32)
|
497
|
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
498
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
499
|
+
return self.weight * hidden_states.to(input_dtype)
|
500
|
+
|
501
|
+
def extra_repr(self):
|
502
|
+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
|
503
|
+
|
504
|
+
|
505
|
+
class MllamaTextCrossAttention(nn.Module):
|
506
|
+
def __init__(
|
507
|
+
self,
|
508
|
+
config: Optional[config_mllama.MllamaTextConfig] = None,
|
509
|
+
layer_id: Optional[int] = None,
|
510
|
+
quant_config: Optional[QuantizationConfig] = None,
|
511
|
+
):
|
512
|
+
super().__init__()
|
513
|
+
self.config = config
|
514
|
+
self.model_parallel_size = get_tensor_model_parallel_world_size()
|
515
|
+
self.num_heads = self.config.num_attention_heads
|
516
|
+
self.num_local_heads = self.num_heads // self.model_parallel_size
|
517
|
+
self.num_key_value_heads = self.config.num_key_value_heads
|
518
|
+
self.num_local_key_value_heads = (
|
519
|
+
self.num_key_value_heads // self.model_parallel_size
|
520
|
+
)
|
521
|
+
self.dropout = config.dropout
|
522
|
+
self.hidden_size = config.hidden_size
|
523
|
+
self.head_dim = config.hidden_size // self.num_heads
|
524
|
+
self.layer_id = layer_id
|
525
|
+
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
526
|
+
self.q_local_size = self.num_local_heads * self.head_dim
|
527
|
+
self.kv_local_size = self.num_local_key_value_heads * self.head_dim
|
528
|
+
|
529
|
+
self.qkv_proj = QKVParallelLinear(
|
530
|
+
self.hidden_size,
|
531
|
+
self.head_dim,
|
532
|
+
self.num_heads,
|
533
|
+
self.num_key_value_heads,
|
534
|
+
bias=False,
|
535
|
+
quant_config=quant_config,
|
536
|
+
)
|
537
|
+
self.o_proj = RowParallelLinear(
|
538
|
+
self.num_heads * self.head_dim,
|
539
|
+
self.hidden_size,
|
540
|
+
bias=False,
|
541
|
+
input_is_parallel=True,
|
542
|
+
quant_config=quant_config,
|
543
|
+
)
|
544
|
+
# vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
|
545
|
+
# use huggingface's instead
|
546
|
+
self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
547
|
+
self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
548
|
+
self.scaling = self.head_dim**-0.5
|
549
|
+
|
550
|
+
self.attn = RadixAttention(
|
551
|
+
self.num_local_heads,
|
552
|
+
self.head_dim,
|
553
|
+
self.scaling,
|
554
|
+
self.num_local_key_value_heads,
|
555
|
+
layer_id=layer_id,
|
556
|
+
is_cross_attention=True,
|
557
|
+
)
|
558
|
+
|
559
|
+
def forward(
|
560
|
+
self,
|
561
|
+
hidden_states: torch.Tensor,
|
562
|
+
attention_mask: Optional[torch.Tensor],
|
563
|
+
cross_attention_states: Optional[torch.Tensor],
|
564
|
+
forward_batch: ForwardBatch,
|
565
|
+
) -> torch.Tensor:
|
566
|
+
qkv_dec, _ = self.qkv_proj(hidden_states)
|
567
|
+
q, _, _ = qkv_dec.split(
|
568
|
+
[self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1
|
569
|
+
)
|
570
|
+
if cross_attention_states is None:
|
571
|
+
k = None
|
572
|
+
v = None
|
573
|
+
else:
|
574
|
+
qkv_enc, _ = self.qkv_proj(cross_attention_states)
|
575
|
+
_, k, v = qkv_enc.split(
|
576
|
+
[self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1
|
577
|
+
)
|
578
|
+
k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
|
579
|
+
v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
|
580
|
+
k = self.k_norm(k)
|
581
|
+
q = q.view(-1, self.num_local_heads, self.head_dim)
|
582
|
+
q = self.q_norm(q)
|
583
|
+
|
584
|
+
output = self.attn(q, k, v, forward_batch)
|
585
|
+
out, _ = self.o_proj(output)
|
586
|
+
return out
|
587
|
+
|
588
|
+
|
589
|
+
class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
|
590
|
+
"""Cross-attention transformer block with tanh-gated attention
|
591
|
+
and feedforward."""
|
592
|
+
|
593
|
+
def __init__(
|
594
|
+
self,
|
595
|
+
config: config_mllama.MllamaTextConfig,
|
596
|
+
layer_id: int,
|
597
|
+
quant_config: Optional[QuantizationConfig],
|
598
|
+
) -> None:
|
599
|
+
super().__init__()
|
600
|
+
self.layer_id = layer_id
|
601
|
+
self.cross_attn = MllamaTextCrossAttention(
|
602
|
+
config=config,
|
603
|
+
layer_id=layer_id,
|
604
|
+
quant_config=quant_config,
|
605
|
+
)
|
606
|
+
|
607
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
608
|
+
self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
|
609
|
+
|
610
|
+
self.mlp = LlamaMLP(
|
611
|
+
hidden_size=config.hidden_size,
|
612
|
+
intermediate_size=config.intermediate_size,
|
613
|
+
hidden_act=config.hidden_act,
|
614
|
+
quant_config=quant_config,
|
615
|
+
)
|
616
|
+
self.post_attention_layernorm = RMSNorm(
|
617
|
+
config.hidden_size, eps=config.rms_norm_eps
|
618
|
+
)
|
619
|
+
self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))
|
620
|
+
|
621
|
+
def forward(
|
622
|
+
self,
|
623
|
+
hidden_states: torch.Tensor,
|
624
|
+
cross_attention_states: torch.Tensor,
|
625
|
+
cross_attention_mask: torch.Tensor,
|
626
|
+
full_text_row_masked_out_mask: torch.Tensor,
|
627
|
+
forward_batch: ForwardBatch,
|
628
|
+
) -> torch.Tensor:
|
629
|
+
residual = hidden_states
|
630
|
+
hidden_states = self.input_layernorm(hidden_states)
|
631
|
+
|
632
|
+
hidden_states = self.cross_attn(
|
633
|
+
hidden_states=hidden_states,
|
634
|
+
attention_mask=cross_attention_mask,
|
635
|
+
cross_attention_states=cross_attention_states,
|
636
|
+
forward_batch=forward_batch,
|
637
|
+
)
|
638
|
+
hidden_states = full_text_row_masked_out_mask * hidden_states
|
639
|
+
hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
|
640
|
+
|
641
|
+
residual = hidden_states
|
642
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
643
|
+
hidden_states = self.mlp(hidden_states)
|
644
|
+
hidden_states = full_text_row_masked_out_mask * hidden_states
|
645
|
+
hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
|
646
|
+
return hidden_states
|
647
|
+
|
648
|
+
|
649
|
+
class MllamaTextModel(nn.Module):
|
650
|
+
config_class = config_mllama.MllamaTextConfig
|
651
|
+
base_model_prefix = "model"
|
652
|
+
|
653
|
+
def __init__(
|
654
|
+
self,
|
655
|
+
config: config_mllama.MllamaTextConfig,
|
656
|
+
quant_config: Optional[QuantizationConfig],
|
657
|
+
cache_config=None,
|
658
|
+
):
|
659
|
+
super().__init__()
|
660
|
+
self.padding_id = config.pad_token_id
|
661
|
+
self.vocab_size = config.vocab_size
|
662
|
+
self.embed_tokens = VocabParallelEmbedding(
|
663
|
+
config.vocab_size + 8, config.hidden_size
|
664
|
+
)
|
665
|
+
self.cross_attention_layers = config.cross_attention_layers
|
666
|
+
|
667
|
+
layers = []
|
668
|
+
for layer_id in range(config.num_hidden_layers):
|
669
|
+
if layer_id in self.cross_attention_layers:
|
670
|
+
layers.append(
|
671
|
+
MllamaCrossAttentionDecoderLayer(
|
672
|
+
config, layer_id, quant_config=quant_config
|
673
|
+
)
|
674
|
+
)
|
675
|
+
else:
|
676
|
+
# TODO: force LlamaDecoderLayer to config.attention_bias=False
|
677
|
+
layers.append(
|
678
|
+
LlamaDecoderLayer(
|
679
|
+
config, quant_config=quant_config, layer_id=layer_id
|
680
|
+
)
|
681
|
+
)
|
682
|
+
|
683
|
+
self.layers = nn.ModuleList(layers)
|
684
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
685
|
+
|
686
|
+
def forward(
|
687
|
+
self,
|
688
|
+
input_ids: torch.LongTensor,
|
689
|
+
positions: Optional[torch.LongTensor],
|
690
|
+
cross_attention_states: Optional[torch.LongTensor],
|
691
|
+
cross_attention_mask: Optional[torch.LongTensor],
|
692
|
+
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
693
|
+
forward_batch: ForwardBatch,
|
694
|
+
skip_cross_attention: bool,
|
695
|
+
) -> torch.Tensor:
|
696
|
+
inputs_embeds = self.embed_tokens(input_ids)
|
697
|
+
hidden_states = inputs_embeds
|
698
|
+
|
699
|
+
for _, decoder_layer in enumerate(self.layers):
|
700
|
+
if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
|
701
|
+
if not skip_cross_attention:
|
702
|
+
hidden_states = decoder_layer(
|
703
|
+
hidden_states=hidden_states,
|
704
|
+
cross_attention_states=cross_attention_states,
|
705
|
+
cross_attention_mask=cross_attention_mask,
|
706
|
+
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
707
|
+
forward_batch=forward_batch,
|
708
|
+
)
|
709
|
+
elif isinstance(decoder_layer, LlamaDecoderLayer):
|
710
|
+
hidden_states, residual = decoder_layer(
|
711
|
+
positions=positions,
|
712
|
+
hidden_states=hidden_states,
|
713
|
+
forward_batch=forward_batch,
|
714
|
+
residual=None,
|
715
|
+
)
|
716
|
+
hidden_states = hidden_states + residual
|
717
|
+
else:
|
718
|
+
raise ValueError(f"Unknown decoder layer type {type(decoder_layer)}")
|
719
|
+
hidden_states = self.norm(hidden_states)
|
720
|
+
return hidden_states
|
721
|
+
|
722
|
+
|
723
|
+
class MllamaForCausalLM(nn.Module):
|
724
|
+
config_class = config_mllama.MllamaTextConfig
|
725
|
+
base_model_prefix = "language_model"
|
726
|
+
_no_split_modules = [
|
727
|
+
"MllamaCrossAttentionDecoderLayer",
|
728
|
+
"MllamaSelfAttentionDecoderLayer",
|
729
|
+
]
|
730
|
+
|
731
|
+
def __init__(
|
732
|
+
self,
|
733
|
+
config: config_mllama.MllamaTextConfig,
|
734
|
+
quant_config: Optional[QuantizationConfig],
|
735
|
+
cache_config=None,
|
736
|
+
):
|
737
|
+
super().__init__()
|
738
|
+
self.vocab_size = config.vocab_size
|
739
|
+
self.model = MllamaTextModel(config, cache_config, quant_config)
|
740
|
+
self.lm_head = ParallelLMHead(
|
741
|
+
config.vocab_size,
|
742
|
+
config.hidden_size,
|
743
|
+
org_num_embeddings=config.vocab_size,
|
744
|
+
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
745
|
+
quant_config=quant_config,
|
746
|
+
)
|
747
|
+
|
748
|
+
def forward(
|
749
|
+
self,
|
750
|
+
input_ids: torch.LongTensor,
|
751
|
+
positions: Optional[torch.LongTensor],
|
752
|
+
cross_attention_states: Optional[torch.LongTensor],
|
753
|
+
cross_attention_mask: Optional[torch.LongTensor],
|
754
|
+
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]],
|
755
|
+
forward_batch: ForwardBatch,
|
756
|
+
skip_cross_attention: bool,
|
757
|
+
) -> torch.Tensor:
|
758
|
+
hidden_states = self.model(
|
759
|
+
input_ids=input_ids,
|
760
|
+
positions=positions,
|
761
|
+
cross_attention_states=cross_attention_states,
|
762
|
+
cross_attention_mask=cross_attention_mask,
|
763
|
+
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
764
|
+
forward_batch=forward_batch,
|
765
|
+
skip_cross_attention=skip_cross_attention,
|
766
|
+
)
|
767
|
+
return hidden_states
|
768
|
+
|
769
|
+
|
770
|
+
class MllamaForConditionalGeneration(nn.Module):
|
771
|
+
def __init__(
|
772
|
+
self,
|
773
|
+
config: config_mllama.MllamaConfig,
|
774
|
+
quant_config: Optional[QuantizationConfig] = None,
|
775
|
+
cache_config=None,
|
776
|
+
):
|
777
|
+
super().__init__()
|
778
|
+
self.vocab_size = config.text_config.vocab_size
|
779
|
+
self.hidden_size = config.text_config.hidden_size
|
780
|
+
self.max_num_tiles = config.vision_config.max_num_tiles
|
781
|
+
self.vision_output_dim = config.vision_config.vision_output_dim
|
782
|
+
self.pad_token_id = (
|
783
|
+
config.pad_token_id if config.pad_token_id is not None else -1
|
784
|
+
)
|
785
|
+
self.image_size = config.vision_config.image_size
|
786
|
+
|
787
|
+
self.vision_model = MllamaVisionModel(config.vision_config)
|
788
|
+
self.language_model = MllamaForCausalLM(
|
789
|
+
config.text_config,
|
790
|
+
cache_config=cache_config,
|
791
|
+
quant_config=quant_config,
|
792
|
+
)
|
793
|
+
self.multi_modal_projector = nn.Linear(
|
794
|
+
config.vision_config.vision_output_dim,
|
795
|
+
config.text_config.hidden_size,
|
796
|
+
bias=True,
|
797
|
+
)
|
798
|
+
self.logits_processor = LogitsProcessor(config.text_config)
|
799
|
+
self.capture_mode = False
|
800
|
+
|
801
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
|
802
|
+
pixel_values = image_inputs.pixel_values
|
803
|
+
pad_values = image_inputs.pad_values
|
804
|
+
|
805
|
+
num_concurrent_media, num_tiles = pixel_values.shape[1:3]
|
806
|
+
num_patches = self.vision_model.num_patches
|
807
|
+
image_len = num_concurrent_media * num_tiles * num_patches
|
808
|
+
image_inputs.num_image_tokens = image_len
|
809
|
+
|
810
|
+
pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
|
811
|
+
|
812
|
+
return pad_ids[:image_len] + input_ids
|
813
|
+
|
814
|
+
def _batch_image_inputs(self, forward_batch: ForwardBatch):
|
815
|
+
if forward_batch.forward_mode.is_decode() or all(forward_batch.encoder_cached):
|
816
|
+
return None, None, None, None
|
817
|
+
|
818
|
+
# pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
|
819
|
+
max_num_images = max_num_tiles = bs = 0
|
820
|
+
for i, im in enumerate(forward_batch.image_inputs):
|
821
|
+
if not forward_batch.encoder_cached[i] and im is not None:
|
822
|
+
max_num_images = max(max_num_images, im.pixel_values.shape[1])
|
823
|
+
max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
|
824
|
+
bs += 1
|
825
|
+
|
826
|
+
if max_num_images * max_num_tiles * bs == 0:
|
827
|
+
return None, None, None, None
|
828
|
+
|
829
|
+
with forward_batch.out_cache_loc.device:
|
830
|
+
batched_images = torch.zeros(
|
831
|
+
bs,
|
832
|
+
max_num_images,
|
833
|
+
max_num_tiles,
|
834
|
+
3,
|
835
|
+
self.image_size,
|
836
|
+
self.image_size,
|
837
|
+
dtype=torch.float32,
|
838
|
+
)
|
839
|
+
batched_ar_ids = torch.ones(
|
840
|
+
bs, max_num_images, dtype=torch.int64, device="cuda"
|
841
|
+
)
|
842
|
+
batched_ar_mask = torch.zeros(
|
843
|
+
bs, max_num_images, max_num_tiles, dtype=torch.int64
|
844
|
+
)
|
845
|
+
i = 0
|
846
|
+
encoder_lens_need = []
|
847
|
+
for k, im in enumerate(forward_batch.image_inputs):
|
848
|
+
if forward_batch.encoder_cached[k] or im is None:
|
849
|
+
continue
|
850
|
+
|
851
|
+
encoder_lens_need.append(forward_batch.encoder_lens[k])
|
852
|
+
for j in range(im.pixel_values.shape[1]):
|
853
|
+
img = im.pixel_values[0, j]
|
854
|
+
num_tiles = img.shape[0]
|
855
|
+
batched_images[i, j, :num_tiles] = img
|
856
|
+
batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j]
|
857
|
+
batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j]
|
858
|
+
i += 1
|
859
|
+
|
860
|
+
return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need
|
861
|
+
|
862
|
+
def flat_encoder_result(
|
863
|
+
self, cross_attention_states: torch.Tensor, encoder_lens_need: List[int]
|
864
|
+
):
|
865
|
+
# NOTE: not all encoders need computation, some are cached
|
866
|
+
head_dim = cross_attention_states.shape[-1]
|
867
|
+
total_encoder_len = sum(encoder_lens_need)
|
868
|
+
cross_attention_states_flat = torch.zeros(
|
869
|
+
total_encoder_len,
|
870
|
+
head_dim,
|
871
|
+
device=cross_attention_states.device,
|
872
|
+
dtype=cross_attention_states.dtype,
|
873
|
+
)
|
874
|
+
|
875
|
+
i = start_pos = 0
|
876
|
+
for encoder_len in encoder_lens_need:
|
877
|
+
if encoder_len == 0:
|
878
|
+
continue
|
879
|
+
end_pos = start_pos + encoder_len
|
880
|
+
cross_attention_states_flat[start_pos:end_pos] = cross_attention_states[i][
|
881
|
+
:encoder_len
|
882
|
+
]
|
883
|
+
i += 1
|
884
|
+
start_pos += encoder_len
|
885
|
+
|
886
|
+
return cross_attention_states_flat
|
887
|
+
|
888
|
+
def get_full_text_row_masked_out_mask(self, forward_batch: ForwardBatch):
|
889
|
+
if forward_batch.forward_mode.is_decode():
|
890
|
+
full_text_row_masked_out_mask = forward_batch.encoder_lens != 0
|
891
|
+
else:
|
892
|
+
full_text_row_masked_out_mask = torch.ones(
|
893
|
+
forward_batch.extend_seq_lens.sum(), dtype=torch.bool
|
894
|
+
)
|
895
|
+
start_pos = 0
|
896
|
+
|
897
|
+
for seq_len, encoder_len in zip(
|
898
|
+
forward_batch.seq_lens.tolist(), forward_batch.encoder_lens_cpu
|
899
|
+
):
|
900
|
+
if encoder_len == 0:
|
901
|
+
full_text_row_masked_out_mask[start_pos : start_pos + seq_len] = (
|
902
|
+
False
|
903
|
+
)
|
904
|
+
start_pos += encoder_len
|
905
|
+
|
906
|
+
full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
|
907
|
+
forward_batch.seq_lens.device
|
908
|
+
)
|
909
|
+
|
910
|
+
return full_text_row_masked_out_mask.reshape(-1, 1)
|
911
|
+
|
912
|
+
def forward(
|
913
|
+
self,
|
914
|
+
input_ids: torch.Tensor,
|
915
|
+
positions: torch.Tensor,
|
916
|
+
forward_batch: ForwardBatch,
|
917
|
+
) -> Union[Tuple, CausalLMOutputWithPast]:
|
918
|
+
batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
|
919
|
+
self._batch_image_inputs(forward_batch)
|
920
|
+
)
|
921
|
+
|
922
|
+
# TODO: support multi-image by this mask
|
923
|
+
cross_attention_mask = None
|
924
|
+
cross_attention_states = None
|
925
|
+
|
926
|
+
if self.capture_mode:
|
927
|
+
# NOTE: when doing cuda graph capture, we do not want to skip cross attention
|
928
|
+
# Make is a constant value to avoid cuda graph capture issue
|
929
|
+
skip_cross_attention = False
|
930
|
+
else:
|
931
|
+
# NOTE: we do not need image_inputs when prefill
|
932
|
+
assert len(forward_batch.encoder_lens) == len(forward_batch.seq_lens)
|
933
|
+
assert len(forward_batch.encoder_lens_cpu) == len(forward_batch.seq_lens)
|
934
|
+
skip_cross_attention = forward_batch.encoder_lens.max() == 0
|
935
|
+
|
936
|
+
if not skip_cross_attention:
|
937
|
+
full_text_row_masked_out_mask = self.get_full_text_row_masked_out_mask(
|
938
|
+
forward_batch
|
939
|
+
)
|
940
|
+
else:
|
941
|
+
full_text_row_masked_out_mask = None
|
942
|
+
|
943
|
+
if batched_images is not None:
|
944
|
+
# NOTE: llama's reference implementation runs vision model on CPU
|
945
|
+
cross_attention_states = self.vision_model(
|
946
|
+
batched_images, batched_ar_ids, batched_ar_mask
|
947
|
+
)
|
948
|
+
cross_attention_states = self.multi_modal_projector(cross_attention_states)
|
949
|
+
|
950
|
+
bs, _, _, _, image_token_dim = cross_attention_states.shape
|
951
|
+
cross_attention_states = cross_attention_states.view(
|
952
|
+
bs, -1, image_token_dim
|
953
|
+
)
|
954
|
+
|
955
|
+
cross_attention_states = self.flat_encoder_result(
|
956
|
+
cross_attention_states, encoder_lens_need
|
957
|
+
)
|
958
|
+
|
959
|
+
hidden_states = self.language_model(
|
960
|
+
input_ids=input_ids,
|
961
|
+
positions=positions,
|
962
|
+
cross_attention_states=cross_attention_states,
|
963
|
+
cross_attention_mask=cross_attention_mask,
|
964
|
+
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
|
965
|
+
forward_batch=forward_batch,
|
966
|
+
skip_cross_attention=skip_cross_attention,
|
967
|
+
)
|
968
|
+
return self.logits_processor(
|
969
|
+
input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
|
970
|
+
)
|
971
|
+
|
972
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
973
|
+
stacked_params_mapping = [
|
974
|
+
# (param_name, shard_name, shard_id)
|
975
|
+
(".qkv_proj", ".q_proj", "q"),
|
976
|
+
(".qkv_proj", ".k_proj", "k"),
|
977
|
+
(".qkv_proj", ".v_proj", "v"),
|
978
|
+
(".gate_up_proj", ".gate_proj", 0),
|
979
|
+
(".gate_up_proj", ".up_proj", 1),
|
980
|
+
]
|
981
|
+
params_dict = dict(self.named_parameters())
|
982
|
+
updated_params = set()
|
983
|
+
for name, loaded_weight in weights:
|
984
|
+
if "patch_embedding.weight" in name:
|
985
|
+
name = name.replace(
|
986
|
+
"patch_embedding.weight", "patch_embedding._linear.weight"
|
987
|
+
)
|
988
|
+
loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
|
989
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
990
|
+
if weight_name not in name:
|
991
|
+
continue
|
992
|
+
name = name.replace(weight_name, param_name)
|
993
|
+
param = params_dict[name]
|
994
|
+
updated_params.add(name)
|
995
|
+
weight_loader = param.weight_loader
|
996
|
+
weight_loader(param, loaded_weight, shard_id)
|
997
|
+
break
|
998
|
+
else:
|
999
|
+
param = params_dict.pop(name)
|
1000
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
1001
|
+
weight_loader(param, loaded_weight)
|
1002
|
+
|
1003
|
+
|
1004
|
+
EntryClass = MllamaForConditionalGeneration
|