sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +67 -43
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +88 -53
- sglang/srt/entrypoints/openai/protocol.py +7 -4
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +39 -19
- sglang/srt/entrypoints/openai/serving_completions.py +15 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -7
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +182 -49
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +68 -41
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +200 -199
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +191 -139
- sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +260 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +18 -33
- sglang/srt/mem_cache/hiradix_cache.py +108 -48
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +121 -57
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +502 -77
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +75 -19
- sglang/srt/model_executor/model_runner.py +357 -30
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +346 -48
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +11 -2
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +60 -13
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +40 -9
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +355 -37
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +197 -112
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +46 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +12 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch_server.py
CHANGED
@@ -47,6 +47,7 @@ class BenchArgs:
|
|
47
47
|
profile: bool = False
|
48
48
|
profile_steps: int = 3
|
49
49
|
profile_by_stage: bool = False
|
50
|
+
dataset_path: str = ""
|
50
51
|
|
51
52
|
@staticmethod
|
52
53
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -83,6 +84,12 @@ class BenchArgs:
|
|
83
84
|
"--profile-steps", type=int, default=BenchArgs.profile_steps
|
84
85
|
)
|
85
86
|
parser.add_argument("--profile-by-stage", action="store_true")
|
87
|
+
parser.add_argument(
|
88
|
+
"--dataset-path",
|
89
|
+
type=str,
|
90
|
+
default=BenchArgs.dataset_path,
|
91
|
+
help="Path to the dataset.",
|
92
|
+
)
|
86
93
|
|
87
94
|
@classmethod
|
88
95
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -138,6 +145,7 @@ def run_one_case(
|
|
138
145
|
profile: bool = False,
|
139
146
|
profile_steps: int = 3,
|
140
147
|
profile_by_stage: bool = False,
|
148
|
+
dataset_path: str = "",
|
141
149
|
):
|
142
150
|
requests.post(url + "/flush_cache")
|
143
151
|
input_requests = sample_random_requests(
|
@@ -146,7 +154,7 @@ def run_one_case(
|
|
146
154
|
num_prompts=batch_size,
|
147
155
|
range_ratio=1.0,
|
148
156
|
tokenizer=tokenizer,
|
149
|
-
dataset_path=
|
157
|
+
dataset_path=dataset_path,
|
150
158
|
random_sample=True,
|
151
159
|
return_text=False,
|
152
160
|
)
|
@@ -345,6 +353,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
345
353
|
run_name="",
|
346
354
|
result_filename="",
|
347
355
|
tokenizer=tokenizer,
|
356
|
+
dataset_path=bench_args.dataset_path,
|
348
357
|
)
|
349
358
|
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
350
359
|
|
sglang/bench_serving.py
CHANGED
@@ -75,6 +75,7 @@ class RequestFuncInput:
|
|
75
75
|
lora_name: str
|
76
76
|
image_data: Optional[List[str]]
|
77
77
|
extra_request_body: Dict[str, Any]
|
78
|
+
timestamp: Optional[float] = None
|
78
79
|
|
79
80
|
|
80
81
|
@dataclass
|
@@ -104,10 +105,13 @@ def remove_suffix(text: str, suffix: str) -> str:
|
|
104
105
|
|
105
106
|
|
106
107
|
def get_auth_headers() -> Dict[str, str]:
|
107
|
-
|
108
|
-
if
|
109
|
-
return {"Authorization": f"Bearer {
|
108
|
+
openai_api_key = os.environ.get("OPENAI_API_KEY")
|
109
|
+
if openai_api_key:
|
110
|
+
return {"Authorization": f"Bearer {openai_api_key}"}
|
110
111
|
else:
|
112
|
+
api_key = os.environ.get("API_KEY")
|
113
|
+
if api_key:
|
114
|
+
return {"Authorization": f"{api_key}"}
|
111
115
|
return {}
|
112
116
|
|
113
117
|
|
@@ -696,6 +700,24 @@ def get_dataset(args, tokenizer):
|
|
696
700
|
apply_chat_template=args.apply_chat_template,
|
697
701
|
random_sample=True,
|
698
702
|
)
|
703
|
+
elif args.dataset_name == "mooncake":
|
704
|
+
# For mooncake, we don't generate the prompts here.
|
705
|
+
# We just load the raw trace data. The async generator will handle the rest.
|
706
|
+
if not args.dataset_path:
|
707
|
+
local_path = os.path.join("/tmp", args.mooncake_workload + "_trace.jsonl")
|
708
|
+
else:
|
709
|
+
local_path = args.dataset_path
|
710
|
+
|
711
|
+
if not os.path.exists(local_path):
|
712
|
+
download_and_cache_file(
|
713
|
+
MOONCAKE_DATASET_URL[args.mooncake_workload], local_path
|
714
|
+
)
|
715
|
+
|
716
|
+
with open(local_path, "r") as f:
|
717
|
+
all_requests_data = [json.loads(line) for line in f if line.strip()]
|
718
|
+
|
719
|
+
# Limit the number of requests based on --num-prompts
|
720
|
+
input_requests = all_requests_data[: args.num_prompts]
|
699
721
|
else:
|
700
722
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
701
723
|
return input_requests
|
@@ -750,6 +772,12 @@ class BenchmarkMetrics:
|
|
750
772
|
|
751
773
|
|
752
774
|
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
775
|
+
MOONCAKE_DATASET_URL = {
|
776
|
+
"mooncake": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl",
|
777
|
+
"conversation": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl",
|
778
|
+
"synthetic": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl",
|
779
|
+
"toolagent": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl",
|
780
|
+
}
|
753
781
|
|
754
782
|
|
755
783
|
def download_and_cache_file(url: str, filename: Optional[str] = None):
|
@@ -808,6 +836,80 @@ class DatasetRow:
|
|
808
836
|
prompt_len: int
|
809
837
|
output_len: int
|
810
838
|
image_data: Optional[List[str]] = None
|
839
|
+
timestamp: Optional[float] = None
|
840
|
+
|
841
|
+
|
842
|
+
async def get_mooncake_request_over_time(
|
843
|
+
input_requests: List[Dict],
|
844
|
+
tokenizer: PreTrainedTokenizerBase,
|
845
|
+
slowdown_factor: float,
|
846
|
+
num_rounds: int,
|
847
|
+
) -> AsyncGenerator[DatasetRow, None]:
|
848
|
+
"""
|
849
|
+
An async generator that yields requests based on the timestamps in the Mooncake trace file,
|
850
|
+
with support for multi-round sessions.
|
851
|
+
"""
|
852
|
+
if not input_requests:
|
853
|
+
return
|
854
|
+
|
855
|
+
input_requests.sort(key=lambda r: r["timestamp"])
|
856
|
+
|
857
|
+
start_time = time.perf_counter()
|
858
|
+
trace_start_time_ms = input_requests[0]["timestamp"]
|
859
|
+
|
860
|
+
for record in input_requests:
|
861
|
+
# Calculate when this entire session should start
|
862
|
+
relative_arrival_time_s = (record["timestamp"] - trace_start_time_ms) / 1000.0
|
863
|
+
target_arrival_time_s = relative_arrival_time_s * slowdown_factor
|
864
|
+
|
865
|
+
current_elapsed_time_s = time.perf_counter() - start_time
|
866
|
+
sleep_duration_s = target_arrival_time_s - current_elapsed_time_s
|
867
|
+
if sleep_duration_s > 0:
|
868
|
+
await asyncio.sleep(sleep_duration_s)
|
869
|
+
|
870
|
+
# Once the session starts, generate all rounds for it as a burst
|
871
|
+
# This simulates a user engaging in a multi-turn conversation
|
872
|
+
|
873
|
+
# Base user query constructed from hash_ids
|
874
|
+
user_query_base = ""
|
875
|
+
hash_ids = record.get("hash_ids", [])
|
876
|
+
for hash_id in hash_ids:
|
877
|
+
user_query_base += f"{hash_id}" + " ".join(
|
878
|
+
["hi"] * 128
|
879
|
+
) # Shorter for multi-round
|
880
|
+
user_query_base += "Tell me a story based on this context."
|
881
|
+
|
882
|
+
output_len_per_round = record.get("output_length", 256)
|
883
|
+
chat_history = []
|
884
|
+
|
885
|
+
for i in range(num_rounds):
|
886
|
+
# Add user query for the current round
|
887
|
+
chat_history.append(
|
888
|
+
{"role": "user", "content": f"Round {i+1}: {user_query_base}"}
|
889
|
+
)
|
890
|
+
|
891
|
+
# Form the full prompt from history
|
892
|
+
try:
|
893
|
+
full_prompt_text = tokenizer.apply_chat_template(
|
894
|
+
chat_history, tokenize=False, add_generation_prompt=True
|
895
|
+
)
|
896
|
+
except Exception:
|
897
|
+
full_prompt_text = "\n".join(
|
898
|
+
[f"{msg['role']}: {msg['content']}" for msg in chat_history]
|
899
|
+
)
|
900
|
+
|
901
|
+
prompt_len = len(tokenizer.encode(full_prompt_text))
|
902
|
+
|
903
|
+
yield DatasetRow(
|
904
|
+
prompt=full_prompt_text,
|
905
|
+
prompt_len=prompt_len,
|
906
|
+
output_len=output_len_per_round,
|
907
|
+
)
|
908
|
+
|
909
|
+
# Add a placeholder assistant response for the next round's context
|
910
|
+
# We use a placeholder because we don't know the real response
|
911
|
+
placeholder_response = " ".join(["story"] * output_len_per_round)
|
912
|
+
chat_history.append({"role": "assistant", "content": placeholder_response})
|
811
913
|
|
812
914
|
|
813
915
|
def sample_mmmu_requests(
|
@@ -896,17 +998,25 @@ def sample_mmmu_requests(
|
|
896
998
|
prompt = f"Question: {question}\n\nAnswer: "
|
897
999
|
if apply_chat_template:
|
898
1000
|
try:
|
1001
|
+
is_phi4_multimodal = (
|
1002
|
+
"phi-4-multimodal" in tokenizer.name_or_path.lower()
|
1003
|
+
)
|
1004
|
+
if is_phi4_multimodal:
|
1005
|
+
# <|endoftext10|> is the image token used in the phi-4-multimodal model.
|
1006
|
+
content = prompt.replace("image 1", "<|endoftext10|>")
|
1007
|
+
else:
|
1008
|
+
content = [
|
1009
|
+
{
|
1010
|
+
"type": "image_url",
|
1011
|
+
"image_url": {"url": image_data},
|
1012
|
+
},
|
1013
|
+
{"type": "text", "text": prompt},
|
1014
|
+
]
|
899
1015
|
prompt = tokenizer.apply_chat_template(
|
900
1016
|
[
|
901
1017
|
{
|
902
1018
|
"role": "user",
|
903
|
-
"content":
|
904
|
-
{
|
905
|
-
"type": "image_url",
|
906
|
-
"image_url": {"url": image_data},
|
907
|
-
},
|
908
|
-
{"type": "text", "text": prompt},
|
909
|
-
],
|
1019
|
+
"content": content,
|
910
1020
|
}
|
911
1021
|
],
|
912
1022
|
add_generation_prompt=True,
|
@@ -1359,19 +1469,41 @@ def sample_generated_shared_prefix_requests(
|
|
1359
1469
|
async def get_request(
|
1360
1470
|
input_requests: List[DatasetRow],
|
1361
1471
|
request_rate: float,
|
1472
|
+
use_trace_timestamps: bool = False,
|
1473
|
+
slowdown_factor: float = 1.0,
|
1362
1474
|
) -> AsyncGenerator[DatasetRow, None]:
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1475
|
+
if use_trace_timestamps:
|
1476
|
+
print(
|
1477
|
+
f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}."
|
1478
|
+
)
|
1479
|
+
# Sort requests by timestamp for correct replay
|
1480
|
+
input_requests.sort(key=lambda r: r.timestamp)
|
1366
1481
|
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1482
|
+
start_time = time.perf_counter()
|
1483
|
+
trace_start_time_ms = input_requests[0].timestamp if input_requests else 0
|
1484
|
+
|
1485
|
+
for request in input_requests:
|
1486
|
+
trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0
|
1487
|
+
target_arrival_time = start_time + (trace_time_s * slowdown_factor)
|
1488
|
+
|
1489
|
+
sleep_duration = target_arrival_time - time.perf_counter()
|
1490
|
+
if sleep_duration > 0:
|
1491
|
+
await asyncio.sleep(sleep_duration)
|
1492
|
+
|
1493
|
+
yield request
|
1494
|
+
else:
|
1495
|
+
input_requests_iter = iter(input_requests)
|
1496
|
+
for request in input_requests_iter:
|
1497
|
+
yield request
|
1370
1498
|
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1499
|
+
if request_rate == float("inf"):
|
1500
|
+
# If the request rate is infinity, then we don't need to wait.
|
1501
|
+
continue
|
1502
|
+
|
1503
|
+
# Sample the request interval from the exponential distribution.
|
1504
|
+
interval = np.random.exponential(1.0 / request_rate)
|
1505
|
+
# The next request will be sent after the interval.
|
1506
|
+
await asyncio.sleep(interval)
|
1375
1507
|
|
1376
1508
|
|
1377
1509
|
def calculate_metrics(
|
@@ -1397,7 +1529,7 @@ def calculate_metrics(
|
|
1397
1529
|
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
|
1398
1530
|
)
|
1399
1531
|
retokenized_output_lens.append(retokenized_output_len)
|
1400
|
-
total_input +=
|
1532
|
+
total_input += outputs[i].prompt_len
|
1401
1533
|
if output_len > 1:
|
1402
1534
|
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
1403
1535
|
itls += outputs[i].itl
|
@@ -1469,6 +1601,9 @@ async def benchmark(
|
|
1469
1601
|
pd_separated: bool = False,
|
1470
1602
|
flush_cache: bool = False,
|
1471
1603
|
warmup_requests: int = 1,
|
1604
|
+
use_trace_timestamps: bool = False,
|
1605
|
+
mooncake_slowdown_factor=1.0,
|
1606
|
+
mooncake_num_rounds=1,
|
1472
1607
|
):
|
1473
1608
|
if backend in ASYNC_REQUEST_FUNCS:
|
1474
1609
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
@@ -1488,8 +1623,32 @@ async def benchmark(
|
|
1488
1623
|
# Warmup
|
1489
1624
|
print(f"Starting warmup with {warmup_requests} sequences...")
|
1490
1625
|
|
1491
|
-
#
|
1492
|
-
|
1626
|
+
# Handle the data structure difference for the warmup request
|
1627
|
+
if args.dataset_name == "mooncake":
|
1628
|
+
# For mooncake, input_requests is a list of dicts.
|
1629
|
+
# We need to build a temporary DatasetRow for the warmup phase.
|
1630
|
+
warmup_record = input_requests[0]
|
1631
|
+
|
1632
|
+
# Build prompt from hash_ids, just like in the async generator
|
1633
|
+
hash_ids = warmup_record.get("hash_ids", [])
|
1634
|
+
prompt_text = ""
|
1635
|
+
for hash_id in hash_ids:
|
1636
|
+
prompt_text += f"{hash_id}" + " ".join(["hi"] * 512)
|
1637
|
+
prompt_text += "Can you tell me a detailed story in 1000 words?"
|
1638
|
+
|
1639
|
+
output_len = warmup_record.get("output_length", 32)
|
1640
|
+
prompt_len = len(tokenizer.encode(prompt_text))
|
1641
|
+
|
1642
|
+
# Create a temporary DatasetRow object for warmup
|
1643
|
+
test_request = DatasetRow(
|
1644
|
+
prompt=prompt_text,
|
1645
|
+
prompt_len=prompt_len,
|
1646
|
+
output_len=output_len,
|
1647
|
+
image_data=None, # Mooncake doesn't have image data
|
1648
|
+
)
|
1649
|
+
else:
|
1650
|
+
# For all other datasets, input_requests is a list of DatasetRow objects
|
1651
|
+
test_request = input_requests[0]
|
1493
1652
|
|
1494
1653
|
if lora_names is not None and len(lora_names) != 0:
|
1495
1654
|
lora_name = lora_names[0]
|
@@ -1543,12 +1702,26 @@ async def benchmark(
|
|
1543
1702
|
if profile_output.success:
|
1544
1703
|
print("Profiler started")
|
1545
1704
|
|
1546
|
-
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
1547
|
-
|
1548
1705
|
# Run all requests
|
1549
1706
|
benchmark_start_time = time.perf_counter()
|
1550
1707
|
tasks: List[asyncio.Task] = []
|
1551
|
-
|
1708
|
+
pbar_total = len(input_requests)
|
1709
|
+
if (
|
1710
|
+
backend == "sglang" and args.dataset_name == "mooncake"
|
1711
|
+
): # Assuming mooncake is mainly for sglang or similar backends
|
1712
|
+
print("Using time-based Mooncake request scheduler, ignoring --request-rate.")
|
1713
|
+
request_generator = get_mooncake_request_over_time(
|
1714
|
+
input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds
|
1715
|
+
)
|
1716
|
+
print(
|
1717
|
+
f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}"
|
1718
|
+
)
|
1719
|
+
pbar_total *= args.mooncake_num_rounds
|
1720
|
+
else:
|
1721
|
+
request_generator = get_request(input_requests, request_rate)
|
1722
|
+
|
1723
|
+
pbar = None if disable_tqdm else tqdm(total=pbar_total)
|
1724
|
+
async for request in request_generator:
|
1552
1725
|
if lora_names is not None and len(lora_names) != 0:
|
1553
1726
|
idx = random.randint(0, len(lora_names) - 1)
|
1554
1727
|
lora_name = lora_names[idx]
|
@@ -1564,6 +1737,7 @@ async def benchmark(
|
|
1564
1737
|
lora_name=lora_name,
|
1565
1738
|
image_data=request.image_data,
|
1566
1739
|
extra_request_body=extra_request_body,
|
1740
|
+
timestamp=request.timestamp,
|
1567
1741
|
)
|
1568
1742
|
|
1569
1743
|
tasks.append(
|
@@ -1609,7 +1783,11 @@ async def benchmark(
|
|
1609
1783
|
|
1610
1784
|
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
1611
1785
|
print("{:<40} {:<10}".format("Backend:", backend))
|
1612
|
-
print(
|
1786
|
+
print(
|
1787
|
+
"{:<40} {:<10}".format(
|
1788
|
+
"Traffic request rate:", "trace" if use_trace_timestamps else request_rate
|
1789
|
+
)
|
1790
|
+
)
|
1613
1791
|
print(
|
1614
1792
|
"{:<40} {:<10}".format(
|
1615
1793
|
"Max request concurrency:",
|
@@ -1678,7 +1856,7 @@ async def benchmark(
|
|
1678
1856
|
# Arguments
|
1679
1857
|
"backend": args.backend,
|
1680
1858
|
"dataset_name": args.dataset_name,
|
1681
|
-
"request_rate": request_rate,
|
1859
|
+
"request_rate": "trace" if use_trace_timestamps else request_rate,
|
1682
1860
|
"max_concurrency": max_concurrency,
|
1683
1861
|
"sharegpt_output_len": args.sharegpt_output_len,
|
1684
1862
|
"random_input_len": args.random_input_len,
|
@@ -1731,7 +1909,9 @@ async def benchmark(
|
|
1731
1909
|
elif args.dataset_name.startswith("random"):
|
1732
1910
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
1733
1911
|
else:
|
1734
|
-
output_file_name =
|
1912
|
+
output_file_name = (
|
1913
|
+
f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
|
1914
|
+
)
|
1735
1915
|
|
1736
1916
|
result_details = {
|
1737
1917
|
"input_lens": [output.prompt_len for output in outputs],
|
@@ -1786,6 +1966,17 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1786
1966
|
if not hasattr(args, "tokenize_prompt"):
|
1787
1967
|
args.tokenize_prompt = False
|
1788
1968
|
|
1969
|
+
if not hasattr(args, "use_trace_timestamps"):
|
1970
|
+
args.use_trace_timestamps = False
|
1971
|
+
if not hasattr(args, "mooncake_slowdown_factor"):
|
1972
|
+
args.mooncake_slowdown_factor = 1.0
|
1973
|
+
|
1974
|
+
if not hasattr(args, "mooncake_slowdown_factor"):
|
1975
|
+
args.mooncake_slowdown_factor = 1.0
|
1976
|
+
|
1977
|
+
if not hasattr(args, "mooncake_num_rounds"):
|
1978
|
+
args.mooncake_num_rounds = 1
|
1979
|
+
|
1789
1980
|
print(f"benchmark_args={args}")
|
1790
1981
|
|
1791
1982
|
# Set global environments
|
@@ -1919,6 +2110,9 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1919
2110
|
pd_separated=args.pd_separated,
|
1920
2111
|
flush_cache=args.flush_cache,
|
1921
2112
|
warmup_requests=args.warmup_requests,
|
2113
|
+
use_trace_timestamps=args.use_trace_timestamps,
|
2114
|
+
mooncake_slowdown_factor=args.mooncake_slowdown_factor,
|
2115
|
+
mooncake_num_rounds=args.mooncake_num_rounds,
|
1922
2116
|
)
|
1923
2117
|
)
|
1924
2118
|
|
@@ -1975,6 +2169,7 @@ if __name__ == "__main__":
|
|
1975
2169
|
"generated-shared-prefix",
|
1976
2170
|
"mmmu",
|
1977
2171
|
"random-image",
|
2172
|
+
"mooncake",
|
1978
2173
|
],
|
1979
2174
|
help="Name of the dataset to benchmark on.",
|
1980
2175
|
)
|
@@ -2051,6 +2246,11 @@ if __name__ == "__main__":
|
|
2051
2246
|
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
2052
2247
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
2053
2248
|
)
|
2249
|
+
parser.add_argument(
|
2250
|
+
"--use-trace-timestamps",
|
2251
|
+
action="store_true",
|
2252
|
+
help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.",
|
2253
|
+
)
|
2054
2254
|
parser.add_argument(
|
2055
2255
|
"--max-concurrency",
|
2056
2256
|
type=int,
|
@@ -2174,5 +2374,33 @@ if __name__ == "__main__":
|
|
2174
2374
|
default=256,
|
2175
2375
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
2176
2376
|
)
|
2377
|
+
mooncake_group = parser.add_argument_group("mooncake dataset arguments")
|
2378
|
+
mooncake_group.add_argument(
|
2379
|
+
"--mooncake-slowdown-factor",
|
2380
|
+
type=float,
|
2381
|
+
default=1.0,
|
2382
|
+
help="Slowdown factor for replaying the mooncake trace. "
|
2383
|
+
"A value of 2.0 means the replay is twice as slow. "
|
2384
|
+
"NOTE: --request-rate is IGNORED in mooncake mode.",
|
2385
|
+
)
|
2386
|
+
mooncake_group.add_argument(
|
2387
|
+
"--mooncake-num-rounds",
|
2388
|
+
type=int,
|
2389
|
+
default=1,
|
2390
|
+
help="Number of conversation rounds for each session in the mooncake dataset. "
|
2391
|
+
"A value > 1 will enable true multi-turn session benchmarking.",
|
2392
|
+
)
|
2393
|
+
mooncake_group.add_argument(
|
2394
|
+
"--mooncake-workload",
|
2395
|
+
type=str,
|
2396
|
+
default="conversation",
|
2397
|
+
choices=[
|
2398
|
+
"mooncake",
|
2399
|
+
"conversation",
|
2400
|
+
"synthetic",
|
2401
|
+
"toolagent",
|
2402
|
+
],
|
2403
|
+
help="Underlying workload for the mooncake dataset.",
|
2404
|
+
)
|
2177
2405
|
args = parser.parse_args()
|
2178
2406
|
run_benchmark(args)
|
sglang/lang/interpreter.py
CHANGED
@@ -740,7 +740,7 @@ class StreamExecutor:
|
|
740
740
|
# Execute the stored lazy generation calls
|
741
741
|
self.backend.role_end_generate(self)
|
742
742
|
|
743
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
743
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
744
744
|
|
745
745
|
reasoning_parser = ReasoningParser(expr.model_type)
|
746
746
|
other = expr.expr
|
sglang/srt/configs/__init__.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1
1
|
from sglang.srt.configs.chatglm import ChatGLMConfig
|
2
2
|
from sglang.srt.configs.dbrx import DbrxConfig
|
3
3
|
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
4
|
+
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
4
5
|
from sglang.srt.configs.exaone import ExaoneConfig
|
5
6
|
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
6
7
|
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
7
8
|
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
8
9
|
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
|
10
|
+
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
|
9
11
|
from sglang.srt.configs.step3_vl import (
|
10
12
|
Step3TextConfig,
|
11
13
|
Step3VisionEncoderConfig,
|
@@ -24,4 +26,6 @@ __all__ = [
|
|
24
26
|
"Step3VLConfig",
|
25
27
|
"Step3TextConfig",
|
26
28
|
"Step3VisionEncoderConfig",
|
29
|
+
"Qwen3NextConfig",
|
30
|
+
"DotsVLMConfig",
|
27
31
|
]
|
@@ -8,10 +8,12 @@ logger = logging.getLogger(__name__)
|
|
8
8
|
|
9
9
|
class DeviceConfig:
|
10
10
|
device: Optional[torch.device]
|
11
|
+
gpu_id: Optional[int]
|
11
12
|
|
12
|
-
def __init__(self, device: str = "cuda") -> None:
|
13
|
+
def __init__(self, device: str = "cuda", gpu_id: int = -1) -> None:
|
13
14
|
if device in ["cuda", "xpu", "hpu", "cpu", "npu"]:
|
14
15
|
self.device_type = device
|
15
16
|
else:
|
16
17
|
raise RuntimeError(f"Not supported device type: {device}")
|
17
18
|
self.device = torch.device(self.device_type)
|
19
|
+
self.gpu_id = gpu_id
|
@@ -0,0 +1,139 @@
|
|
1
|
+
from typing import Any, List, Optional, Union
|
2
|
+
|
3
|
+
from transformers import AutoProcessor, LlamaTokenizerFast, PretrainedConfig
|
4
|
+
from transformers.feature_extraction_utils import BatchFeature
|
5
|
+
from transformers.image_utils import ImageInput
|
6
|
+
from transformers.processing_utils import ProcessingKwargs, Unpack
|
7
|
+
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
8
|
+
|
9
|
+
try:
|
10
|
+
from transformers import Qwen2_5_VLProcessor
|
11
|
+
except ImportError:
|
12
|
+
raise ImportError(
|
13
|
+
"Qwen2_5_VLProcessor can not be found. Please upgrade your transformers version."
|
14
|
+
)
|
15
|
+
|
16
|
+
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
|
17
|
+
|
18
|
+
|
19
|
+
class DotsVisionConfig(PretrainedConfig):
|
20
|
+
model_type: str = "dots_vit"
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
embed_dim: int = 1536, # vision encoder embed size
|
25
|
+
hidden_size: int = 1536, # after merger hidden size
|
26
|
+
intermediate_size: int = 4224,
|
27
|
+
num_hidden_layers: int = 42,
|
28
|
+
num_attention_heads: int = 12,
|
29
|
+
num_channels: int = 3,
|
30
|
+
patch_size: int = 14,
|
31
|
+
spatial_merge_size: int = 2,
|
32
|
+
temporal_patch_size: int = 1,
|
33
|
+
rms_norm_eps: float = 1e-5,
|
34
|
+
use_bias: bool = False,
|
35
|
+
attn_implementation="flash_attention_2", # "eager","sdpa","flash_attention_2"
|
36
|
+
initializer_range=0.02,
|
37
|
+
init_merger_std=0.02,
|
38
|
+
is_causal=False, # ve causal forward
|
39
|
+
post_norm=True,
|
40
|
+
gradient_checkpointing=False,
|
41
|
+
**kwargs,
|
42
|
+
):
|
43
|
+
super().__init__(**kwargs)
|
44
|
+
self.embed_dim = embed_dim
|
45
|
+
self.hidden_size = hidden_size
|
46
|
+
self.intermediate_size = intermediate_size
|
47
|
+
self.num_hidden_layers = num_hidden_layers
|
48
|
+
self.num_attention_heads = num_attention_heads
|
49
|
+
self.num_channels = num_channels
|
50
|
+
self.patch_size = patch_size
|
51
|
+
self.spatial_merge_size = spatial_merge_size
|
52
|
+
self.temporal_patch_size = temporal_patch_size
|
53
|
+
self.rms_norm_eps = rms_norm_eps
|
54
|
+
self.use_bias = use_bias
|
55
|
+
self.attn_implementation = attn_implementation
|
56
|
+
self.initializer_range = initializer_range
|
57
|
+
self.init_merger_std = init_merger_std
|
58
|
+
self.is_causal = is_causal
|
59
|
+
self.post_norm = post_norm
|
60
|
+
self.gradient_checkpointing = gradient_checkpointing
|
61
|
+
|
62
|
+
|
63
|
+
class DotsVLMConfig(PretrainedConfig):
|
64
|
+
model_type = "dots_vlm"
|
65
|
+
|
66
|
+
def __init__(self, **kwargs):
|
67
|
+
super().__init__(**kwargs)
|
68
|
+
vision_config = kwargs.get("vision_config", {})
|
69
|
+
self.im_span_id = kwargs.get("image_token_id", 128815)
|
70
|
+
self.video_span_id = kwargs.get("video_token_id", 128836)
|
71
|
+
self.vision_config = DotsVisionConfig(**vision_config)
|
72
|
+
self.language_config = DeepseekV2Config(**kwargs)
|
73
|
+
self.architectures = ["DotsVLMForCausalLM"]
|
74
|
+
|
75
|
+
|
76
|
+
class DotsVLMProcessorKwargs(ProcessingKwargs, total=False):
|
77
|
+
_defaults = {
|
78
|
+
"text_kwargs": {
|
79
|
+
"padding": False,
|
80
|
+
},
|
81
|
+
}
|
82
|
+
|
83
|
+
|
84
|
+
class DotsVLMProcessor(Qwen2_5_VLProcessor):
|
85
|
+
r"""
|
86
|
+
Constructs a DotsVLM processor which derives from Qwen2_5_VLProcessor, but overrides the image and video token ids.
|
87
|
+
Besides, its tokenizer is a LlamaTokenizerFast instead of Qwen2TokenizerFast.
|
88
|
+
[`DotsVLMProcessor`] offers all the functionalities of [`DotsVisionConfig`] and [`LlamaTokenizerFast`]. See the
|
89
|
+
[`~DotsVLMProcessor.__call__`] and [`~DotsVLMProcessor.decode`] for more information.
|
90
|
+
Args:
|
91
|
+
image_processor ([`Qwen2VLImageProcessor`], *optional*):
|
92
|
+
The image processor is a required input.
|
93
|
+
tokenizer ([`LlamaTokenizerFast`], *optional*):
|
94
|
+
The tokenizer is a required input.
|
95
|
+
chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
|
96
|
+
in a chat into a tokenizable string.
|
97
|
+
"""
|
98
|
+
|
99
|
+
attributes = ["image_processor", "tokenizer"]
|
100
|
+
|
101
|
+
valid_kwargs = ["chat_template"]
|
102
|
+
|
103
|
+
tokenizer_class = ("LlamaTokenizer", "LlamaTokenizerFast")
|
104
|
+
|
105
|
+
def __init__(
|
106
|
+
self, image_processor=None, tokenizer=None, chat_template=None, **kwargs
|
107
|
+
):
|
108
|
+
super().__init__(image_processor, tokenizer, chat_template=chat_template)
|
109
|
+
self.image_token = (
|
110
|
+
"<|imgpad|>"
|
111
|
+
if not hasattr(tokenizer, "image_token")
|
112
|
+
else tokenizer.image_token
|
113
|
+
)
|
114
|
+
self.video_token = (
|
115
|
+
"<|video_pad|>"
|
116
|
+
if not hasattr(tokenizer, "video_token")
|
117
|
+
else tokenizer.video_token
|
118
|
+
)
|
119
|
+
self.img_token = (
|
120
|
+
"<|img|>" if not hasattr(tokenizer, "img_token") else tokenizer.img_token
|
121
|
+
)
|
122
|
+
self.endofimg_token = (
|
123
|
+
"<|endofimg|>"
|
124
|
+
if not hasattr(tokenizer, "endofimg_token")
|
125
|
+
else tokenizer.endofimg_token
|
126
|
+
)
|
127
|
+
self.image_token_id = (
|
128
|
+
tokenizer.image_token_id
|
129
|
+
if getattr(tokenizer, "image_token_id", None)
|
130
|
+
else tokenizer.encode(self.image_token)[0]
|
131
|
+
)
|
132
|
+
self.video_token_id = (
|
133
|
+
tokenizer.video_token_id
|
134
|
+
if getattr(tokenizer, "video_token_id", None)
|
135
|
+
else tokenizer.encode(self.video_token)[0]
|
136
|
+
)
|
137
|
+
|
138
|
+
|
139
|
+
AutoProcessor.register(DotsVLMConfig, DotsVLMProcessor)
|
sglang/srt/configs/internvl.py
CHANGED
@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
6
|
import sentencepiece as spm
|
7
7
|
from transformers import (
|
8
8
|
TOKENIZER_MAPPING,
|
9
|
+
GptOssConfig,
|
9
10
|
LlamaConfig,
|
10
11
|
PretrainedConfig,
|
11
12
|
PreTrainedTokenizer,
|
12
13
|
Qwen2Config,
|
13
14
|
Qwen3Config,
|
15
|
+
Qwen3MoeConfig,
|
14
16
|
)
|
15
17
|
|
16
18
|
from sglang.utils import logger
|
@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig):
|
|
316
318
|
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
317
319
|
self.llm_config = Qwen2Config(**llm_config)
|
318
320
|
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
|
321
|
+
self.llm_config = Qwen3MoeConfig(**llm_config)
|
322
|
+
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
|
319
323
|
self.llm_config = Qwen3Config(**llm_config)
|
324
|
+
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
|
325
|
+
self.llm_config = GptOssConfig(**llm_config)
|
320
326
|
else:
|
321
327
|
raise ValueError(
|
322
328
|
"Unsupported architecture: {}".format(
|