sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +49 -7
- sglang/lang/chat_template.py +24 -0
- sglang/srt/_custom_ops.py +59 -92
- sglang/srt/configs/model_config.py +5 -0
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/conversation.py +29 -4
- sglang/srt/custom_op.py +5 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/entrypoints/engine.py +0 -5
- sglang/srt/layers/attention/flashattention_backend.py +678 -83
- sglang/srt/layers/attention/flashinfer_backend.py +5 -7
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
- sglang/srt/layers/moe/ep_moe/layer.py +79 -80
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
- sglang/srt/layers/moe/fused_moe_native.py +5 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +49 -3
- sglang/srt/layers/quantization/__init__.py +5 -1
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
- sglang/srt/layers/quantization/fp8.py +3 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -4
- sglang/srt/layers/quantization/moe_wna16.py +503 -0
- sglang/srt/layers/quantization/utils.py +1 -1
- sglang/srt/layers/quantization/w8a8_int8.py +2 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +63 -12
- sglang/srt/managers/cache_controller.py +34 -11
- sglang/srt/managers/mm_utils.py +202 -156
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
- sglang/srt/managers/multimodal_processors/clip.py +7 -26
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
- sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
- sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
- sglang/srt/managers/multimodal_processors/llava.py +34 -14
- sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
- sglang/srt/managers/multimodal_processors/mlama.py +10 -23
- sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
- sglang/srt/managers/schedule_batch.py +185 -128
- sglang/srt/managers/scheduler.py +4 -4
- sglang/srt/managers/tokenizer_manager.py +1 -1
- sglang/srt/managers/utils.py +1 -6
- sglang/srt/mem_cache/hiradix_cache.py +62 -52
- sglang/srt/mem_cache/memory_pool.py +72 -6
- sglang/srt/mem_cache/paged_allocator.py +39 -0
- sglang/srt/metrics/collector.py +23 -53
- sglang/srt/model_executor/cuda_graph_runner.py +8 -6
- sglang/srt/model_executor/forward_batch_info.py +10 -10
- sglang/srt/model_executor/model_runner.py +60 -57
- sglang/srt/model_loader/loader.py +8 -0
- sglang/srt/models/clip.py +12 -7
- sglang/srt/models/deepseek_janus_pro.py +10 -15
- sglang/srt/models/deepseek_v2.py +212 -121
- sglang/srt/models/deepseek_vl2.py +105 -104
- sglang/srt/models/gemma3_mm.py +14 -80
- sglang/srt/models/llama.py +16 -5
- sglang/srt/models/llama4.py +420 -0
- sglang/srt/models/llava.py +31 -19
- sglang/srt/models/llavavid.py +16 -7
- sglang/srt/models/minicpmo.py +63 -147
- sglang/srt/models/minicpmv.py +17 -27
- sglang/srt/models/mllama.py +29 -14
- sglang/srt/models/mllama4.py +154 -0
- sglang/srt/models/qwen2.py +9 -6
- sglang/srt/models/qwen2_5_vl.py +21 -31
- sglang/srt/models/qwen2_vl.py +20 -21
- sglang/srt/openai_api/adapter.py +18 -6
- sglang/srt/platforms/interface.py +371 -0
- sglang/srt/server_args.py +99 -14
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
- sglang/srt/speculative/eagle_utils.py +140 -28
- sglang/srt/speculative/eagle_worker.py +93 -24
- sglang/srt/utils.py +104 -51
- sglang/test/test_custom_ops.py +55 -0
- sglang/test/test_utils.py +13 -26
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -11,7 +11,11 @@ from sglang.srt.configs.deepseekvl2 import (
|
|
11
11
|
)
|
12
12
|
from sglang.srt.layers.linear import ReplicatedLinear
|
13
13
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
14
|
-
from sglang.srt.managers.
|
14
|
+
from sglang.srt.managers.mm_utils import (
|
15
|
+
MultiModalityDataPaddingPatternImageTokens,
|
16
|
+
general_mm_embed_routine,
|
17
|
+
)
|
18
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
15
19
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
16
20
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
17
21
|
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
@@ -150,7 +154,6 @@ class DeepseekVL2MlpProjector(nn.Module):
|
|
150
154
|
return x
|
151
155
|
|
152
156
|
|
153
|
-
# todo
|
154
157
|
class DeepseekVL2ForCausalLM(nn.Module):
|
155
158
|
|
156
159
|
def __init__(
|
@@ -215,32 +218,15 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
215
218
|
forward_batch: ForwardBatch,
|
216
219
|
**kwargs: object,
|
217
220
|
):
|
218
|
-
|
219
|
-
if (
|
220
|
-
forward_batch.forward_mode.is_extend()
|
221
|
-
and forward_batch.contains_image_inputs()
|
222
|
-
):
|
223
|
-
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
|
224
|
-
extend_seq_lens_cpu = forward_batch.extend_seq_lens.cpu().numpy()
|
225
|
-
for idx, image in enumerate(forward_batch.mm_inputs):
|
226
|
-
if image is None:
|
227
|
-
continue
|
228
|
-
start_idx = extend_start_loc_cpu[idx]
|
229
|
-
end_idx = start_idx + extend_seq_lens_cpu[idx]
|
230
|
-
images_emb_mask = image.images_emb_mask.to(device="cuda")
|
231
|
-
image_features = self.get_image_feature(image)
|
232
|
-
input_embeds[start_idx:end_idx] = input_embeds[
|
233
|
-
start_idx:end_idx
|
234
|
-
].masked_scatter(images_emb_mask.unsqueeze(-1), image_features)
|
235
|
-
|
236
|
-
outputs = self.language_model.forward(
|
221
|
+
hs = general_mm_embed_routine(
|
237
222
|
input_ids=input_ids,
|
238
223
|
positions=positions,
|
239
224
|
forward_batch=forward_batch,
|
240
|
-
|
225
|
+
image_data_embedding_func=self.get_image_feature,
|
226
|
+
language_model=self.language_model,
|
241
227
|
)
|
242
228
|
|
243
|
-
return
|
229
|
+
return hs
|
244
230
|
|
245
231
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
246
232
|
stacked_params_mapping = [
|
@@ -263,94 +249,109 @@ class DeepseekVL2ForCausalLM(nn.Module):
|
|
263
249
|
weights_loader(param, loaded_weight)
|
264
250
|
|
265
251
|
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
252
|
+
helper = MultiModalityDataPaddingPatternImageTokens(
|
253
|
+
image_token_id=image_inputs.im_token_id
|
254
|
+
)
|
255
|
+
return helper.pad_input_tokens(input_ids, image_inputs)
|
256
|
+
|
257
|
+
def get_image_feature(self, items: List[MultimodalDataItem]):
|
258
|
+
|
259
|
+
images_spatial_crop = torch.cat(
|
260
|
+
[item.image_spatial_crop for item in items], dim=0
|
261
|
+
)
|
262
|
+
|
263
|
+
assert images_spatial_crop.dim() == 3
|
264
|
+
|
265
|
+
# TODO: can it be batched ?
|
277
266
|
images_in_this_batch = []
|
278
|
-
|
279
|
-
|
280
|
-
|
281
|
-
|
282
|
-
|
283
|
-
|
284
|
-
|
285
|
-
# [hw, D]
|
286
|
-
global_features = images_embeds[tile_index]
|
287
|
-
|
288
|
-
# [num_height_tiles * num_width_tiles, hw, D]
|
289
|
-
local_features = images_embeds[
|
290
|
-
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
291
|
-
]
|
292
|
-
tile_index += num_tiles_in_image + 1
|
293
|
-
|
294
|
-
# format global and local features
|
295
|
-
# ----------------- global view add newline -----------------
|
296
|
-
# [hw, D] -> [h, w, D]
|
297
|
-
global_features = global_features.view(h, w, n_dim)
|
298
|
-
|
299
|
-
# [D] -> [h, 1, D]
|
300
|
-
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
301
|
-
|
302
|
-
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
303
|
-
global_features = torch.cat([global_features, new_lines_in_global], dim=1)
|
304
|
-
|
305
|
-
# [h, w + 1, D] -> [h * (w + 1), D]
|
306
|
-
global_features = global_features.view(-1, n_dim)
|
307
|
-
|
308
|
-
# ----------------- local view add newline -----------------
|
309
|
-
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
310
|
-
# [num_height_tiles * h, num_width_tiles * w, D]
|
311
|
-
local_features = rearrange(
|
312
|
-
local_features,
|
313
|
-
"(th tw) (h w) d -> (th h) (tw w) d",
|
314
|
-
th=num_height_tiles,
|
315
|
-
tw=num_width_tiles,
|
316
|
-
h=h,
|
317
|
-
w=w,
|
267
|
+
for item in items:
|
268
|
+
assert item.pixel_values.dim() == 4
|
269
|
+
image_feature = self.vision.forward_features(
|
270
|
+
item.pixel_values.type(next(self.vision.parameters()).dtype).to(
|
271
|
+
device=next(self.vision.parameters()).device
|
272
|
+
)
|
318
273
|
)
|
274
|
+
images_embeds = self.projector(image_feature)
|
275
|
+
_, hw, n_dim = images_embeds.shape
|
276
|
+
h = w = int(hw**0.5)
|
277
|
+
tile_index = 0
|
278
|
+
for jdx in range(item.image_spatial_crop.shape[1]):
|
279
|
+
num_width_tiles, num_height_tiles = item.image_spatial_crop[0, jdx]
|
280
|
+
if num_width_tiles == 0 or num_height_tiles == 0:
|
281
|
+
break
|
282
|
+
num_tiles_in_image = num_width_tiles * num_height_tiles
|
283
|
+
|
284
|
+
# [hw, D]
|
285
|
+
global_features = images_embeds[tile_index]
|
286
|
+
|
287
|
+
# [num_height_tiles * num_width_tiles, hw, D]
|
288
|
+
local_features = images_embeds[
|
289
|
+
tile_index + 1 : tile_index + 1 + num_tiles_in_image
|
290
|
+
]
|
291
|
+
tile_index += num_tiles_in_image + 1
|
319
292
|
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
h
|
326
|
-
|
293
|
+
# format global and local features
|
294
|
+
# ----------------- global view add newline -----------------
|
295
|
+
# [hw, D] -> [h, w, D]
|
296
|
+
global_features = global_features.view(h, w, n_dim)
|
297
|
+
|
298
|
+
# [D] -> [h, 1, D]
|
299
|
+
new_lines_in_global = repeat(self.image_newline, "d -> h 1 d", h=h)
|
327
300
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
332
|
-
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
333
|
-
local_features = local_features.view(-1, n_dim)
|
334
|
-
|
335
|
-
# merge global and local tiles
|
336
|
-
if self.global_view_pos == "head":
|
337
|
-
global_local_features = torch.cat(
|
338
|
-
[
|
339
|
-
global_features,
|
340
|
-
self.view_seperator[None, :],
|
341
|
-
local_features,
|
342
|
-
]
|
301
|
+
# cat([h, w, D], [h, 1, D], dim=1) -> [h, w + 1, D]
|
302
|
+
global_features = torch.cat(
|
303
|
+
[global_features, new_lines_in_global], dim=1
|
343
304
|
)
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
305
|
+
|
306
|
+
# [h, w + 1, D] -> [h * (w + 1), D]
|
307
|
+
global_features = global_features.view(-1, n_dim)
|
308
|
+
|
309
|
+
# ----------------- local view add newline -----------------
|
310
|
+
# [num_height_tiles * num_width_tiles, h * w, D] ->
|
311
|
+
# [num_height_tiles * h, num_width_tiles * w, D]
|
312
|
+
local_features = rearrange(
|
313
|
+
local_features,
|
314
|
+
"(th tw) (h w) d -> (th h) (tw w) d",
|
315
|
+
th=num_height_tiles,
|
316
|
+
tw=num_width_tiles,
|
317
|
+
h=h,
|
318
|
+
w=w,
|
351
319
|
)
|
352
320
|
|
353
|
-
|
321
|
+
# [D] -> [num_height_tiles * h, 1, D]
|
322
|
+
new_lines_in_local = repeat(
|
323
|
+
self.image_newline,
|
324
|
+
"d -> (th h) 1 d",
|
325
|
+
th=num_height_tiles,
|
326
|
+
h=h,
|
327
|
+
)
|
328
|
+
|
329
|
+
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
330
|
+
local_features = torch.cat([local_features, new_lines_in_local], dim=1)
|
331
|
+
|
332
|
+
# [num_height_tiles * h, num_width_tiles * w + 1, D]
|
333
|
+
# --> [(num_height_tiles * h) * (num_width_tiles * w + 1), D]
|
334
|
+
local_features = local_features.view(-1, n_dim)
|
335
|
+
|
336
|
+
# merge global and local tiles
|
337
|
+
if self.global_view_pos == "head":
|
338
|
+
global_local_features = torch.cat(
|
339
|
+
[
|
340
|
+
global_features,
|
341
|
+
self.view_seperator[None, :],
|
342
|
+
local_features,
|
343
|
+
]
|
344
|
+
)
|
345
|
+
else:
|
346
|
+
global_local_features = torch.cat(
|
347
|
+
[
|
348
|
+
local_features,
|
349
|
+
self.view_seperator[None, :],
|
350
|
+
global_features,
|
351
|
+
]
|
352
|
+
)
|
353
|
+
|
354
|
+
images_in_this_batch.append(global_local_features)
|
354
355
|
|
355
356
|
return torch.cat(images_in_this_batch, dim=0)
|
356
357
|
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -21,14 +21,7 @@ from typing import Dict, Iterable, List, Optional, Set, Tuple, TypedDict
|
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
|
-
from transformers import
|
25
|
-
AutoModel,
|
26
|
-
BatchFeature,
|
27
|
-
Gemma3Config,
|
28
|
-
Gemma3Processor,
|
29
|
-
PreTrainedModel,
|
30
|
-
)
|
31
|
-
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
24
|
+
from transformers import AutoModel, Gemma3Config, PreTrainedModel
|
32
25
|
|
33
26
|
from sglang.srt.hf_transformers_utils import get_processor
|
34
27
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
@@ -38,7 +31,11 @@ from sglang.srt.managers.mm_utils import (
|
|
38
31
|
MultiModalityDataPaddingPatternTokenPairs,
|
39
32
|
general_mm_embed_routine,
|
40
33
|
)
|
41
|
-
from sglang.srt.managers.schedule_batch import
|
34
|
+
from sglang.srt.managers.schedule_batch import (
|
35
|
+
MultimodalDataItem,
|
36
|
+
MultimodalInputs,
|
37
|
+
flatten_nested_list,
|
38
|
+
)
|
42
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
40
|
from sglang.srt.model_loader.weight_utils import (
|
44
41
|
default_weight_loader,
|
@@ -274,17 +271,16 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
274
271
|
"""
|
275
272
|
return self.language_model.get_attention_sliding_window_size()
|
276
273
|
|
277
|
-
def get_image_feature(self,
|
274
|
+
def get_image_feature(self, items: List[MultimodalDataItem]):
|
278
275
|
"""
|
279
276
|
Projects the last hidden state from the vision model into language model space.
|
280
277
|
|
281
|
-
Args:
|
282
|
-
pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`)
|
283
|
-
The tensors corresponding to the input images.
|
284
278
|
Returns:
|
285
279
|
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
286
280
|
"""
|
287
|
-
pixel_values =
|
281
|
+
pixel_values = torch.stack(
|
282
|
+
flatten_nested_list([item.pixel_values for item in items]), dim=0
|
283
|
+
)
|
288
284
|
pixel_values = pixel_values.to("cuda")
|
289
285
|
pixel_values = pixel_values.to(dtype=self.language_model.dtype())
|
290
286
|
|
@@ -292,61 +288,6 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
292
288
|
image_features = self.multi_modal_projector(vision_outputs)
|
293
289
|
return image_features
|
294
290
|
|
295
|
-
def embed_mm_inputs(
|
296
|
-
self,
|
297
|
-
input_ids: torch.Tensor,
|
298
|
-
forward_batch: ForwardBatch,
|
299
|
-
image_input: MultimodalInputs,
|
300
|
-
) -> torch.Tensor:
|
301
|
-
if input_ids is None:
|
302
|
-
raise ValueError("Unimplemented")
|
303
|
-
# boolean-masking image tokens
|
304
|
-
special_image_mask = torch.isin(
|
305
|
-
input_ids,
|
306
|
-
torch.tensor(image_input.pad_values, device=input_ids.device),
|
307
|
-
).unsqueeze(-1)
|
308
|
-
num_image_tokens_in_input_ids = special_image_mask.sum()
|
309
|
-
|
310
|
-
inputs_embeds = None
|
311
|
-
if num_image_tokens_in_input_ids == 0:
|
312
|
-
inputs_embeds = self.get_input_embeddings()(input_ids)
|
313
|
-
return inputs_embeds
|
314
|
-
else:
|
315
|
-
# print(f"image tokens from input_ids: {inputs_embeds[special_image_mask].numel()}")
|
316
|
-
image_features = self.get_image_feature(image_input.pixel_values)
|
317
|
-
|
318
|
-
# print(f"image tokens from image embeddings: {image_features.numel()}")
|
319
|
-
num_image_tokens_in_embedding = (
|
320
|
-
image_features.shape[0] * image_features.shape[1]
|
321
|
-
)
|
322
|
-
|
323
|
-
if num_image_tokens_in_input_ids != num_image_tokens_in_embedding:
|
324
|
-
num_image = num_image_tokens_in_input_ids // image_features.shape[1]
|
325
|
-
image_features = image_features[:num_image, :]
|
326
|
-
logger.warning(
|
327
|
-
f"Number of images does not match number of special image tokens in the input text. "
|
328
|
-
f"Got {num_image_tokens_in_input_ids} image tokens in the text but {num_image_tokens_in_embedding} "
|
329
|
-
"tokens from image embeddings."
|
330
|
-
)
|
331
|
-
|
332
|
-
# Important: clamp after extracting original image boundaries
|
333
|
-
input_ids.clamp_(min=0, max=self.vocab_size - 1)
|
334
|
-
|
335
|
-
inputs_embeds = self.get_input_embeddings()(input_ids)
|
336
|
-
|
337
|
-
special_image_mask = special_image_mask.expand_as(inputs_embeds).to(
|
338
|
-
inputs_embeds.device
|
339
|
-
)
|
340
|
-
|
341
|
-
image_features = image_features.to(
|
342
|
-
inputs_embeds.device, inputs_embeds.dtype
|
343
|
-
)
|
344
|
-
inputs_embeds = inputs_embeds.masked_scatter(
|
345
|
-
special_image_mask, image_features
|
346
|
-
)
|
347
|
-
|
348
|
-
return inputs_embeds
|
349
|
-
|
350
291
|
@torch.no_grad()
|
351
292
|
def forward(
|
352
293
|
self,
|
@@ -405,22 +346,15 @@ class Gemma3ForConditionalGeneration(PreTrainedModel):
|
|
405
346
|
else:
|
406
347
|
llm_input_ids = input_ids
|
407
348
|
|
408
|
-
|
349
|
+
hs = general_mm_embed_routine(
|
409
350
|
input_ids=llm_input_ids,
|
410
351
|
forward_batch=forward_batch,
|
411
|
-
|
412
|
-
|
413
|
-
)
|
414
|
-
|
415
|
-
outputs = self.language_model(
|
416
|
-
input_ids=None,
|
352
|
+
language_model=self.language_model,
|
353
|
+
image_data_embedding_func=self.get_image_feature,
|
417
354
|
positions=positions,
|
418
|
-
forward_batch=forward_batch,
|
419
|
-
input_embeds=inputs_embeds,
|
420
|
-
**kwargs,
|
421
355
|
)
|
422
356
|
|
423
|
-
return
|
357
|
+
return hs
|
424
358
|
|
425
359
|
def tie_weights(self):
|
426
360
|
return self.language_model.tie_weights()
|
sglang/srt/models/llama.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
18
18
|
|
19
19
|
import logging
|
20
|
-
from typing import Any, Dict, Iterable, List, Optional,
|
20
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
@@ -63,6 +63,7 @@ class LlamaMLP(nn.Module):
|
|
63
63
|
hidden_act: str,
|
64
64
|
quant_config: Optional[QuantizationConfig] = None,
|
65
65
|
prefix: str = "",
|
66
|
+
reduce_results: bool = True,
|
66
67
|
) -> None:
|
67
68
|
super().__init__()
|
68
69
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -78,6 +79,7 @@ class LlamaMLP(nn.Module):
|
|
78
79
|
bias=False,
|
79
80
|
quant_config=quant_config,
|
80
81
|
prefix=add_prefix("down_proj", prefix),
|
82
|
+
reduce_results=reduce_results,
|
81
83
|
)
|
82
84
|
if hidden_act != "silu":
|
83
85
|
raise ValueError(
|
@@ -281,7 +283,7 @@ class LlamaModel(nn.Module):
|
|
281
283
|
self.layers = make_layers(
|
282
284
|
config.num_hidden_layers,
|
283
285
|
lambda idx, prefix: LlamaDecoderLayer(
|
284
|
-
config=config,
|
286
|
+
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
285
287
|
),
|
286
288
|
prefix="model.layers",
|
287
289
|
)
|
@@ -375,9 +377,7 @@ class LlamaForCausalLM(nn.Module):
|
|
375
377
|
super().__init__()
|
376
378
|
self.config = config
|
377
379
|
self.quant_config = quant_config
|
378
|
-
self.model =
|
379
|
-
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
380
|
-
)
|
380
|
+
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
381
381
|
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
382
382
|
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
383
383
|
if self.config.tie_word_embeddings:
|
@@ -402,6 +402,14 @@ class LlamaForCausalLM(nn.Module):
|
|
402
402
|
|
403
403
|
self.capture_aux_hidden_states = False
|
404
404
|
|
405
|
+
def _init_model(
|
406
|
+
self,
|
407
|
+
config: LlamaConfig,
|
408
|
+
quant_config: Optional[QuantizationConfig] = None,
|
409
|
+
prefix: str = "",
|
410
|
+
):
|
411
|
+
return LlamaModel(config, quant_config=quant_config, prefix=prefix)
|
412
|
+
|
405
413
|
@torch.no_grad()
|
406
414
|
def forward(
|
407
415
|
self,
|
@@ -428,6 +436,9 @@ class LlamaForCausalLM(nn.Module):
|
|
428
436
|
else:
|
429
437
|
return self.pooler(hidden_states, forward_batch)
|
430
438
|
|
439
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
440
|
+
return self.model.embed_tokens
|
441
|
+
|
431
442
|
def get_hidden_dim(self, module_name):
|
432
443
|
# return input_dim, output_dim
|
433
444
|
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|