sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_latency.py +2 -1
  2. sglang/lang/chat_template.py +17 -0
  3. sglang/launch_server_llavavid.py +1 -1
  4. sglang/srt/configs/__init__.py +3 -0
  5. sglang/srt/configs/model_config.py +27 -2
  6. sglang/srt/configs/qwen2vl.py +133 -0
  7. sglang/srt/constrained/fsm_cache.py +10 -3
  8. sglang/srt/conversation.py +27 -0
  9. sglang/srt/hf_transformers_utils.py +16 -1
  10. sglang/srt/layers/attention/__init__.py +16 -5
  11. sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
  12. sglang/srt/layers/attention/flashinfer_backend.py +174 -54
  13. sglang/srt/layers/attention/triton_backend.py +22 -6
  14. sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
  15. sglang/srt/layers/linear.py +89 -63
  16. sglang/srt/layers/logits_processor.py +5 -5
  17. sglang/srt/layers/rotary_embedding.py +112 -0
  18. sglang/srt/layers/sampler.py +51 -39
  19. sglang/srt/lora/lora.py +3 -1
  20. sglang/srt/managers/data_parallel_controller.py +1 -1
  21. sglang/srt/managers/detokenizer_manager.py +4 -0
  22. sglang/srt/managers/image_processor.py +186 -13
  23. sglang/srt/managers/io_struct.py +10 -0
  24. sglang/srt/managers/schedule_batch.py +238 -68
  25. sglang/srt/managers/scheduler.py +69 -50
  26. sglang/srt/managers/tokenizer_manager.py +24 -4
  27. sglang/srt/managers/tp_worker.py +26 -111
  28. sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
  29. sglang/srt/mem_cache/memory_pool.py +56 -10
  30. sglang/srt/mem_cache/radix_cache.py +4 -3
  31. sglang/srt/model_executor/cuda_graph_runner.py +87 -28
  32. sglang/srt/model_executor/forward_batch_info.py +83 -3
  33. sglang/srt/model_executor/model_runner.py +32 -11
  34. sglang/srt/models/chatglm.py +3 -3
  35. sglang/srt/models/deepseek_v2.py +2 -2
  36. sglang/srt/models/mllama.py +1004 -0
  37. sglang/srt/models/qwen2_vl.py +724 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
  39. sglang/srt/sampling/sampling_batch_info.py +13 -3
  40. sglang/srt/sampling/sampling_params.py +5 -7
  41. sglang/srt/server.py +12 -0
  42. sglang/srt/server_args.py +10 -0
  43. sglang/srt/utils.py +22 -0
  44. sglang/test/run_eval.py +2 -0
  45. sglang/test/runners.py +20 -1
  46. sglang/test/srt/sampling/penaltylib/utils.py +1 -0
  47. sglang/test/test_utils.py +100 -3
  48. sglang/version.py +1 -1
  49. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
  50. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
  51. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
  52. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
  53. {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1004 @@
1
+ # Adapted from:
2
+ # https://github.com/vllm-project/vllm/blob/7193774b1ff8603ad5bf4598e5efba0d9a39b436/vllm/model_executor/models/mllama.py
3
+ """PyTorch Mllama model."""
4
+ import math
5
+ from typing import Iterable, List, Optional, Tuple, Union
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint
10
+ import transformers.models.mllama.configuration_mllama as config_mllama
11
+ import vllm.distributed.parallel_state as ps
12
+ from torch import nn
13
+ from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
14
+ from transformers.models.mllama.modeling_mllama import (
15
+ _prepare_aspect_ratio_attention_mask,
16
+ )
17
+ from vllm.distributed import get_tensor_model_parallel_world_size
18
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
19
+ DEFAULT_VOCAB_PADDING_SIZE,
20
+ ParallelLMHead,
21
+ VocabParallelEmbedding,
22
+ )
23
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
24
+
25
+ from sglang.srt.layers.activation import get_act_fn
26
+ from sglang.srt.layers.layernorm import RMSNorm
27
+ from sglang.srt.layers.linear import (
28
+ ColumnParallelLinear,
29
+ QKVParallelLinear,
30
+ RowParallelLinear,
31
+ )
32
+ from sglang.srt.layers.logits_processor import LogitsProcessor
33
+ from sglang.srt.layers.quantization import QuantizationConfig
34
+ from sglang.srt.layers.radix_attention import RadixAttention
35
+ from sglang.srt.managers.schedule_batch import ImageInputs
36
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
+ from sglang.srt.models.llama import LlamaDecoderLayer, LlamaMLP
38
+
39
+
40
+ class ColumnParallelConv2dPatch(torch.nn.Module):
41
+ """Conv2D Patching layer with model parallelism.
42
+ Column parallel over unfolded input.
43
+ Arguments:
44
+ in_channels: Input channels.
45
+ out_channels: Output channels.
46
+ kernel_size: Size of convolution kernel.
47
+ stride (default 1): Stride for convolution.
48
+ bias (default False): Use bias in Conv2d.
49
+ Input: (bsz, in_channels, width, height)
50
+ Output: (bsz, num_tokens, out_channels)
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ in_channels: int,
56
+ out_channels: int,
57
+ kernel_size: Union[int, Tuple[int, int]],
58
+ stride: Union[int, Tuple[int, int]],
59
+ bias: bool = False,
60
+ ) -> None:
61
+ super().__init__()
62
+ if isinstance(kernel_size, int):
63
+ kernel_size = (kernel_size, kernel_size)
64
+ self._unfold = torch.nn.Unfold(kernel_size=kernel_size, stride=stride)
65
+ self._linear = ColumnParallelLinear(
66
+ in_channels * kernel_size[0] * kernel_size[1],
67
+ out_channels,
68
+ bias=bias,
69
+ )
70
+
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ x = self._unfold(x)
73
+ x = x.permute(0, 2, 1)
74
+ x, _ = self._linear(x)
75
+ return x
76
+
77
+
78
+ class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
79
+
80
+ def __init__(self, config: config_mllama.MllamaVisionConfig, is_gated: bool = True):
81
+ super().__init__()
82
+ self.max_num_tiles = config.max_num_tiles
83
+ self.hidden_size = config.hidden_size
84
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
85
+ self.is_gated = is_gated
86
+
87
+ self.embedding = nn.Embedding(
88
+ self.max_aspect_ratio_id + 1, self.max_num_tiles * self.hidden_size
89
+ )
90
+ if is_gated:
91
+ self.gate = nn.Parameter(torch.zeros(1))
92
+
93
+ def forward(
94
+ self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
95
+ ) -> torch.Tensor:
96
+ embeddings = self.embedding(aspect_ratio_ids)
97
+ embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)
98
+
99
+ if self.is_gated:
100
+ embeddings = embeddings * self.gate.tanh()
101
+
102
+ hidden_state = hidden_state + embeddings
103
+ return hidden_state
104
+
105
+
106
+ class MllamaPrecomputedPositionEmbedding(nn.Module):
107
+ def __init__(self, config: config_mllama.MllamaVisionConfig):
108
+ super().__init__()
109
+ self.max_num_tiles = config.max_num_tiles
110
+ self.max_aspect_ratio_id = config.max_aspect_ratio_id
111
+ self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
112
+ self.hidden_size = config.hidden_size
113
+ self.scale = config.hidden_size**-0.5
114
+
115
+ self.gate = nn.Parameter(torch.zeros(1))
116
+
117
+ # position embedding
118
+ position_embedding = torch.randn(self.num_patches, self.hidden_size)
119
+ self.embedding = nn.Parameter(self.scale * position_embedding)
120
+
121
+ # tile position embedding
122
+ self.tile_embedding = nn.Embedding(
123
+ self.max_aspect_ratio_id + 1,
124
+ self.max_num_tiles * self.num_patches * self.hidden_size,
125
+ )
126
+
127
+ def forward(
128
+ self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
129
+ ) -> torch.Tensor:
130
+ # position embeddings
131
+ gated_position_embedding = (1 - self.gate.tanh()) * self.embedding
132
+ hidden_state = hidden_state + gated_position_embedding.view(
133
+ 1, 1, self.num_patches, self.hidden_size
134
+ )
135
+
136
+ # precomputed tile position embeddings
137
+ tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
138
+ batch_size = hidden_state.shape[0]
139
+ tile_position_embedding = tile_position_embedding.reshape(
140
+ batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
141
+ )
142
+ gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
143
+ hidden_state = hidden_state + gated_tile_position_embedding
144
+
145
+ return hidden_state
146
+
147
+
148
+ class MllamaVisionSdpaAttention(nn.Module):
149
+ def __init__(self, config: config_mllama.MllamaVisionConfig):
150
+ super().__init__()
151
+
152
+ model_parallel_size = get_tensor_model_parallel_world_size()
153
+ self.embed_dim = config.hidden_size
154
+ self.num_heads = config.attention_heads
155
+ self.head_dim = config.hidden_size // config.attention_heads
156
+ self.num_local_heads = self.num_heads // model_parallel_size
157
+ self.q_size = self.num_local_heads * self.head_dim
158
+ self.kv_size = self.num_local_heads * self.head_dim
159
+
160
+ self.qkv_proj = QKVParallelLinear(
161
+ self.embed_dim,
162
+ self.head_dim,
163
+ self.num_heads,
164
+ bias=False,
165
+ )
166
+ self.o_proj = RowParallelLinear(
167
+ self.num_heads * self.head_dim,
168
+ self.embed_dim,
169
+ bias=False,
170
+ input_is_parallel=True,
171
+ )
172
+
173
+ def forward(
174
+ self,
175
+ hidden_state: torch.Tensor,
176
+ attention_mask: Optional[torch.Tensor] = None,
177
+ ) -> torch.Tensor:
178
+ qkv, _ = self.qkv_proj(hidden_state)
179
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
180
+ q = q.view(
181
+ q.shape[0], q.shape[1], self.num_local_heads, self.head_dim
182
+ ).transpose(1, 2)
183
+ k = k.view(
184
+ k.shape[0], k.shape[1], self.num_local_heads, self.head_dim
185
+ ).transpose(1, 2)
186
+ v = v.view(
187
+ v.shape[0], v.shape[1], self.num_local_heads, self.head_dim
188
+ ).transpose(1, 2)
189
+
190
+ # TODO: remove padding in image encoder
191
+ attn_output = F.scaled_dot_product_attention(
192
+ q, k, v, attn_mask=attention_mask, dropout_p=0.0
193
+ )
194
+
195
+ attn_output = attn_output.transpose(1, 2).contiguous()
196
+ attn_output = attn_output.reshape(
197
+ attn_output.shape[0], attn_output.shape[1], -1
198
+ )
199
+ output, _ = self.o_proj(attn_output)
200
+ return output
201
+
202
+
203
+ class MllamaVisionMLP(nn.Module):
204
+ def __init__(self, config, quant_config: Optional[QuantizationConfig] = None):
205
+ super().__init__()
206
+ self.config = config
207
+ self.activation_fn = get_act_fn(config.hidden_act)
208
+ self.fc1 = ColumnParallelLinear(
209
+ config.hidden_size,
210
+ config.intermediate_size,
211
+ bias=True,
212
+ quant_config=quant_config,
213
+ )
214
+ self.fc2 = RowParallelLinear(
215
+ config.intermediate_size,
216
+ config.hidden_size,
217
+ bias=True,
218
+ quant_config=quant_config,
219
+ )
220
+
221
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
222
+ hidden_states, _ = self.fc1(hidden_states)
223
+ hidden_states = self.activation_fn(hidden_states)
224
+ hidden_states, _ = self.fc2(hidden_states)
225
+
226
+ return hidden_states
227
+
228
+
229
+ class MllamaVisionEncoderLayer(nn.Module):
230
+ def __init__(
231
+ self, config: config_mllama.MllamaVisionConfig, is_gated: bool = False
232
+ ):
233
+ super().__init__()
234
+
235
+ self.hidden_size = config.hidden_size
236
+ self.num_attention_heads = config.attention_heads
237
+ self.is_gated = is_gated
238
+ self.intermediate_size = config.intermediate_size
239
+
240
+ self.self_attn = MllamaVisionSdpaAttention(config)
241
+ self.mlp = MllamaVisionMLP(config)
242
+
243
+ self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
244
+ self.post_attention_layernorm = nn.LayerNorm(
245
+ self.hidden_size, eps=config.norm_eps
246
+ )
247
+
248
+ # there used to be an if else here, no code path
249
+ if is_gated:
250
+ self.gate_attn = nn.Parameter(torch.ones(1) * math.pi / 4)
251
+ self.gate_ffn = nn.Parameter(torch.ones(1) * math.pi / 4)
252
+
253
+ def forward(
254
+ self,
255
+ hidden_state: torch.Tensor,
256
+ attention_mask: Optional[torch.Tensor] = None,
257
+ ):
258
+ # Self Attention
259
+ residual = hidden_state
260
+ hidden_state = self.input_layernorm(hidden_state)
261
+ hidden_state = self.self_attn(hidden_state, attention_mask=attention_mask)
262
+ gate_attn = 1 if not self.is_gated else self.gate_attn.tanh()
263
+ hidden_state = residual + gate_attn * hidden_state
264
+
265
+ # Feed forward
266
+ residual = hidden_state
267
+ hidden_state = self.post_attention_layernorm(hidden_state)
268
+ hidden_state = self.mlp(hidden_state)
269
+ gate_ffn = 1 if not self.is_gated else self.gate_ffn.tanh()
270
+ hidden_state = residual + gate_ffn * hidden_state
271
+
272
+ return hidden_state
273
+
274
+
275
+ class MllamaVisionEncoder(nn.Module):
276
+ def __init__(
277
+ self,
278
+ config: config_mllama.MllamaVisionConfig,
279
+ num_layers=32,
280
+ is_gated=False,
281
+ output_hidden_states=None,
282
+ ):
283
+ super().__init__()
284
+ self.config = config
285
+ self.layers = nn.ModuleList(
286
+ [MllamaVisionEncoderLayer(config, is_gated) for _ in range(num_layers)]
287
+ )
288
+ self.output_hidden_states = output_hidden_states or []
289
+
290
+ def forward(
291
+ self,
292
+ hidden_states: torch.Tensor,
293
+ attention_mask: Optional[torch.Tensor] = None,
294
+ ) -> Union[Tuple, BaseModelOutput]:
295
+ encoder_states = ()
296
+
297
+ for i, encoder_layer in enumerate(self.layers):
298
+ if i in self.output_hidden_states:
299
+ encoder_states = encoder_states + (hidden_states,)
300
+ hidden_states = encoder_layer(
301
+ hidden_states,
302
+ attention_mask,
303
+ )
304
+
305
+ if len(self.layers) - 1 in self.output_hidden_states:
306
+ encoder_states = encoder_states + (hidden_states,)
307
+
308
+ return hidden_states, encoder_states
309
+
310
+
311
+ class MllamaVisionModel(nn.Module):
312
+ def __init__(self, config: config_mllama.MllamaVisionConfig):
313
+ super().__init__()
314
+ self.image_size = config.image_size
315
+ self.patch_size = config.patch_size
316
+ self.max_num_tiles = config.max_num_tiles
317
+ self.hidden_size = config.hidden_size
318
+ self.in_channels = config.num_channels
319
+ self.intermediate_layers_indices = config.intermediate_layers_indices
320
+
321
+ self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
322
+ self.scale = config.hidden_size**-0.5
323
+
324
+ self.patch_embedding = ColumnParallelConv2dPatch(
325
+ in_channels=config.num_channels,
326
+ out_channels=self.hidden_size,
327
+ kernel_size=self.patch_size,
328
+ stride=self.patch_size,
329
+ bias=False,
330
+ )
331
+
332
+ self.class_embedding = nn.Parameter(self.scale * torch.randn(self.hidden_size))
333
+ self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(config)
334
+
335
+ self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
336
+ config, is_gated=True
337
+ )
338
+ self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
339
+ config, is_gated=True
340
+ )
341
+
342
+ # layer norms
343
+ self.layernorm_pre = nn.LayerNorm(self.hidden_size)
344
+ self.layernorm_post = nn.LayerNorm(self.hidden_size)
345
+
346
+ # encoders
347
+ self.transformer = MllamaVisionEncoder(
348
+ config,
349
+ config.num_hidden_layers,
350
+ is_gated=False,
351
+ output_hidden_states=config.intermediate_layers_indices,
352
+ )
353
+ self.global_transformer = MllamaVisionEncoder(
354
+ config, config.num_global_layers, is_gated=True
355
+ )
356
+
357
+ def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
358
+ batch_size, _, hidden_size = hidden_state.shape
359
+ class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
360
+ hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
361
+ return hidden_state
362
+
363
+ def forward(
364
+ self,
365
+ pixel_values: torch.Tensor,
366
+ aspect_ratio_ids: torch.Tensor,
367
+ aspect_ratio_mask: torch.Tensor,
368
+ ) -> torch.Tensor:
369
+ batch_size, num_concurrent_media, num_tiles, num_channels, height, width = (
370
+ pixel_values.shape
371
+ )
372
+
373
+ pixel_values = pixel_values.reshape(
374
+ batch_size * num_concurrent_media * num_tiles, num_channels, height, width
375
+ )
376
+ aspect_ratio_ids = aspect_ratio_ids.reshape(
377
+ batch_size * num_concurrent_media, -1
378
+ )
379
+
380
+ # patch embedding
381
+ patch_embeds = self.patch_embedding(
382
+ pixel_values.to(self.layernorm_pre.weight.dtype)
383
+ )
384
+ hidden_state = patch_embeds
385
+ hidden_state = ps.get_tp_group().all_gather(hidden_state)
386
+
387
+ # tile embeddings
388
+ _, num_patches, dim = hidden_state.shape
389
+ hidden_state = hidden_state.reshape(
390
+ batch_size * num_concurrent_media, num_tiles, -1, dim
391
+ )
392
+ hidden_state = self.pre_tile_positional_embedding(
393
+ hidden_state, aspect_ratio_ids
394
+ )
395
+
396
+ # apply cls token
397
+ hidden_state = hidden_state.reshape(
398
+ batch_size * num_concurrent_media * num_tiles, num_patches, dim
399
+ )
400
+ hidden_state = self.apply_class_embedding(hidden_state)
401
+ num_patches += 1
402
+
403
+ # apply position embeddings
404
+ hidden_state = hidden_state.reshape(
405
+ batch_size * num_concurrent_media, num_tiles, num_patches, dim
406
+ )
407
+ hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)
408
+
409
+ # apply encoder
410
+ hidden_state = self.layernorm_pre(hidden_state)
411
+
412
+ # Compute the number of tokens to pad
413
+ num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
414
+ # Compute padding tuple for pad function
415
+ padding = (
416
+ 0,
417
+ 0,
418
+ 0,
419
+ num_padding_patches,
420
+ ) # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
421
+ # Pad the tensor
422
+ hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
423
+ slice_index = -num_padding_patches if num_padding_patches > 0 else None
424
+
425
+ attention_mask = aspect_ratio_mask.reshape(
426
+ batch_size * num_concurrent_media, -1
427
+ )
428
+ attention_mask = _prepare_aspect_ratio_attention_mask(
429
+ aspect_ratio_mask=attention_mask,
430
+ num_patches=self.num_patches,
431
+ target_length=hidden_state.shape[2],
432
+ dtype=self.layernorm_pre.weight.dtype,
433
+ )
434
+
435
+ hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
436
+ output = self.transformer(
437
+ hidden_state,
438
+ attention_mask=attention_mask,
439
+ )
440
+ hidden_state, intermediate_hidden_states = output[0], output[1]
441
+ intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)
442
+
443
+ # apply global encoder
444
+ hidden_state = self.layernorm_post(hidden_state)
445
+ hidden_state = hidden_state.reshape(
446
+ batch_size * num_concurrent_media,
447
+ num_tiles,
448
+ num_patches + num_padding_patches,
449
+ dim,
450
+ )
451
+ hidden_state = self.post_tile_positional_embedding(
452
+ hidden_state, aspect_ratio_ids
453
+ )
454
+ hidden_state = hidden_state.reshape(
455
+ batch_size * num_concurrent_media,
456
+ num_tiles * (num_patches + num_padding_patches),
457
+ dim,
458
+ )
459
+ hidden_state = self.global_transformer(
460
+ hidden_state, attention_mask=attention_mask
461
+ )[0]
462
+ hidden_state = hidden_state.reshape(
463
+ batch_size * num_concurrent_media,
464
+ num_tiles,
465
+ num_patches + num_padding_patches,
466
+ dim,
467
+ )
468
+ hidden_state = hidden_state[:, :, :slice_index]
469
+
470
+ # adding intermediate layer outputs
471
+ hidden_state = hidden_state.reshape(
472
+ batch_size, num_concurrent_media, num_tiles, num_patches, dim
473
+ )
474
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
475
+ batch_size * num_concurrent_media,
476
+ num_tiles,
477
+ num_patches + num_padding_patches,
478
+ -1,
479
+ )
480
+ intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
481
+ intermediate_hidden_states = intermediate_hidden_states.reshape(
482
+ batch_size, num_concurrent_media, num_tiles, num_patches, -1
483
+ )
484
+ hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
485
+ return hidden_state
486
+
487
+
488
+ class MllamaTextRMSNorm(nn.Module):
489
+ def __init__(self, hidden_size, eps=1e-6):
490
+ super().__init__()
491
+ self.weight = nn.Parameter(torch.ones(hidden_size))
492
+ self.variance_epsilon = eps
493
+
494
+ def forward(self, hidden_states):
495
+ input_dtype = hidden_states.dtype
496
+ hidden_states = hidden_states.to(torch.float32)
497
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
498
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
499
+ return self.weight * hidden_states.to(input_dtype)
500
+
501
+ def extra_repr(self):
502
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
503
+
504
+
505
+ class MllamaTextCrossAttention(nn.Module):
506
+ def __init__(
507
+ self,
508
+ config: Optional[config_mllama.MllamaTextConfig] = None,
509
+ layer_id: Optional[int] = None,
510
+ quant_config: Optional[QuantizationConfig] = None,
511
+ ):
512
+ super().__init__()
513
+ self.config = config
514
+ self.model_parallel_size = get_tensor_model_parallel_world_size()
515
+ self.num_heads = self.config.num_attention_heads
516
+ self.num_local_heads = self.num_heads // self.model_parallel_size
517
+ self.num_key_value_heads = self.config.num_key_value_heads
518
+ self.num_local_key_value_heads = (
519
+ self.num_key_value_heads // self.model_parallel_size
520
+ )
521
+ self.dropout = config.dropout
522
+ self.hidden_size = config.hidden_size
523
+ self.head_dim = config.hidden_size // self.num_heads
524
+ self.layer_id = layer_id
525
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
526
+ self.q_local_size = self.num_local_heads * self.head_dim
527
+ self.kv_local_size = self.num_local_key_value_heads * self.head_dim
528
+
529
+ self.qkv_proj = QKVParallelLinear(
530
+ self.hidden_size,
531
+ self.head_dim,
532
+ self.num_heads,
533
+ self.num_key_value_heads,
534
+ bias=False,
535
+ quant_config=quant_config,
536
+ )
537
+ self.o_proj = RowParallelLinear(
538
+ self.num_heads * self.head_dim,
539
+ self.hidden_size,
540
+ bias=False,
541
+ input_is_parallel=True,
542
+ quant_config=quant_config,
543
+ )
544
+ # vllm.model_executor.layers.layernorm.RMSNorm has precision issue,
545
+ # use huggingface's instead
546
+ self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
547
+ self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps)
548
+ self.scaling = self.head_dim**-0.5
549
+
550
+ self.attn = RadixAttention(
551
+ self.num_local_heads,
552
+ self.head_dim,
553
+ self.scaling,
554
+ self.num_local_key_value_heads,
555
+ layer_id=layer_id,
556
+ is_cross_attention=True,
557
+ )
558
+
559
+ def forward(
560
+ self,
561
+ hidden_states: torch.Tensor,
562
+ attention_mask: Optional[torch.Tensor],
563
+ cross_attention_states: Optional[torch.Tensor],
564
+ forward_batch: ForwardBatch,
565
+ ) -> torch.Tensor:
566
+ qkv_dec, _ = self.qkv_proj(hidden_states)
567
+ q, _, _ = qkv_dec.split(
568
+ [self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1
569
+ )
570
+ if cross_attention_states is None:
571
+ k = None
572
+ v = None
573
+ else:
574
+ qkv_enc, _ = self.qkv_proj(cross_attention_states)
575
+ _, k, v = qkv_enc.split(
576
+ [self.q_local_size, self.kv_local_size, self.kv_local_size], dim=-1
577
+ )
578
+ k = k.view(-1, self.num_local_key_value_heads, self.head_dim)
579
+ v = v.view(-1, self.num_local_key_value_heads, self.head_dim)
580
+ k = self.k_norm(k)
581
+ q = q.view(-1, self.num_local_heads, self.head_dim)
582
+ q = self.q_norm(q)
583
+
584
+ output = self.attn(q, k, v, forward_batch)
585
+ out, _ = self.o_proj(output)
586
+ return out
587
+
588
+
589
+ class MllamaCrossAttentionDecoderLayer(torch.nn.Module):
590
+ """Cross-attention transformer block with tanh-gated attention
591
+ and feedforward."""
592
+
593
+ def __init__(
594
+ self,
595
+ config: config_mllama.MllamaTextConfig,
596
+ layer_id: int,
597
+ quant_config: Optional[QuantizationConfig],
598
+ ) -> None:
599
+ super().__init__()
600
+ self.layer_id = layer_id
601
+ self.cross_attn = MllamaTextCrossAttention(
602
+ config=config,
603
+ layer_id=layer_id,
604
+ quant_config=quant_config,
605
+ )
606
+
607
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
608
+ self.cross_attn_attn_gate = torch.nn.Parameter(torch.zeros(1))
609
+
610
+ self.mlp = LlamaMLP(
611
+ hidden_size=config.hidden_size,
612
+ intermediate_size=config.intermediate_size,
613
+ hidden_act=config.hidden_act,
614
+ quant_config=quant_config,
615
+ )
616
+ self.post_attention_layernorm = RMSNorm(
617
+ config.hidden_size, eps=config.rms_norm_eps
618
+ )
619
+ self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1))
620
+
621
+ def forward(
622
+ self,
623
+ hidden_states: torch.Tensor,
624
+ cross_attention_states: torch.Tensor,
625
+ cross_attention_mask: torch.Tensor,
626
+ full_text_row_masked_out_mask: torch.Tensor,
627
+ forward_batch: ForwardBatch,
628
+ ) -> torch.Tensor:
629
+ residual = hidden_states
630
+ hidden_states = self.input_layernorm(hidden_states)
631
+
632
+ hidden_states = self.cross_attn(
633
+ hidden_states=hidden_states,
634
+ attention_mask=cross_attention_mask,
635
+ cross_attention_states=cross_attention_states,
636
+ forward_batch=forward_batch,
637
+ )
638
+ hidden_states = full_text_row_masked_out_mask * hidden_states
639
+ hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
640
+
641
+ residual = hidden_states
642
+ hidden_states = self.post_attention_layernorm(hidden_states)
643
+ hidden_states = self.mlp(hidden_states)
644
+ hidden_states = full_text_row_masked_out_mask * hidden_states
645
+ hidden_states = residual + self.cross_attn_mlp_gate.tanh() * hidden_states
646
+ return hidden_states
647
+
648
+
649
+ class MllamaTextModel(nn.Module):
650
+ config_class = config_mllama.MllamaTextConfig
651
+ base_model_prefix = "model"
652
+
653
+ def __init__(
654
+ self,
655
+ config: config_mllama.MllamaTextConfig,
656
+ quant_config: Optional[QuantizationConfig],
657
+ cache_config=None,
658
+ ):
659
+ super().__init__()
660
+ self.padding_id = config.pad_token_id
661
+ self.vocab_size = config.vocab_size
662
+ self.embed_tokens = VocabParallelEmbedding(
663
+ config.vocab_size + 8, config.hidden_size
664
+ )
665
+ self.cross_attention_layers = config.cross_attention_layers
666
+
667
+ layers = []
668
+ for layer_id in range(config.num_hidden_layers):
669
+ if layer_id in self.cross_attention_layers:
670
+ layers.append(
671
+ MllamaCrossAttentionDecoderLayer(
672
+ config, layer_id, quant_config=quant_config
673
+ )
674
+ )
675
+ else:
676
+ # TODO: force LlamaDecoderLayer to config.attention_bias=False
677
+ layers.append(
678
+ LlamaDecoderLayer(
679
+ config, quant_config=quant_config, layer_id=layer_id
680
+ )
681
+ )
682
+
683
+ self.layers = nn.ModuleList(layers)
684
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
685
+
686
+ def forward(
687
+ self,
688
+ input_ids: torch.LongTensor,
689
+ positions: Optional[torch.LongTensor],
690
+ cross_attention_states: Optional[torch.LongTensor],
691
+ cross_attention_mask: Optional[torch.LongTensor],
692
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]],
693
+ forward_batch: ForwardBatch,
694
+ skip_cross_attention: bool,
695
+ ) -> torch.Tensor:
696
+ inputs_embeds = self.embed_tokens(input_ids)
697
+ hidden_states = inputs_embeds
698
+
699
+ for _, decoder_layer in enumerate(self.layers):
700
+ if isinstance(decoder_layer, MllamaCrossAttentionDecoderLayer):
701
+ if not skip_cross_attention:
702
+ hidden_states = decoder_layer(
703
+ hidden_states=hidden_states,
704
+ cross_attention_states=cross_attention_states,
705
+ cross_attention_mask=cross_attention_mask,
706
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
707
+ forward_batch=forward_batch,
708
+ )
709
+ elif isinstance(decoder_layer, LlamaDecoderLayer):
710
+ hidden_states, residual = decoder_layer(
711
+ positions=positions,
712
+ hidden_states=hidden_states,
713
+ forward_batch=forward_batch,
714
+ residual=None,
715
+ )
716
+ hidden_states = hidden_states + residual
717
+ else:
718
+ raise ValueError(f"Unknown decoder layer type {type(decoder_layer)}")
719
+ hidden_states = self.norm(hidden_states)
720
+ return hidden_states
721
+
722
+
723
+ class MllamaForCausalLM(nn.Module):
724
+ config_class = config_mllama.MllamaTextConfig
725
+ base_model_prefix = "language_model"
726
+ _no_split_modules = [
727
+ "MllamaCrossAttentionDecoderLayer",
728
+ "MllamaSelfAttentionDecoderLayer",
729
+ ]
730
+
731
+ def __init__(
732
+ self,
733
+ config: config_mllama.MllamaTextConfig,
734
+ quant_config: Optional[QuantizationConfig],
735
+ cache_config=None,
736
+ ):
737
+ super().__init__()
738
+ self.vocab_size = config.vocab_size
739
+ self.model = MllamaTextModel(config, cache_config, quant_config)
740
+ self.lm_head = ParallelLMHead(
741
+ config.vocab_size,
742
+ config.hidden_size,
743
+ org_num_embeddings=config.vocab_size,
744
+ padding_size=DEFAULT_VOCAB_PADDING_SIZE,
745
+ quant_config=quant_config,
746
+ )
747
+
748
+ def forward(
749
+ self,
750
+ input_ids: torch.LongTensor,
751
+ positions: Optional[torch.LongTensor],
752
+ cross_attention_states: Optional[torch.LongTensor],
753
+ cross_attention_mask: Optional[torch.LongTensor],
754
+ full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]],
755
+ forward_batch: ForwardBatch,
756
+ skip_cross_attention: bool,
757
+ ) -> torch.Tensor:
758
+ hidden_states = self.model(
759
+ input_ids=input_ids,
760
+ positions=positions,
761
+ cross_attention_states=cross_attention_states,
762
+ cross_attention_mask=cross_attention_mask,
763
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
764
+ forward_batch=forward_batch,
765
+ skip_cross_attention=skip_cross_attention,
766
+ )
767
+ return hidden_states
768
+
769
+
770
+ class MllamaForConditionalGeneration(nn.Module):
771
+ def __init__(
772
+ self,
773
+ config: config_mllama.MllamaConfig,
774
+ quant_config: Optional[QuantizationConfig] = None,
775
+ cache_config=None,
776
+ ):
777
+ super().__init__()
778
+ self.vocab_size = config.text_config.vocab_size
779
+ self.hidden_size = config.text_config.hidden_size
780
+ self.max_num_tiles = config.vision_config.max_num_tiles
781
+ self.vision_output_dim = config.vision_config.vision_output_dim
782
+ self.pad_token_id = (
783
+ config.pad_token_id if config.pad_token_id is not None else -1
784
+ )
785
+ self.image_size = config.vision_config.image_size
786
+
787
+ self.vision_model = MllamaVisionModel(config.vision_config)
788
+ self.language_model = MllamaForCausalLM(
789
+ config.text_config,
790
+ cache_config=cache_config,
791
+ quant_config=quant_config,
792
+ )
793
+ self.multi_modal_projector = nn.Linear(
794
+ config.vision_config.vision_output_dim,
795
+ config.text_config.hidden_size,
796
+ bias=True,
797
+ )
798
+ self.logits_processor = LogitsProcessor(config.text_config)
799
+ self.capture_mode = False
800
+
801
+ def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
802
+ pixel_values = image_inputs.pixel_values
803
+ pad_values = image_inputs.pad_values
804
+
805
+ num_concurrent_media, num_tiles = pixel_values.shape[1:3]
806
+ num_patches = self.vision_model.num_patches
807
+ image_len = num_concurrent_media * num_tiles * num_patches
808
+ image_inputs.num_image_tokens = image_len
809
+
810
+ pad_ids = pad_values * ((image_len + len(pad_values)) // len(pad_values))
811
+
812
+ return pad_ids[:image_len] + input_ids
813
+
814
+ def _batch_image_inputs(self, forward_batch: ForwardBatch):
815
+ if forward_batch.forward_mode.is_decode() or all(forward_batch.encoder_cached):
816
+ return None, None, None, None
817
+
818
+ # pixel_values: shape (bs, num_image, num_tiles, 3, image_res, image_res)
819
+ max_num_images = max_num_tiles = bs = 0
820
+ for i, im in enumerate(forward_batch.image_inputs):
821
+ if not forward_batch.encoder_cached[i] and im is not None:
822
+ max_num_images = max(max_num_images, im.pixel_values.shape[1])
823
+ max_num_tiles = max(max_num_tiles, im.pixel_values.shape[2])
824
+ bs += 1
825
+
826
+ if max_num_images * max_num_tiles * bs == 0:
827
+ return None, None, None, None
828
+
829
+ with forward_batch.out_cache_loc.device:
830
+ batched_images = torch.zeros(
831
+ bs,
832
+ max_num_images,
833
+ max_num_tiles,
834
+ 3,
835
+ self.image_size,
836
+ self.image_size,
837
+ dtype=torch.float32,
838
+ )
839
+ batched_ar_ids = torch.ones(
840
+ bs, max_num_images, dtype=torch.int64, device="cuda"
841
+ )
842
+ batched_ar_mask = torch.zeros(
843
+ bs, max_num_images, max_num_tiles, dtype=torch.int64
844
+ )
845
+ i = 0
846
+ encoder_lens_need = []
847
+ for k, im in enumerate(forward_batch.image_inputs):
848
+ if forward_batch.encoder_cached[k] or im is None:
849
+ continue
850
+
851
+ encoder_lens_need.append(forward_batch.encoder_lens[k])
852
+ for j in range(im.pixel_values.shape[1]):
853
+ img = im.pixel_values[0, j]
854
+ num_tiles = img.shape[0]
855
+ batched_images[i, j, :num_tiles] = img
856
+ batched_ar_ids[i, j] = im.aspect_ratio_ids[0, j]
857
+ batched_ar_mask[i, j, :num_tiles] = im.aspect_ratio_mask[0, j]
858
+ i += 1
859
+
860
+ return batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need
861
+
862
+ def flat_encoder_result(
863
+ self, cross_attention_states: torch.Tensor, encoder_lens_need: List[int]
864
+ ):
865
+ # NOTE: not all encoders need computation, some are cached
866
+ head_dim = cross_attention_states.shape[-1]
867
+ total_encoder_len = sum(encoder_lens_need)
868
+ cross_attention_states_flat = torch.zeros(
869
+ total_encoder_len,
870
+ head_dim,
871
+ device=cross_attention_states.device,
872
+ dtype=cross_attention_states.dtype,
873
+ )
874
+
875
+ i = start_pos = 0
876
+ for encoder_len in encoder_lens_need:
877
+ if encoder_len == 0:
878
+ continue
879
+ end_pos = start_pos + encoder_len
880
+ cross_attention_states_flat[start_pos:end_pos] = cross_attention_states[i][
881
+ :encoder_len
882
+ ]
883
+ i += 1
884
+ start_pos += encoder_len
885
+
886
+ return cross_attention_states_flat
887
+
888
+ def get_full_text_row_masked_out_mask(self, forward_batch: ForwardBatch):
889
+ if forward_batch.forward_mode.is_decode():
890
+ full_text_row_masked_out_mask = forward_batch.encoder_lens != 0
891
+ else:
892
+ full_text_row_masked_out_mask = torch.ones(
893
+ forward_batch.extend_seq_lens.sum(), dtype=torch.bool
894
+ )
895
+ start_pos = 0
896
+
897
+ for seq_len, encoder_len in zip(
898
+ forward_batch.seq_lens.tolist(), forward_batch.encoder_lens_cpu
899
+ ):
900
+ if encoder_len == 0:
901
+ full_text_row_masked_out_mask[start_pos : start_pos + seq_len] = (
902
+ False
903
+ )
904
+ start_pos += encoder_len
905
+
906
+ full_text_row_masked_out_mask = full_text_row_masked_out_mask.to(
907
+ forward_batch.seq_lens.device
908
+ )
909
+
910
+ return full_text_row_masked_out_mask.reshape(-1, 1)
911
+
912
+ def forward(
913
+ self,
914
+ input_ids: torch.Tensor,
915
+ positions: torch.Tensor,
916
+ forward_batch: ForwardBatch,
917
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
918
+ batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
919
+ self._batch_image_inputs(forward_batch)
920
+ )
921
+
922
+ # TODO: support multi-image by this mask
923
+ cross_attention_mask = None
924
+ cross_attention_states = None
925
+
926
+ if self.capture_mode:
927
+ # NOTE: when doing cuda graph capture, we do not want to skip cross attention
928
+ # Make is a constant value to avoid cuda graph capture issue
929
+ skip_cross_attention = False
930
+ else:
931
+ # NOTE: we do not need image_inputs when prefill
932
+ assert len(forward_batch.encoder_lens) == len(forward_batch.seq_lens)
933
+ assert len(forward_batch.encoder_lens_cpu) == len(forward_batch.seq_lens)
934
+ skip_cross_attention = forward_batch.encoder_lens.max() == 0
935
+
936
+ if not skip_cross_attention:
937
+ full_text_row_masked_out_mask = self.get_full_text_row_masked_out_mask(
938
+ forward_batch
939
+ )
940
+ else:
941
+ full_text_row_masked_out_mask = None
942
+
943
+ if batched_images is not None:
944
+ # NOTE: llama's reference implementation runs vision model on CPU
945
+ cross_attention_states = self.vision_model(
946
+ batched_images, batched_ar_ids, batched_ar_mask
947
+ )
948
+ cross_attention_states = self.multi_modal_projector(cross_attention_states)
949
+
950
+ bs, _, _, _, image_token_dim = cross_attention_states.shape
951
+ cross_attention_states = cross_attention_states.view(
952
+ bs, -1, image_token_dim
953
+ )
954
+
955
+ cross_attention_states = self.flat_encoder_result(
956
+ cross_attention_states, encoder_lens_need
957
+ )
958
+
959
+ hidden_states = self.language_model(
960
+ input_ids=input_ids,
961
+ positions=positions,
962
+ cross_attention_states=cross_attention_states,
963
+ cross_attention_mask=cross_attention_mask,
964
+ full_text_row_masked_out_mask=full_text_row_masked_out_mask,
965
+ forward_batch=forward_batch,
966
+ skip_cross_attention=skip_cross_attention,
967
+ )
968
+ return self.logits_processor(
969
+ input_ids, hidden_states, self.language_model.lm_head.weight, forward_batch
970
+ )
971
+
972
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
973
+ stacked_params_mapping = [
974
+ # (param_name, shard_name, shard_id)
975
+ (".qkv_proj", ".q_proj", "q"),
976
+ (".qkv_proj", ".k_proj", "k"),
977
+ (".qkv_proj", ".v_proj", "v"),
978
+ (".gate_up_proj", ".gate_proj", 0),
979
+ (".gate_up_proj", ".up_proj", 1),
980
+ ]
981
+ params_dict = dict(self.named_parameters())
982
+ updated_params = set()
983
+ for name, loaded_weight in weights:
984
+ if "patch_embedding.weight" in name:
985
+ name = name.replace(
986
+ "patch_embedding.weight", "patch_embedding._linear.weight"
987
+ )
988
+ loaded_weight = loaded_weight.view(loaded_weight.shape[0], -1)
989
+ for param_name, weight_name, shard_id in stacked_params_mapping:
990
+ if weight_name not in name:
991
+ continue
992
+ name = name.replace(weight_name, param_name)
993
+ param = params_dict[name]
994
+ updated_params.add(name)
995
+ weight_loader = param.weight_loader
996
+ weight_loader(param, loaded_weight, shard_id)
997
+ break
998
+ else:
999
+ param = params_dict.pop(name)
1000
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
1001
+ weight_loader(param, loaded_weight)
1002
+
1003
+
1004
+ EntryClass = MllamaForConditionalGeneration