sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama.py
CHANGED
@@ -532,31 +532,6 @@ class LlamaForCausalLM(nn.Module):
|
|
532
532
|
def get_input_embeddings(self) -> nn.Embedding:
|
533
533
|
return self.model.embed_tokens
|
534
534
|
|
535
|
-
def get_hidden_dim(self, module_name):
|
536
|
-
# return input_dim, output_dim
|
537
|
-
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
538
|
-
return self.config.hidden_size, self.config.hidden_size
|
539
|
-
elif module_name in ["kv_proj"]:
|
540
|
-
return self.config.hidden_size, self.config.hidden_size // (
|
541
|
-
self.config.num_attention_heads // self.config.num_key_value_heads
|
542
|
-
)
|
543
|
-
elif module_name == "gate_up_proj":
|
544
|
-
return self.config.hidden_size, self.config.intermediate_size
|
545
|
-
elif module_name == "down_proj":
|
546
|
-
return self.config.intermediate_size, self.config.hidden_size
|
547
|
-
else:
|
548
|
-
raise NotImplementedError()
|
549
|
-
|
550
|
-
def get_module_name(self, name):
|
551
|
-
params_mapping = {
|
552
|
-
"q_proj": "qkv_proj",
|
553
|
-
"k_proj": "qkv_proj",
|
554
|
-
"v_proj": "qkv_proj",
|
555
|
-
"gate_proj": "gate_up_proj",
|
556
|
-
"up_proj": "gate_up_proj",
|
557
|
-
}
|
558
|
-
return params_mapping.get(name, name)
|
559
|
-
|
560
535
|
def get_module_name_from_weight_name(self, name):
|
561
536
|
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
562
537
|
if weight_name in name:
|
sglang/srt/models/llama4.py
CHANGED
@@ -204,7 +204,7 @@ class Llama4Attention(nn.Module):
|
|
204
204
|
super().__init__()
|
205
205
|
self.layer_id = layer_id
|
206
206
|
self.hidden_size = hidden_size
|
207
|
-
self.use_rope =
|
207
|
+
self.use_rope = (layer_id + 1) % 4 != 0
|
208
208
|
self.use_qk_norm = config.use_qk_norm and self.use_rope
|
209
209
|
|
210
210
|
attn_tp_rank = get_attention_tp_rank()
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -114,7 +114,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
114
114
|
num_heads: int,
|
115
115
|
hidden_act="silu",
|
116
116
|
norm_layer: Type[nn.Module] = None,
|
117
|
-
attn_implementation: Optional[str] =
|
117
|
+
attn_implementation: Optional[str] = None,
|
118
118
|
quant_config: Optional[QuantizationConfig] = None,
|
119
119
|
prefix: str = "",
|
120
120
|
) -> None:
|
@@ -123,7 +123,12 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
123
123
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
124
124
|
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
125
125
|
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
126
|
-
|
126
|
+
|
127
|
+
if attn_implementation is None:
|
128
|
+
softmax_in_single_precision = False
|
129
|
+
qkv_backend = None
|
130
|
+
flatten_batch = True
|
131
|
+
elif attn_implementation == "sdpa":
|
127
132
|
softmax_in_single_precision = False
|
128
133
|
qkv_backend = "sdpa"
|
129
134
|
flatten_batch = True
|
@@ -268,7 +273,6 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
268
273
|
num_heads=num_heads,
|
269
274
|
hidden_act=vision_config.hidden_act,
|
270
275
|
norm_layer=norm_layer,
|
271
|
-
attn_implementation="sdpa",
|
272
276
|
quant_config=quant_config,
|
273
277
|
prefix=add_prefix(f"blocks.{i}", prefix),
|
274
278
|
)
|
sglang/srt/models/qwen2_audio.py
CHANGED
@@ -52,7 +52,11 @@ from sglang.srt.managers.mm_utils import (
|
|
52
52
|
MultiModalityDataPaddingPatternMultimodalTokens,
|
53
53
|
general_mm_embed_routine,
|
54
54
|
)
|
55
|
-
from sglang.srt.managers.schedule_batch import
|
55
|
+
from sglang.srt.managers.schedule_batch import (
|
56
|
+
Modality,
|
57
|
+
MultimodalDataItem,
|
58
|
+
MultimodalInputs,
|
59
|
+
)
|
56
60
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
57
61
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
58
62
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
@@ -106,15 +110,10 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
|
|
106
110
|
self.language_model = Qwen2ForCausalLM(
|
107
111
|
config.text_config, quant_config, prefix=add_prefix("model", prefix)
|
108
112
|
)
|
113
|
+
self.pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
109
114
|
|
110
115
|
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
111
|
-
|
112
|
-
audio_token_id: int = getattr(
|
113
|
-
mm_inputs, "audio_token_id", mm_inputs.im_token_id
|
114
|
-
)
|
115
|
-
|
116
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens([audio_token_id])
|
117
|
-
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
116
|
+
return self.pattern.pad_input_tokens(input_ids, mm_inputs)
|
118
117
|
|
119
118
|
def get_audio_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
120
119
|
# Extract audio features from input items
|
@@ -143,7 +142,9 @@ class Qwen2AudioForConditionalGeneration(nn.Module):
|
|
143
142
|
input_ids=input_ids,
|
144
143
|
forward_batch=forward_batch,
|
145
144
|
language_model=self.language_model,
|
146
|
-
|
145
|
+
data_embedding_funcs={
|
146
|
+
Modality.AUDIO: self.get_audio_feature,
|
147
|
+
},
|
147
148
|
positions=positions,
|
148
149
|
)
|
149
150
|
|
sglang/srt/models/qwen3.py
CHANGED
@@ -330,30 +330,6 @@ class Qwen3ForCausalLM(nn.Module):
|
|
330
330
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
331
331
|
return self.model.get_input_embeddings(input_ids)
|
332
332
|
|
333
|
-
def get_hidden_dim(self, module_name: str) -> Tuple[int]:
|
334
|
-
# return input_dim, output_dim
|
335
|
-
if module_name in ["q_proj", "qkv_proj"]:
|
336
|
-
return (
|
337
|
-
self.config.hidden_size,
|
338
|
-
self.config.head_dim * self.config.num_attention_heads,
|
339
|
-
)
|
340
|
-
elif module_name in ["o_proj"]:
|
341
|
-
return (
|
342
|
-
self.config.head_dim * self.config.num_attention_heads,
|
343
|
-
self.config.hidden_size,
|
344
|
-
)
|
345
|
-
elif module_name in ["kv_proj"]:
|
346
|
-
return (
|
347
|
-
self.config.hidden_size,
|
348
|
-
self.config.head_dim * self.config.num_key_value_heads,
|
349
|
-
)
|
350
|
-
elif module_name == "gate_up_proj":
|
351
|
-
return self.config.hidden_size, self.config.intermediate_size
|
352
|
-
elif module_name == "down_proj":
|
353
|
-
return self.config.intermediate_size, self.config.hidden_size
|
354
|
-
else:
|
355
|
-
raise NotImplementedError()
|
356
|
-
|
357
333
|
@torch.no_grad()
|
358
334
|
def forward(
|
359
335
|
self,
|
sglang/srt/models/registry.py
CHANGED
@@ -83,7 +83,7 @@ def import_model_classes():
|
|
83
83
|
try:
|
84
84
|
module = importlib.import_module(name)
|
85
85
|
except Exception as e:
|
86
|
-
logger.warning(f"Ignore import error when loading {name}
|
86
|
+
logger.warning(f"Ignore import error when loading {name}: {e}")
|
87
87
|
continue
|
88
88
|
if hasattr(module, "EntryClass"):
|
89
89
|
entry = module.EntryClass
|
@@ -416,30 +416,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
416
416
|
input_ids, hidden_states, self.lm_head, forward_batch
|
417
417
|
)
|
418
418
|
|
419
|
-
def get_hidden_dim(self, module_name):
|
420
|
-
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
|
421
|
-
return self.config.hidden_size, self.config.hidden_size
|
422
|
-
elif module_name in ["kv_proj"]:
|
423
|
-
return self.config.hidden_size, self.config.hidden_size // (
|
424
|
-
self.config.num_attention_heads // self.config.num_key_value_heads
|
425
|
-
)
|
426
|
-
elif module_name == "gate_up_proj":
|
427
|
-
return self.config.hidden_size, self.config.intermediate_size
|
428
|
-
elif module_name == "down_proj":
|
429
|
-
return self.config.intermediate_size, self.config.hidden_size
|
430
|
-
else:
|
431
|
-
raise NotImplementedError()
|
432
|
-
|
433
|
-
def get_module_name(self, name):
|
434
|
-
params_mapping = {
|
435
|
-
"q_proj": "qkv_proj",
|
436
|
-
"k_proj": "qkv_proj",
|
437
|
-
"v_proj": "qkv_proj",
|
438
|
-
"gate_proj": "gate_up_proj",
|
439
|
-
"up_proj": "gate_up_proj",
|
440
|
-
}
|
441
|
-
return params_mapping.get(name, name)
|
442
|
-
|
443
419
|
def get_module_name_from_weight_name(self, name):
|
444
420
|
stacked_params_mapping = [
|
445
421
|
# (param_name, shard_name, shard_id, num_shard)
|
@@ -22,13 +22,19 @@ class BaseMultiModalProcessorOutput:
|
|
22
22
|
input_text: str
|
23
23
|
|
24
24
|
# frames loaded from image, in given order
|
25
|
-
images: Optional[list[Union[Image.Image, dict]]] =
|
25
|
+
images: Optional[list[Union[Image.Image, dict]]] = dataclasses.field(
|
26
|
+
default_factory=list
|
27
|
+
)
|
26
28
|
|
27
29
|
# videos
|
28
|
-
videos: Optional[list[Union[torch.Tensor, dict]]] =
|
30
|
+
videos: Optional[list[Union[torch.Tensor, dict]]] = dataclasses.field(
|
31
|
+
default_factory=list
|
32
|
+
)
|
29
33
|
|
30
34
|
# audios
|
31
|
-
audios: Optional[list[Union[np.ndarray, dict]]] =
|
35
|
+
audios: Optional[list[Union[np.ndarray, dict]]] = dataclasses.field(
|
36
|
+
default_factory=list
|
37
|
+
)
|
32
38
|
|
33
39
|
def organize_results(self) -> List[Tuple[Modality, Any]]:
|
34
40
|
"""
|
@@ -202,7 +208,7 @@ class BaseMultimodalProcessor(ABC):
|
|
202
208
|
|
203
209
|
def process_mm_data(
|
204
210
|
self, input_text, images=None, videos=None, audios=None, **kwargs
|
205
|
-
):
|
211
|
+
) -> dict:
|
206
212
|
"""
|
207
213
|
process multimodal data with transformers AutoProcessor
|
208
214
|
"""
|
@@ -211,10 +217,14 @@ class BaseMultimodalProcessor(ABC):
|
|
211
217
|
if videos:
|
212
218
|
kwargs["videos"] = videos
|
213
219
|
if audios:
|
214
|
-
|
215
|
-
|
220
|
+
if self.arch in {
|
221
|
+
"Gemma3nForConditionalGeneration",
|
222
|
+
"Qwen2AudioForConditionalGeneration",
|
223
|
+
}:
|
216
224
|
# Note(Xinyuan): for gemma3n, ref: https://github.com/huggingface/transformers/blob/ccf2ca162e33f381e454cdb74bf4b41a51ab976d/src/transformers/models/gemma3n/processing_gemma3n.py#L107
|
217
225
|
kwargs["audio"] = audios
|
226
|
+
else:
|
227
|
+
kwargs["audios"] = audios
|
218
228
|
|
219
229
|
processor = self._processor
|
220
230
|
if (
|
@@ -601,12 +611,6 @@ class BaseMultimodalProcessor(ABC):
|
|
601
611
|
all_collected_items: list[MultimodalDataItem] = []
|
602
612
|
input_ids = None
|
603
613
|
|
604
|
-
# Handle dict items (already processed)
|
605
|
-
for dict_item in dict_items:
|
606
|
-
all_collected_items.extend(
|
607
|
-
self.collect_mm_items_from_processor_output(dict_item)
|
608
|
-
)
|
609
|
-
|
610
614
|
# Handle raw items (need processing)
|
611
615
|
if raw_images or raw_audios or raw_videos:
|
612
616
|
collected_items, input_ids, ret = self._process_and_collect_mm_items(
|
@@ -616,10 +620,16 @@ class BaseMultimodalProcessor(ABC):
|
|
616
620
|
videos=raw_videos,
|
617
621
|
**kwargs,
|
618
622
|
)
|
619
|
-
all_collected_items
|
623
|
+
all_collected_items = collected_items
|
620
624
|
else:
|
621
625
|
ret = None
|
622
626
|
|
627
|
+
# Handle dict items (already processed)
|
628
|
+
for dict_item in dict_items:
|
629
|
+
all_collected_items.extend(
|
630
|
+
self.collect_mm_items_from_processor_output(dict_item)
|
631
|
+
)
|
632
|
+
|
623
633
|
# Fallback tokenization if no raw items were processed
|
624
634
|
if input_ids is None:
|
625
635
|
input_ids = self._processor.tokenizer(
|
@@ -0,0 +1,132 @@
|
|
1
|
+
import re
|
2
|
+
from typing import List, Union
|
3
|
+
|
4
|
+
from decord import VideoReader
|
5
|
+
from transformers.video_utils import VideoMetadata
|
6
|
+
|
7
|
+
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
8
|
+
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration
|
9
|
+
from sglang.srt.models.glm4v_moe import Glm4vMoeForConditionalGeneration
|
10
|
+
from sglang.srt.multimodal.processors.base_processor import (
|
11
|
+
BaseMultimodalProcessor as SGLangBaseProcessor,
|
12
|
+
)
|
13
|
+
from sglang.srt.multimodal.processors.base_processor import (
|
14
|
+
BaseMultiModalProcessorOutput,
|
15
|
+
MultimodalSpecialTokens,
|
16
|
+
)
|
17
|
+
|
18
|
+
|
19
|
+
class Glm4vImageProcessor(SGLangBaseProcessor):
|
20
|
+
models = [Glm4vForConditionalGeneration, Glm4vMoeForConditionalGeneration]
|
21
|
+
|
22
|
+
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
23
|
+
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
24
|
+
|
25
|
+
# GLM-4.1V and GLM-4.5V specific tokens
|
26
|
+
self.IMAGE_TOKEN = "<|image|>"
|
27
|
+
self.VIDEO_TOKEN = "<|video|>"
|
28
|
+
self.IMAGE_START_TOKEN = "<|begin_of_image|>"
|
29
|
+
self.IMAGE_END_TOKEN = "<|end_of_image|>"
|
30
|
+
self.VIDEO_START_TOKEN = "<|begin_of_video|>"
|
31
|
+
self.VIDEO_END_TOKEN = "<|end_of_video|>"
|
32
|
+
|
33
|
+
# Token IDs
|
34
|
+
self.IM_TOKEN_ID = hf_config.image_token_id
|
35
|
+
self.VIDEO_TOKEN_ID = hf_config.video_token_id
|
36
|
+
self.IMAGE_START_TOKEN_ID = hf_config.image_start_token_id
|
37
|
+
self.IMAGE_END_TOKEN_ID = hf_config.image_end_token_id
|
38
|
+
self.VIDEO_START_TOKEN_ID = hf_config.video_start_token_id
|
39
|
+
self.VIDEO_END_TOKEN_ID = hf_config.video_end_token_id
|
40
|
+
|
41
|
+
# Vision config
|
42
|
+
self.IMAGE_FACTOR = 28
|
43
|
+
self.MIN_PIXELS = 112 * 112
|
44
|
+
self.MAX_PIXELS = 30000 * 28 * 28 * 2
|
45
|
+
|
46
|
+
self.mm_tokens = MultimodalSpecialTokens(
|
47
|
+
image_token=self.IMAGE_TOKEN,
|
48
|
+
image_token_id=self.IM_TOKEN_ID,
|
49
|
+
video_token=self.VIDEO_TOKEN,
|
50
|
+
# Note: For GLM4v videos, it uses the video token before tokenization but uses image token after tokenization
|
51
|
+
video_token_id=self.IM_TOKEN_ID,
|
52
|
+
).build(_processor)
|
53
|
+
|
54
|
+
# adapted from https://github.com/huggingface/transformers/blob/369c99d0cea403b77bd0aef818527106453fd9fc/src/transformers/video_utils.py#L312
|
55
|
+
async def preprocess_video(self, vr: VideoReader):
|
56
|
+
"""
|
57
|
+
Preprocess video using VideoReader from Decord backend.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
vr (VideoReader): VideoReader object from decord
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
tuple: A tuple containing processed frames and metadata
|
64
|
+
"""
|
65
|
+
video_fps = vr.get_avg_fps()
|
66
|
+
total_num_frames = len(vr)
|
67
|
+
duration = total_num_frames / video_fps if video_fps else 0
|
68
|
+
|
69
|
+
metadata = VideoMetadata(
|
70
|
+
total_num_frames=int(total_num_frames),
|
71
|
+
fps=float(video_fps),
|
72
|
+
duration=float(duration),
|
73
|
+
video_backend="decord",
|
74
|
+
)
|
75
|
+
|
76
|
+
# Extract all frames
|
77
|
+
indices = list(range(total_num_frames))
|
78
|
+
frames = vr.get_batch(indices).asnumpy()
|
79
|
+
metadata.frames_indices = indices
|
80
|
+
|
81
|
+
return frames, metadata
|
82
|
+
|
83
|
+
async def process_mm_data_async(
|
84
|
+
self,
|
85
|
+
image_data: List[Union[str, bytes]],
|
86
|
+
input_text,
|
87
|
+
request_obj,
|
88
|
+
*args,
|
89
|
+
**kwargs,
|
90
|
+
):
|
91
|
+
base_output = self.load_mm_data(
|
92
|
+
prompt=input_text,
|
93
|
+
image_data=image_data,
|
94
|
+
video_data=request_obj.video_data,
|
95
|
+
multimodal_tokens=self.mm_tokens,
|
96
|
+
)
|
97
|
+
|
98
|
+
video_metadata = None
|
99
|
+
|
100
|
+
if base_output.videos:
|
101
|
+
videos_processed = [
|
102
|
+
await self.preprocess_video(video) for video in base_output.videos
|
103
|
+
]
|
104
|
+
base_output.videos, video_metadata = map(list, zip(*videos_processed))
|
105
|
+
# transformer requires the video inputs to be under this format
|
106
|
+
base_output.videos = [base_output.videos]
|
107
|
+
video_metadata = [video_metadata]
|
108
|
+
|
109
|
+
mm_items, input_ids, ret = self.process_and_combine_mm_data(
|
110
|
+
base_output, self.mm_tokens, video_metadata=video_metadata
|
111
|
+
)
|
112
|
+
|
113
|
+
input_ids = input_ids.flatten()
|
114
|
+
mrope_positions, mrope_position_delta = MRotaryEmbedding.get_rope_index_glm4v(
|
115
|
+
input_ids=input_ids.unsqueeze(0),
|
116
|
+
hf_config=self.hf_config,
|
117
|
+
image_grid_thw=getattr(ret, "image_grid_thw", None),
|
118
|
+
video_grid_thw=getattr(ret, "video_grid_thw", None),
|
119
|
+
attention_mask=getattr(ret, "attention_mask", None),
|
120
|
+
)
|
121
|
+
mrope_positions = mrope_positions.squeeze(1)
|
122
|
+
|
123
|
+
mm_inputs = {
|
124
|
+
"input_ids": input_ids.tolist(),
|
125
|
+
"mm_items": mm_items,
|
126
|
+
"im_token_id": self.mm_tokens.image_token_id,
|
127
|
+
"video_token_id": self.mm_tokens.video_token_id,
|
128
|
+
"mrope_positions": mrope_positions,
|
129
|
+
"mrope_position_delta": mrope_position_delta,
|
130
|
+
}
|
131
|
+
|
132
|
+
return mm_inputs
|
@@ -1,6 +1,6 @@
|
|
1
1
|
import re
|
2
2
|
|
3
|
-
from sglang.srt.managers.schedule_batch import Modality
|
3
|
+
from sglang.srt.managers.schedule_batch import Modality
|
4
4
|
from sglang.srt.models.qwen2_audio import Qwen2AudioForConditionalGeneration
|
5
5
|
from sglang.srt.multimodal.processors.base_processor import (
|
6
6
|
BaseMultimodalProcessor,
|
@@ -29,6 +29,8 @@ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
|
29
29
|
audio_token_id=self.audio_token_id,
|
30
30
|
).build(_processor)
|
31
31
|
|
32
|
+
self.ATTR_NAME_TO_MODALITY.update({"feature_attention_mask": Modality.AUDIO})
|
33
|
+
|
32
34
|
async def process_mm_data_async(
|
33
35
|
self,
|
34
36
|
audio_data,
|
@@ -54,7 +56,7 @@ class Qwen2AudioMultimodalProcessor(BaseMultimodalProcessor):
|
|
54
56
|
input_lengths = (input_lengths - 1) // 2 + 1
|
55
57
|
output_lengths = (input_lengths - 2) // 2 + 1
|
56
58
|
|
57
|
-
mm_items[0].
|
59
|
+
mm_items[0].audio_feature_lens = output_lengths
|
58
60
|
|
59
61
|
return {
|
60
62
|
"mm_items": mm_items,
|