sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen3.py
CHANGED
|
@@ -29,6 +29,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
|
29
29
|
)
|
|
30
30
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
|
31
31
|
from sglang.srt.models.qwen2 import Qwen2Model
|
|
32
|
+
from sglang.srt.server_args import get_global_server_args
|
|
32
33
|
from sglang.srt.utils import (
|
|
33
34
|
add_prefix,
|
|
34
35
|
get_cmo_stream,
|
|
@@ -88,8 +89,16 @@ class Qwen3Attention(nn.Module):
|
|
|
88
89
|
self.max_position_embeddings = max_position_embeddings
|
|
89
90
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
90
91
|
|
|
91
|
-
|
|
92
|
-
|
|
92
|
+
norm_kwargs = (
|
|
93
|
+
dict(
|
|
94
|
+
weight_dtype=torch.float32,
|
|
95
|
+
cast_x_before_out_mul=True,
|
|
96
|
+
)
|
|
97
|
+
if get_global_server_args().rl_on_policy_target == "fsdp"
|
|
98
|
+
else {}
|
|
99
|
+
)
|
|
100
|
+
self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
|
|
101
|
+
self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps, **norm_kwargs)
|
|
93
102
|
|
|
94
103
|
self.qkv_proj = QKVParallelLinear(
|
|
95
104
|
hidden_size,
|
|
@@ -158,10 +167,18 @@ class Qwen3Attention(nn.Module):
|
|
|
158
167
|
hidden_states: torch.Tensor,
|
|
159
168
|
forward_batch: ForwardBatch,
|
|
160
169
|
) -> torch.Tensor:
|
|
170
|
+
if get_global_server_args().rl_on_policy_target == "fsdp":
|
|
171
|
+
hidden_states = hidden_states.bfloat16()
|
|
172
|
+
|
|
161
173
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
162
174
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
|
163
175
|
q, k = self._apply_qk_norm(q, k)
|
|
164
176
|
q, k = self.rotary_emb(positions, q, k)
|
|
177
|
+
|
|
178
|
+
if get_global_server_args().rl_on_policy_target == "fsdp":
|
|
179
|
+
q = q.to(torch.bfloat16)
|
|
180
|
+
k = k.to(torch.bfloat16)
|
|
181
|
+
|
|
165
182
|
attn_output = self.attn(q, k, v, forward_batch)
|
|
166
183
|
output, _ = self.o_proj(attn_output)
|
|
167
184
|
return output
|
|
@@ -204,9 +221,22 @@ class Qwen3DecoderLayer(nn.Module):
|
|
|
204
221
|
quant_config=quant_config,
|
|
205
222
|
prefix=add_prefix("mlp", prefix),
|
|
206
223
|
)
|
|
207
|
-
|
|
224
|
+
|
|
225
|
+
norm_kwargs = (
|
|
226
|
+
dict(
|
|
227
|
+
weight_dtype=torch.float32,
|
|
228
|
+
cast_x_before_out_mul=True,
|
|
229
|
+
override_orig_dtype=torch.float32,
|
|
230
|
+
fp32_residual=True,
|
|
231
|
+
)
|
|
232
|
+
if get_global_server_args().rl_on_policy_target == "fsdp"
|
|
233
|
+
else {}
|
|
234
|
+
)
|
|
235
|
+
self.input_layernorm = RMSNorm(
|
|
236
|
+
config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
|
|
237
|
+
)
|
|
208
238
|
self.post_attention_layernorm = RMSNorm(
|
|
209
|
-
config.hidden_size, eps=config.rms_norm_eps
|
|
239
|
+
config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
|
|
210
240
|
)
|
|
211
241
|
|
|
212
242
|
self.layer_scatter_modes = LayerScatterModes.init_new(
|
|
@@ -331,7 +361,7 @@ class Qwen3ForCausalLM(nn.Module):
|
|
|
331
361
|
self.pp_group.send(
|
|
332
362
|
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
|
|
333
363
|
)
|
|
334
|
-
|
|
364
|
+
elif self.pp_group.is_last_rank:
|
|
335
365
|
emb_token_weight = self.pp_group.recv(
|
|
336
366
|
size=(config.vocab_size, config.hidden_size),
|
|
337
367
|
dtype=next(self.model.parameters()).dtype,
|
sglang/srt/models/qwen3_moe.py
CHANGED
|
@@ -241,16 +241,14 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|
|
241
241
|
)
|
|
242
242
|
|
|
243
243
|
def op_experts(self, state):
|
|
244
|
-
state.
|
|
244
|
+
state.combine_input = self.experts.run_moe_core(
|
|
245
245
|
dispatch_output=state.dispatch_output,
|
|
246
246
|
)
|
|
247
247
|
|
|
248
248
|
def op_combine_a(self, state):
|
|
249
249
|
if self.ep_size > 1:
|
|
250
250
|
self.experts.dispatcher.combine_a(
|
|
251
|
-
|
|
252
|
-
topk_ids=state.dispatch_output.topk_ids,
|
|
253
|
-
topk_weights=state.dispatch_output.topk_weights,
|
|
251
|
+
combine_input=state.pop("combine_input"),
|
|
254
252
|
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
|
255
253
|
)
|
|
256
254
|
state.pop("dispatch_output")
|
|
@@ -539,10 +537,16 @@ class Qwen3MoeDecoderLayer(nn.Module):
|
|
|
539
537
|
hidden_states: torch.Tensor,
|
|
540
538
|
forward_batch: ForwardBatch,
|
|
541
539
|
residual: Optional[torch.Tensor],
|
|
540
|
+
captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
|
|
542
541
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
543
542
|
|
|
544
|
-
hidden_states, residual =
|
|
545
|
-
|
|
543
|
+
hidden_states, residual = (
|
|
544
|
+
self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
|
|
545
|
+
hidden_states,
|
|
546
|
+
residual,
|
|
547
|
+
forward_batch,
|
|
548
|
+
captured_last_layer_outputs=captured_last_layer_outputs,
|
|
549
|
+
)
|
|
546
550
|
)
|
|
547
551
|
|
|
548
552
|
if hidden_states.shape[0] != 0:
|
|
@@ -774,13 +778,15 @@ class Qwen3MoeForCausalLM(nn.Module):
|
|
|
774
778
|
self.capture_aux_hidden_states = True
|
|
775
779
|
if layer_ids is None:
|
|
776
780
|
num_layers = self.config.num_hidden_layers
|
|
777
|
-
self.model.
|
|
778
|
-
|
|
779
|
-
|
|
780
|
-
|
|
781
|
-
|
|
781
|
+
self.model.set_eagle3_layers_to_capture(
|
|
782
|
+
[
|
|
783
|
+
2,
|
|
784
|
+
num_layers // 2,
|
|
785
|
+
num_layers - 3,
|
|
786
|
+
]
|
|
787
|
+
) # Specific layers for EAGLE3 support
|
|
782
788
|
else:
|
|
783
|
-
self.model.
|
|
789
|
+
self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
|
|
784
790
|
|
|
785
791
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
|
786
792
|
stacked_params_mapping = [
|
sglang/srt/models/qwen3_next.py
CHANGED
|
@@ -478,6 +478,13 @@ class Qwen3GatedDeltaNet(nn.Module):
|
|
|
478
478
|
# reshape input data into 2D tensor
|
|
479
479
|
core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
|
|
480
480
|
z = z.reshape(-1, z.shape[-1])
|
|
481
|
+
|
|
482
|
+
# Add padding for DP-Attn
|
|
483
|
+
if is_dp_attention_enabled():
|
|
484
|
+
core_attn_out_pad = torch.zeros_like(z)
|
|
485
|
+
core_attn_out_pad[: core_attn_out.shape[0], :] = core_attn_out
|
|
486
|
+
core_attn_out = core_attn_out_pad
|
|
487
|
+
|
|
481
488
|
core_attn_out = self.norm(core_attn_out, z)
|
|
482
489
|
core_attn_out = core_attn_out.reshape(z_shape_og)
|
|
483
490
|
core_attn_out = core_attn_out.reshape(*core_attn_out.shape[:-2], -1)
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from typing import Dict, Type
|
|
2
|
+
|
|
3
|
+
from transformers import PretrainedConfig, ProcessorMixin
|
|
4
|
+
|
|
5
|
+
# Useful for registering a custom processor different from Hugging Face's default.
|
|
6
|
+
_CUSTOMIZED_MM_PROCESSOR: Dict[str, Type[ProcessorMixin]] = dict()
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def register_customized_processor(
|
|
10
|
+
processor_class: Type[ProcessorMixin],
|
|
11
|
+
):
|
|
12
|
+
"""Class decorator that maps a config class's model_type field to a customized processor class.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
processor_class: A processor class that inherits from ProcessorMixin
|
|
16
|
+
|
|
17
|
+
Example:
|
|
18
|
+
```python
|
|
19
|
+
@register_customized_processor(MyCustomProcessor)
|
|
20
|
+
class MyModelConfig(PretrainedConfig):
|
|
21
|
+
model_type = "my_model"
|
|
22
|
+
|
|
23
|
+
```
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def decorator(config_class: PretrainedConfig):
|
|
27
|
+
if not hasattr(config_class, "model_type"):
|
|
28
|
+
raise ValueError(
|
|
29
|
+
f"Class {config_class.__name__} with register_customized_processor should "
|
|
30
|
+
f"have a 'model_type' class attribute."
|
|
31
|
+
)
|
|
32
|
+
_CUSTOMIZED_MM_PROCESSOR[config_class.model_type] = processor_class
|
|
33
|
+
return config_class
|
|
34
|
+
|
|
35
|
+
return decorator
|
|
@@ -185,6 +185,7 @@ class BaseMultimodalProcessor(ABC):
|
|
|
185
185
|
"aspect_ratio_mask": Modality.IMAGE,
|
|
186
186
|
"num_patches": Modality.IMAGE,
|
|
187
187
|
"patch_pixel_values": Modality.IMAGE,
|
|
188
|
+
"block_sizes": Modality.IMAGE,
|
|
188
189
|
# Audio-related attributes
|
|
189
190
|
"audio_features": Modality.AUDIO,
|
|
190
191
|
"audio_feature_lens": Modality.AUDIO,
|
|
@@ -17,7 +17,7 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
|
|
|
17
17
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
|
18
18
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
|
19
19
|
|
|
20
|
-
# GLM-
|
|
20
|
+
# GLM-V specific tokens
|
|
21
21
|
self.IMAGE_TOKEN = "<|image|>"
|
|
22
22
|
self.VIDEO_TOKEN = "<|video|>"
|
|
23
23
|
self.IMAGE_START_TOKEN = "<|begin_of_image|>"
|
|
@@ -1,64 +1,72 @@
|
|
|
1
|
-
from typing import Any
|
|
1
|
+
from typing import Any
|
|
2
2
|
|
|
3
3
|
import torch.nn as nn
|
|
4
4
|
from transformers.configuration_utils import PretrainedConfig
|
|
5
5
|
from transformers.processing_utils import ProcessorMixin
|
|
6
6
|
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
|
7
7
|
|
|
8
|
-
from sglang.srt.managers.io_struct import
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
ImageDataInputItem,
|
|
12
|
-
)
|
|
13
|
-
from sglang.srt.models.vila import VILAForConditionalGeneration
|
|
8
|
+
from sglang.srt.managers.io_struct import GenerateReqInput
|
|
9
|
+
from sglang.srt.models.nvila import NVILAForConditionalGeneration
|
|
10
|
+
from sglang.srt.models.nvila_lite import NVILALiteForConditionalGeneration
|
|
14
11
|
from sglang.srt.multimodal.processors.base_processor import (
|
|
15
12
|
BaseMultimodalProcessor,
|
|
16
13
|
MultimodalSpecialTokens,
|
|
17
14
|
)
|
|
18
15
|
from sglang.srt.server_args import ServerArgs
|
|
19
16
|
|
|
17
|
+
NUM_VIDEO_FRAMES = 8
|
|
20
18
|
|
|
21
|
-
class VILAProcessor(ProcessorMixin):
|
|
22
|
-
"""A stub class for the VILA processor."""
|
|
23
|
-
|
|
24
|
-
tokenizer: PreTrainedTokenizerBase
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class VILAMultimodalProcessor(BaseMultimodalProcessor):
|
|
28
|
-
models: List[Type[nn.Module]] = [VILAForConditionalGeneration]
|
|
29
19
|
|
|
30
|
-
|
|
20
|
+
class NVILAMultimodalProcessor(BaseMultimodalProcessor):
|
|
21
|
+
models: list[type[nn.Module]] = [
|
|
22
|
+
NVILAForConditionalGeneration,
|
|
23
|
+
NVILALiteForConditionalGeneration,
|
|
24
|
+
]
|
|
31
25
|
|
|
32
26
|
def __init__(
|
|
33
27
|
self,
|
|
34
28
|
hf_config: PretrainedConfig,
|
|
35
29
|
server_args: ServerArgs,
|
|
36
|
-
_processor:
|
|
30
|
+
_processor: ProcessorMixin,
|
|
37
31
|
*args,
|
|
38
32
|
**kwargs,
|
|
39
33
|
) -> None:
|
|
40
34
|
super().__init__(hf_config, server_args, _processor, *args, **kwargs)
|
|
35
|
+
|
|
36
|
+
self._processor: ProcessorMixin
|
|
37
|
+
|
|
38
|
+
tokenizer: PreTrainedTokenizerBase = getattr(self._processor, "tokenizer")
|
|
39
|
+
|
|
41
40
|
self.mm_tokens = MultimodalSpecialTokens(
|
|
42
|
-
image_token=
|
|
41
|
+
image_token=tokenizer.image_token,
|
|
43
42
|
image_token_id=hf_config.image_token_id,
|
|
43
|
+
video_token=tokenizer.video_token,
|
|
44
44
|
video_token_id=hf_config.video_token_id,
|
|
45
45
|
).build(_processor)
|
|
46
46
|
|
|
47
47
|
async def process_mm_data_async(
|
|
48
48
|
self,
|
|
49
|
-
image_data
|
|
50
|
-
|
|
51
|
-
|
|
49
|
+
image_data,
|
|
50
|
+
audio_data,
|
|
51
|
+
input_text,
|
|
52
|
+
request_obj: GenerateReqInput,
|
|
52
53
|
**kwargs,
|
|
53
|
-
) ->
|
|
54
|
+
) -> dict[str, Any] | None:
|
|
54
55
|
base_output = self.load_mm_data(
|
|
55
56
|
prompt=input_text,
|
|
56
57
|
multimodal_tokens=self.mm_tokens,
|
|
57
|
-
image_data=image_data,
|
|
58
|
+
image_data=request_obj.image_data, # type: ignore
|
|
59
|
+
video_data=request_obj.video_data, # type: ignore
|
|
58
60
|
)
|
|
59
61
|
|
|
62
|
+
for i, video in enumerate(base_output.videos): # type: ignore
|
|
63
|
+
base_output.videos[i] = [x.asnumpy() for x in video] # type: ignore
|
|
64
|
+
|
|
60
65
|
mm_items, input_ids, _ = self.process_and_combine_mm_data(
|
|
61
|
-
base_output,
|
|
66
|
+
base_output,
|
|
67
|
+
self.mm_tokens,
|
|
68
|
+
do_sample_frames=True,
|
|
69
|
+
num_frames=NUM_VIDEO_FRAMES,
|
|
62
70
|
)
|
|
63
71
|
|
|
64
72
|
return {
|
|
@@ -7,12 +7,12 @@ from PIL import Image
|
|
|
7
7
|
|
|
8
8
|
from sglang.srt.models.points_v15_chat import POINTSV15ChatModel
|
|
9
9
|
from sglang.srt.multimodal.processors.qwen_vl import (
|
|
10
|
-
|
|
10
|
+
QwenVLImageProcessor,
|
|
11
11
|
resize_image_async,
|
|
12
12
|
)
|
|
13
13
|
|
|
14
14
|
|
|
15
|
-
class POINTSV15ChatProcessor(
|
|
15
|
+
class POINTSV15ChatProcessor(QwenVLImageProcessor):
|
|
16
16
|
models = [POINTSV15ChatModel]
|
|
17
17
|
|
|
18
18
|
def __init__(self, hf_config, server_args, _processor, *args, **kwargs):
|
|
@@ -0,0 +1,209 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mixin class providing multiplexing scheduling logic
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.distributed as dist
|
|
9
|
+
from torch.cuda.streams import ExternalStream
|
|
10
|
+
|
|
11
|
+
from sglang.srt.distributed.parallel_state import set_pdmux_status
|
|
12
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
|
13
|
+
from sglang.srt.multiplex.pdmux_context import (
|
|
14
|
+
get_current_stream_idx,
|
|
15
|
+
get_sm_counts,
|
|
16
|
+
get_stream_groups,
|
|
17
|
+
initialize_stream_groups,
|
|
18
|
+
load_pdmux_config,
|
|
19
|
+
set_current_stream_idx,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class SchedulerMultiplexMixin:
|
|
26
|
+
|
|
27
|
+
def init_pdmux(self):
|
|
28
|
+
# for pd_multiplexing, Init stream_groups, exclude normal stream for prefill only and decode only
|
|
29
|
+
self.pdmux_config = load_pdmux_config(self.server_args.pdmux_config_path)
|
|
30
|
+
initialize_stream_groups(self.gpu_id, self.pdmux_config)
|
|
31
|
+
self.stream_groups = get_stream_groups()
|
|
32
|
+
self.sm_counts = get_sm_counts()
|
|
33
|
+
self.real_sm_group_num = len(self.stream_groups)
|
|
34
|
+
logger.info(
|
|
35
|
+
f"PD-Multiplexing enabled with {self.real_sm_group_num} stream groups, sm_counts (prefill_sm, decode_sm): {self.sm_counts}"
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
# TODO(jason-fxz): This is a temporary demo
|
|
39
|
+
def adjust_stream_groups(self) -> tuple[int, tuple[ExternalStream, ExternalStream]]:
|
|
40
|
+
if not self.running_batch.is_empty() and self.split_prefill_batch:
|
|
41
|
+
decode_bs = self.running_batch.batch_size()
|
|
42
|
+
manual_divisions = self.pdmux_config.manual_divisions
|
|
43
|
+
if manual_divisions:
|
|
44
|
+
for i in range(len(manual_divisions)):
|
|
45
|
+
_, _, threshold = manual_divisions[i]
|
|
46
|
+
if decode_bs >= threshold:
|
|
47
|
+
stream_idx = i + 1
|
|
48
|
+
else:
|
|
49
|
+
stream_idx = max(
|
|
50
|
+
1,
|
|
51
|
+
min(
|
|
52
|
+
self.real_sm_group_num - 2,
|
|
53
|
+
decode_bs
|
|
54
|
+
* (self.real_sm_group_num - 2)
|
|
55
|
+
// self.pdmux_config.decode_bs_divisor,
|
|
56
|
+
),
|
|
57
|
+
)
|
|
58
|
+
set_current_stream_idx(stream_idx)
|
|
59
|
+
elif not self.running_batch.is_empty():
|
|
60
|
+
set_current_stream_idx(self.real_sm_group_num - 1)
|
|
61
|
+
else:
|
|
62
|
+
set_current_stream_idx(0)
|
|
63
|
+
|
|
64
|
+
stream_idx = get_current_stream_idx()
|
|
65
|
+
|
|
66
|
+
self.tp_worker.model_runner.update_decode_attn_backend(stream_idx)
|
|
67
|
+
return stream_idx, self.stream_groups[stream_idx]
|
|
68
|
+
|
|
69
|
+
def update_split_prefill_batch(self, sm_count: int) -> bool:
|
|
70
|
+
if self.split_prefill_batch:
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
# add new request
|
|
74
|
+
batch = self.get_new_batch_prefill()
|
|
75
|
+
if batch and not batch.is_empty():
|
|
76
|
+
batch.forward_mode = (
|
|
77
|
+
ForwardMode.SPLIT_PREFILL
|
|
78
|
+
) # Set forward mode for split prefill
|
|
79
|
+
self.split_prefill_batch = batch
|
|
80
|
+
return True
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
@torch.inference_mode()
|
|
84
|
+
def event_loop_pdmux(self):
|
|
85
|
+
"""A scheduler loop for pd multiplexing."""
|
|
86
|
+
decode_done = False
|
|
87
|
+
prefill_done = False
|
|
88
|
+
wait_prefill_kernel_done = False
|
|
89
|
+
adjust_stream_group = False
|
|
90
|
+
stream_idx = get_current_stream_idx()
|
|
91
|
+
stream_group = self.stream_groups[stream_idx]
|
|
92
|
+
prefill_stream = stream_group[0]
|
|
93
|
+
decode_stream = stream_group[1]
|
|
94
|
+
torch.cuda.empty_cache()
|
|
95
|
+
|
|
96
|
+
logger.debug("Starting event loop for pd multiplexing...")
|
|
97
|
+
|
|
98
|
+
while True:
|
|
99
|
+
with torch.cuda.stream(decode_stream):
|
|
100
|
+
set_pdmux_status(False)
|
|
101
|
+
recv_reqs = self.recv_requests()
|
|
102
|
+
self.process_input_requests(recv_reqs)
|
|
103
|
+
|
|
104
|
+
with torch.cuda.stream(prefill_stream):
|
|
105
|
+
set_pdmux_status(True)
|
|
106
|
+
sm_count = self.sm_counts[stream_idx][0]
|
|
107
|
+
if not wait_prefill_kernel_done:
|
|
108
|
+
adjust_stream_group = (
|
|
109
|
+
self.update_split_prefill_batch(sm_count) or adjust_stream_group
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
with torch.cuda.stream(decode_stream):
|
|
113
|
+
set_pdmux_status(False)
|
|
114
|
+
self.running_batch = self.update_running_batch(self.running_batch)
|
|
115
|
+
adjust_stream_group = adjust_stream_group or (
|
|
116
|
+
stream_idx > 0 and self.running_batch.is_empty()
|
|
117
|
+
)
|
|
118
|
+
if self.running_batch.is_empty() and self.split_prefill_batch is None:
|
|
119
|
+
self.check_memory()
|
|
120
|
+
self.check_tree_cache()
|
|
121
|
+
self.new_token_ratio = self.init_new_token_ratio
|
|
122
|
+
self.maybe_sleep_on_idle()
|
|
123
|
+
|
|
124
|
+
if adjust_stream_group:
|
|
125
|
+
prefill_stream.synchronize()
|
|
126
|
+
decode_stream.synchronize()
|
|
127
|
+
stream_idx, stream_group = self.adjust_stream_groups()
|
|
128
|
+
prefill_stream = stream_group[0]
|
|
129
|
+
decode_stream = stream_group[1]
|
|
130
|
+
adjust_stream_group = False
|
|
131
|
+
logger.debug(
|
|
132
|
+
f"Adjusting stream groups: {stream_idx}, prefill sm: {self.sm_counts[stream_idx][0]}, decode sm: {self.sm_counts[stream_idx][1]}"
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
with torch.cuda.stream(decode_stream):
|
|
136
|
+
set_pdmux_status(False)
|
|
137
|
+
# process decode batch
|
|
138
|
+
if self.running_batch and not self.running_batch.is_empty():
|
|
139
|
+
decode_result = self.run_batch(self.running_batch)
|
|
140
|
+
decode_done = True
|
|
141
|
+
else:
|
|
142
|
+
decode_done = False
|
|
143
|
+
with torch.cuda.stream(prefill_stream):
|
|
144
|
+
set_pdmux_status(True)
|
|
145
|
+
if (
|
|
146
|
+
self.split_prefill_batch
|
|
147
|
+
and not self.split_prefill_batch.is_empty()
|
|
148
|
+
and not wait_prefill_kernel_done
|
|
149
|
+
):
|
|
150
|
+
prefill_done = True
|
|
151
|
+
forward_count = (
|
|
152
|
+
max(
|
|
153
|
+
1,
|
|
154
|
+
self.pdmux_config.split_forward_token_budget
|
|
155
|
+
// self.split_prefill_batch.extend_num_tokens,
|
|
156
|
+
)
|
|
157
|
+
if self.split_prefill_batch.extend_num_tokens > 0
|
|
158
|
+
else self.model_config.num_hidden_layers
|
|
159
|
+
)
|
|
160
|
+
next_split_index = min(
|
|
161
|
+
self.split_prefill_batch.split_index + forward_count,
|
|
162
|
+
self.model_config.num_hidden_layers,
|
|
163
|
+
)
|
|
164
|
+
forward_count = (
|
|
165
|
+
next_split_index - self.split_prefill_batch.split_index
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
self.split_prefill_batch.split_forward_count = forward_count
|
|
169
|
+
prefill_result = self.run_batch(self.split_prefill_batch)
|
|
170
|
+
if next_split_index == self.model_config.num_hidden_layers:
|
|
171
|
+
self.split_prefill_batch.split_prefill_finished = True
|
|
172
|
+
prefill_exe_done = prefill_stream.record_event()
|
|
173
|
+
self.split_prefill_batch.split_index = next_split_index
|
|
174
|
+
|
|
175
|
+
elif wait_prefill_kernel_done:
|
|
176
|
+
prefill_done = True
|
|
177
|
+
else:
|
|
178
|
+
prefill_done = False
|
|
179
|
+
|
|
180
|
+
with torch.cuda.stream(decode_stream):
|
|
181
|
+
set_pdmux_status(False)
|
|
182
|
+
decode_stream.synchronize()
|
|
183
|
+
if decode_done:
|
|
184
|
+
self.process_batch_result(self.running_batch, decode_result)
|
|
185
|
+
|
|
186
|
+
with torch.cuda.stream(prefill_stream):
|
|
187
|
+
set_pdmux_status(True)
|
|
188
|
+
if prefill_done and self.split_prefill_batch.split_prefill_finished:
|
|
189
|
+
wait_prefill_kernel_done = True
|
|
190
|
+
prefill_exe_done_flag = prefill_exe_done.query()
|
|
191
|
+
flags = (
|
|
192
|
+
torch.ones(1, device="cpu", dtype=torch.int32)
|
|
193
|
+
if prefill_exe_done_flag
|
|
194
|
+
else torch.zeros(1, device="cpu", dtype=torch.int32)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
self.tp_cpu_group.allreduce(flags, dist.ReduceOp.SUM).wait()
|
|
198
|
+
if flags.item() == self.tp_size:
|
|
199
|
+
self.process_batch_result(
|
|
200
|
+
self.split_prefill_batch, prefill_result
|
|
201
|
+
)
|
|
202
|
+
if self.running_batch and not self.running_batch.is_empty():
|
|
203
|
+
self.running_batch.merge_batch(self.split_prefill_batch)
|
|
204
|
+
else:
|
|
205
|
+
self.running_batch = self.split_prefill_batch
|
|
206
|
+
|
|
207
|
+
self.split_prefill_batch = None
|
|
208
|
+
wait_prefill_kernel_done = False
|
|
209
|
+
adjust_stream_group = True
|