sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -61,6 +61,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
|
61
61
|
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
62
62
|
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
63
63
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
64
|
+
from sglang.srt.layers.moe import initialize_moe_config
|
64
65
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
65
66
|
from sglang.srt.managers.scheduler import Scheduler
|
66
67
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -509,6 +510,8 @@ def latency_test(
|
|
509
510
|
bench_args,
|
510
511
|
tp_rank,
|
511
512
|
):
|
513
|
+
initialize_moe_config(server_args)
|
514
|
+
|
512
515
|
# Set CPU affinity
|
513
516
|
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
514
517
|
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, tp_rank)
|
sglang/bench_one_batch_server.py
CHANGED
@@ -47,6 +47,7 @@ class BenchArgs:
|
|
47
47
|
profile: bool = False
|
48
48
|
profile_steps: int = 3
|
49
49
|
profile_by_stage: bool = False
|
50
|
+
dataset_path: str = ""
|
50
51
|
|
51
52
|
@staticmethod
|
52
53
|
def add_cli_args(parser: argparse.ArgumentParser):
|
@@ -83,6 +84,12 @@ class BenchArgs:
|
|
83
84
|
"--profile-steps", type=int, default=BenchArgs.profile_steps
|
84
85
|
)
|
85
86
|
parser.add_argument("--profile-by-stage", action="store_true")
|
87
|
+
parser.add_argument(
|
88
|
+
"--dataset-path",
|
89
|
+
type=str,
|
90
|
+
default=BenchArgs.dataset_path,
|
91
|
+
help="Path to the dataset.",
|
92
|
+
)
|
86
93
|
|
87
94
|
@classmethod
|
88
95
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -138,6 +145,7 @@ def run_one_case(
|
|
138
145
|
profile: bool = False,
|
139
146
|
profile_steps: int = 3,
|
140
147
|
profile_by_stage: bool = False,
|
148
|
+
dataset_path: str = "",
|
141
149
|
):
|
142
150
|
requests.post(url + "/flush_cache")
|
143
151
|
input_requests = sample_random_requests(
|
@@ -146,7 +154,7 @@ def run_one_case(
|
|
146
154
|
num_prompts=batch_size,
|
147
155
|
range_ratio=1.0,
|
148
156
|
tokenizer=tokenizer,
|
149
|
-
dataset_path=
|
157
|
+
dataset_path=dataset_path,
|
150
158
|
random_sample=True,
|
151
159
|
return_text=False,
|
152
160
|
)
|
@@ -345,6 +353,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
|
|
345
353
|
run_name="",
|
346
354
|
result_filename="",
|
347
355
|
tokenizer=tokenizer,
|
356
|
+
dataset_path=bench_args.dataset_path,
|
348
357
|
)
|
349
358
|
print("=" * 8 + " Warmup End " + "=" * 8 + "\n")
|
350
359
|
|
sglang/bench_serving.py
CHANGED
@@ -75,6 +75,7 @@ class RequestFuncInput:
|
|
75
75
|
lora_name: str
|
76
76
|
image_data: Optional[List[str]]
|
77
77
|
extra_request_body: Dict[str, Any]
|
78
|
+
timestamp: Optional[float] = None
|
78
79
|
|
79
80
|
|
80
81
|
@dataclass
|
@@ -696,6 +697,24 @@ def get_dataset(args, tokenizer):
|
|
696
697
|
apply_chat_template=args.apply_chat_template,
|
697
698
|
random_sample=True,
|
698
699
|
)
|
700
|
+
elif args.dataset_name == "mooncake":
|
701
|
+
# For mooncake, we don't generate the prompts here.
|
702
|
+
# We just load the raw trace data. The async generator will handle the rest.
|
703
|
+
if not args.dataset_path:
|
704
|
+
local_path = os.path.join("/tmp", args.mooncake_workload + "_trace.jsonl")
|
705
|
+
else:
|
706
|
+
local_path = args.dataset_path
|
707
|
+
|
708
|
+
if not os.path.exists(local_path):
|
709
|
+
download_and_cache_file(
|
710
|
+
MOONCAKE_DATASET_URL[args.mooncake_workload], local_path
|
711
|
+
)
|
712
|
+
|
713
|
+
with open(local_path, "r") as f:
|
714
|
+
all_requests_data = [json.loads(line) for line in f if line.strip()]
|
715
|
+
|
716
|
+
# Limit the number of requests based on --num-prompts
|
717
|
+
input_requests = all_requests_data[: args.num_prompts]
|
699
718
|
else:
|
700
719
|
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
701
720
|
return input_requests
|
@@ -750,6 +769,12 @@ class BenchmarkMetrics:
|
|
750
769
|
|
751
770
|
|
752
771
|
SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
772
|
+
MOONCAKE_DATASET_URL = {
|
773
|
+
"mooncake": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/arxiv-trace/mooncake_trace.jsonl",
|
774
|
+
"conversation": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/conversation_trace.jsonl",
|
775
|
+
"synthetic": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/synthetic_trace.jsonl",
|
776
|
+
"toolagent": "https://raw.githubusercontent.com/kvcache-ai/Mooncake/main/FAST25-release/traces/toolagent_trace.jsonl",
|
777
|
+
}
|
753
778
|
|
754
779
|
|
755
780
|
def download_and_cache_file(url: str, filename: Optional[str] = None):
|
@@ -808,6 +833,80 @@ class DatasetRow:
|
|
808
833
|
prompt_len: int
|
809
834
|
output_len: int
|
810
835
|
image_data: Optional[List[str]] = None
|
836
|
+
timestamp: Optional[float] = None
|
837
|
+
|
838
|
+
|
839
|
+
async def get_mooncake_request_over_time(
|
840
|
+
input_requests: List[Dict],
|
841
|
+
tokenizer: PreTrainedTokenizerBase,
|
842
|
+
slowdown_factor: float,
|
843
|
+
num_rounds: int,
|
844
|
+
) -> AsyncGenerator[DatasetRow, None]:
|
845
|
+
"""
|
846
|
+
An async generator that yields requests based on the timestamps in the Mooncake trace file,
|
847
|
+
with support for multi-round sessions.
|
848
|
+
"""
|
849
|
+
if not input_requests:
|
850
|
+
return
|
851
|
+
|
852
|
+
input_requests.sort(key=lambda r: r["timestamp"])
|
853
|
+
|
854
|
+
start_time = time.perf_counter()
|
855
|
+
trace_start_time_ms = input_requests[0]["timestamp"]
|
856
|
+
|
857
|
+
for record in input_requests:
|
858
|
+
# Calculate when this entire session should start
|
859
|
+
relative_arrival_time_s = (record["timestamp"] - trace_start_time_ms) / 1000.0
|
860
|
+
target_arrival_time_s = relative_arrival_time_s * slowdown_factor
|
861
|
+
|
862
|
+
current_elapsed_time_s = time.perf_counter() - start_time
|
863
|
+
sleep_duration_s = target_arrival_time_s - current_elapsed_time_s
|
864
|
+
if sleep_duration_s > 0:
|
865
|
+
await asyncio.sleep(sleep_duration_s)
|
866
|
+
|
867
|
+
# Once the session starts, generate all rounds for it as a burst
|
868
|
+
# This simulates a user engaging in a multi-turn conversation
|
869
|
+
|
870
|
+
# Base user query constructed from hash_ids
|
871
|
+
user_query_base = ""
|
872
|
+
hash_ids = record.get("hash_ids", [])
|
873
|
+
for hash_id in hash_ids:
|
874
|
+
user_query_base += f"{hash_id}" + " ".join(
|
875
|
+
["hi"] * 128
|
876
|
+
) # Shorter for multi-round
|
877
|
+
user_query_base += "Tell me a story based on this context."
|
878
|
+
|
879
|
+
output_len_per_round = record.get("output_length", 256)
|
880
|
+
chat_history = []
|
881
|
+
|
882
|
+
for i in range(num_rounds):
|
883
|
+
# Add user query for the current round
|
884
|
+
chat_history.append(
|
885
|
+
{"role": "user", "content": f"Round {i+1}: {user_query_base}"}
|
886
|
+
)
|
887
|
+
|
888
|
+
# Form the full prompt from history
|
889
|
+
try:
|
890
|
+
full_prompt_text = tokenizer.apply_chat_template(
|
891
|
+
chat_history, tokenize=False, add_generation_prompt=True
|
892
|
+
)
|
893
|
+
except Exception:
|
894
|
+
full_prompt_text = "\n".join(
|
895
|
+
[f"{msg['role']}: {msg['content']}" for msg in chat_history]
|
896
|
+
)
|
897
|
+
|
898
|
+
prompt_len = len(tokenizer.encode(full_prompt_text))
|
899
|
+
|
900
|
+
yield DatasetRow(
|
901
|
+
prompt=full_prompt_text,
|
902
|
+
prompt_len=prompt_len,
|
903
|
+
output_len=output_len_per_round,
|
904
|
+
)
|
905
|
+
|
906
|
+
# Add a placeholder assistant response for the next round's context
|
907
|
+
# We use a placeholder because we don't know the real response
|
908
|
+
placeholder_response = " ".join(["story"] * output_len_per_round)
|
909
|
+
chat_history.append({"role": "assistant", "content": placeholder_response})
|
811
910
|
|
812
911
|
|
813
912
|
def sample_mmmu_requests(
|
@@ -896,17 +995,25 @@ def sample_mmmu_requests(
|
|
896
995
|
prompt = f"Question: {question}\n\nAnswer: "
|
897
996
|
if apply_chat_template:
|
898
997
|
try:
|
998
|
+
is_phi4_multimodal = (
|
999
|
+
"phi-4-multimodal" in tokenizer.name_or_path.lower()
|
1000
|
+
)
|
1001
|
+
if is_phi4_multimodal:
|
1002
|
+
# <|endoftext10|> is the image token used in the phi-4-multimodal model.
|
1003
|
+
content = prompt.replace("image 1", "<|endoftext10|>")
|
1004
|
+
else:
|
1005
|
+
content = [
|
1006
|
+
{
|
1007
|
+
"type": "image_url",
|
1008
|
+
"image_url": {"url": image_data},
|
1009
|
+
},
|
1010
|
+
{"type": "text", "text": prompt},
|
1011
|
+
]
|
899
1012
|
prompt = tokenizer.apply_chat_template(
|
900
1013
|
[
|
901
1014
|
{
|
902
1015
|
"role": "user",
|
903
|
-
"content":
|
904
|
-
{
|
905
|
-
"type": "image_url",
|
906
|
-
"image_url": {"url": image_data},
|
907
|
-
},
|
908
|
-
{"type": "text", "text": prompt},
|
909
|
-
],
|
1016
|
+
"content": content,
|
910
1017
|
}
|
911
1018
|
],
|
912
1019
|
add_generation_prompt=True,
|
@@ -1359,19 +1466,41 @@ def sample_generated_shared_prefix_requests(
|
|
1359
1466
|
async def get_request(
|
1360
1467
|
input_requests: List[DatasetRow],
|
1361
1468
|
request_rate: float,
|
1469
|
+
use_trace_timestamps: bool = False,
|
1470
|
+
slowdown_factor: float = 1.0,
|
1362
1471
|
) -> AsyncGenerator[DatasetRow, None]:
|
1363
|
-
|
1364
|
-
|
1365
|
-
|
1472
|
+
if use_trace_timestamps:
|
1473
|
+
print(
|
1474
|
+
f"Using trace timestamps for request generation with slowdown factor {slowdown_factor}."
|
1475
|
+
)
|
1476
|
+
# Sort requests by timestamp for correct replay
|
1477
|
+
input_requests.sort(key=lambda r: r.timestamp)
|
1366
1478
|
|
1367
|
-
|
1368
|
-
|
1369
|
-
|
1479
|
+
start_time = time.perf_counter()
|
1480
|
+
trace_start_time_ms = input_requests[0].timestamp if input_requests else 0
|
1481
|
+
|
1482
|
+
for request in input_requests:
|
1483
|
+
trace_time_s = (request.timestamp - trace_start_time_ms) / 1000.0
|
1484
|
+
target_arrival_time = start_time + (trace_time_s * slowdown_factor)
|
1485
|
+
|
1486
|
+
sleep_duration = target_arrival_time - time.perf_counter()
|
1487
|
+
if sleep_duration > 0:
|
1488
|
+
await asyncio.sleep(sleep_duration)
|
1489
|
+
|
1490
|
+
yield request
|
1491
|
+
else:
|
1492
|
+
input_requests_iter = iter(input_requests)
|
1493
|
+
for request in input_requests_iter:
|
1494
|
+
yield request
|
1370
1495
|
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1496
|
+
if request_rate == float("inf"):
|
1497
|
+
# If the request rate is infinity, then we don't need to wait.
|
1498
|
+
continue
|
1499
|
+
|
1500
|
+
# Sample the request interval from the exponential distribution.
|
1501
|
+
interval = np.random.exponential(1.0 / request_rate)
|
1502
|
+
# The next request will be sent after the interval.
|
1503
|
+
await asyncio.sleep(interval)
|
1375
1504
|
|
1376
1505
|
|
1377
1506
|
def calculate_metrics(
|
@@ -1397,7 +1526,7 @@ def calculate_metrics(
|
|
1397
1526
|
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
|
1398
1527
|
)
|
1399
1528
|
retokenized_output_lens.append(retokenized_output_len)
|
1400
|
-
total_input +=
|
1529
|
+
total_input += outputs[i].prompt_len
|
1401
1530
|
if output_len > 1:
|
1402
1531
|
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
1403
1532
|
itls += outputs[i].itl
|
@@ -1469,6 +1598,9 @@ async def benchmark(
|
|
1469
1598
|
pd_separated: bool = False,
|
1470
1599
|
flush_cache: bool = False,
|
1471
1600
|
warmup_requests: int = 1,
|
1601
|
+
use_trace_timestamps: bool = False,
|
1602
|
+
mooncake_slowdown_factor=1.0,
|
1603
|
+
mooncake_num_rounds=1,
|
1472
1604
|
):
|
1473
1605
|
if backend in ASYNC_REQUEST_FUNCS:
|
1474
1606
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
@@ -1488,8 +1620,32 @@ async def benchmark(
|
|
1488
1620
|
# Warmup
|
1489
1621
|
print(f"Starting warmup with {warmup_requests} sequences...")
|
1490
1622
|
|
1491
|
-
#
|
1492
|
-
|
1623
|
+
# Handle the data structure difference for the warmup request
|
1624
|
+
if args.dataset_name == "mooncake":
|
1625
|
+
# For mooncake, input_requests is a list of dicts.
|
1626
|
+
# We need to build a temporary DatasetRow for the warmup phase.
|
1627
|
+
warmup_record = input_requests[0]
|
1628
|
+
|
1629
|
+
# Build prompt from hash_ids, just like in the async generator
|
1630
|
+
hash_ids = warmup_record.get("hash_ids", [])
|
1631
|
+
prompt_text = ""
|
1632
|
+
for hash_id in hash_ids:
|
1633
|
+
prompt_text += f"{hash_id}" + " ".join(["hi"] * 512)
|
1634
|
+
prompt_text += "Can you tell me a detailed story in 1000 words?"
|
1635
|
+
|
1636
|
+
output_len = warmup_record.get("output_length", 32)
|
1637
|
+
prompt_len = len(tokenizer.encode(prompt_text))
|
1638
|
+
|
1639
|
+
# Create a temporary DatasetRow object for warmup
|
1640
|
+
test_request = DatasetRow(
|
1641
|
+
prompt=prompt_text,
|
1642
|
+
prompt_len=prompt_len,
|
1643
|
+
output_len=output_len,
|
1644
|
+
image_data=None, # Mooncake doesn't have image data
|
1645
|
+
)
|
1646
|
+
else:
|
1647
|
+
# For all other datasets, input_requests is a list of DatasetRow objects
|
1648
|
+
test_request = input_requests[0]
|
1493
1649
|
|
1494
1650
|
if lora_names is not None and len(lora_names) != 0:
|
1495
1651
|
lora_name = lora_names[0]
|
@@ -1543,12 +1699,26 @@ async def benchmark(
|
|
1543
1699
|
if profile_output.success:
|
1544
1700
|
print("Profiler started")
|
1545
1701
|
|
1546
|
-
pbar = None if disable_tqdm else tqdm(total=len(input_requests))
|
1547
|
-
|
1548
1702
|
# Run all requests
|
1549
1703
|
benchmark_start_time = time.perf_counter()
|
1550
1704
|
tasks: List[asyncio.Task] = []
|
1551
|
-
|
1705
|
+
pbar_total = len(input_requests)
|
1706
|
+
if (
|
1707
|
+
backend == "sglang" and args.dataset_name == "mooncake"
|
1708
|
+
): # Assuming mooncake is mainly for sglang or similar backends
|
1709
|
+
print("Using time-based Mooncake request scheduler, ignoring --request-rate.")
|
1710
|
+
request_generator = get_mooncake_request_over_time(
|
1711
|
+
input_requests, tokenizer, mooncake_slowdown_factor, mooncake_num_rounds
|
1712
|
+
)
|
1713
|
+
print(
|
1714
|
+
f"Starting Mooncake trace replay. Sessions: {len(input_requests)}, Rounds per session: {mooncake_num_rounds}. Slowdown factor: {mooncake_slowdown_factor}"
|
1715
|
+
)
|
1716
|
+
pbar_total *= args.mooncake_num_rounds
|
1717
|
+
else:
|
1718
|
+
request_generator = get_request(input_requests, request_rate)
|
1719
|
+
|
1720
|
+
pbar = None if disable_tqdm else tqdm(total=pbar_total)
|
1721
|
+
async for request in request_generator:
|
1552
1722
|
if lora_names is not None and len(lora_names) != 0:
|
1553
1723
|
idx = random.randint(0, len(lora_names) - 1)
|
1554
1724
|
lora_name = lora_names[idx]
|
@@ -1564,6 +1734,7 @@ async def benchmark(
|
|
1564
1734
|
lora_name=lora_name,
|
1565
1735
|
image_data=request.image_data,
|
1566
1736
|
extra_request_body=extra_request_body,
|
1737
|
+
timestamp=request.timestamp,
|
1567
1738
|
)
|
1568
1739
|
|
1569
1740
|
tasks.append(
|
@@ -1609,7 +1780,11 @@ async def benchmark(
|
|
1609
1780
|
|
1610
1781
|
print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
|
1611
1782
|
print("{:<40} {:<10}".format("Backend:", backend))
|
1612
|
-
print(
|
1783
|
+
print(
|
1784
|
+
"{:<40} {:<10}".format(
|
1785
|
+
"Traffic request rate:", "trace" if use_trace_timestamps else request_rate
|
1786
|
+
)
|
1787
|
+
)
|
1613
1788
|
print(
|
1614
1789
|
"{:<40} {:<10}".format(
|
1615
1790
|
"Max request concurrency:",
|
@@ -1678,7 +1853,7 @@ async def benchmark(
|
|
1678
1853
|
# Arguments
|
1679
1854
|
"backend": args.backend,
|
1680
1855
|
"dataset_name": args.dataset_name,
|
1681
|
-
"request_rate": request_rate,
|
1856
|
+
"request_rate": "trace" if use_trace_timestamps else request_rate,
|
1682
1857
|
"max_concurrency": max_concurrency,
|
1683
1858
|
"sharegpt_output_len": args.sharegpt_output_len,
|
1684
1859
|
"random_input_len": args.random_input_len,
|
@@ -1731,7 +1906,9 @@ async def benchmark(
|
|
1731
1906
|
elif args.dataset_name.startswith("random"):
|
1732
1907
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
|
1733
1908
|
else:
|
1734
|
-
output_file_name =
|
1909
|
+
output_file_name = (
|
1910
|
+
f"{args.backend}_{now}_{args.num_prompts}_{args.dataset_name}.jsonl"
|
1911
|
+
)
|
1735
1912
|
|
1736
1913
|
result_details = {
|
1737
1914
|
"input_lens": [output.prompt_len for output in outputs],
|
@@ -1786,6 +1963,17 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1786
1963
|
if not hasattr(args, "tokenize_prompt"):
|
1787
1964
|
args.tokenize_prompt = False
|
1788
1965
|
|
1966
|
+
if not hasattr(args, "use_trace_timestamps"):
|
1967
|
+
args.use_trace_timestamps = False
|
1968
|
+
if not hasattr(args, "mooncake_slowdown_factor"):
|
1969
|
+
args.mooncake_slowdown_factor = 1.0
|
1970
|
+
|
1971
|
+
if not hasattr(args, "mooncake_slowdown_factor"):
|
1972
|
+
args.mooncake_slowdown_factor = 1.0
|
1973
|
+
|
1974
|
+
if not hasattr(args, "mooncake_num_rounds"):
|
1975
|
+
args.mooncake_num_rounds = 1
|
1976
|
+
|
1789
1977
|
print(f"benchmark_args={args}")
|
1790
1978
|
|
1791
1979
|
# Set global environments
|
@@ -1919,6 +2107,9 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1919
2107
|
pd_separated=args.pd_separated,
|
1920
2108
|
flush_cache=args.flush_cache,
|
1921
2109
|
warmup_requests=args.warmup_requests,
|
2110
|
+
use_trace_timestamps=args.use_trace_timestamps,
|
2111
|
+
mooncake_slowdown_factor=args.mooncake_slowdown_factor,
|
2112
|
+
mooncake_num_rounds=args.mooncake_num_rounds,
|
1922
2113
|
)
|
1923
2114
|
)
|
1924
2115
|
|
@@ -1975,6 +2166,7 @@ if __name__ == "__main__":
|
|
1975
2166
|
"generated-shared-prefix",
|
1976
2167
|
"mmmu",
|
1977
2168
|
"random-image",
|
2169
|
+
"mooncake",
|
1978
2170
|
],
|
1979
2171
|
help="Name of the dataset to benchmark on.",
|
1980
2172
|
)
|
@@ -2051,6 +2243,11 @@ if __name__ == "__main__":
|
|
2051
2243
|
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
2052
2244
|
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
|
2053
2245
|
)
|
2246
|
+
parser.add_argument(
|
2247
|
+
"--use-trace-timestamps",
|
2248
|
+
action="store_true",
|
2249
|
+
help="Use timestamps from the trace file for request scheduling. Only valid for 'mooncake' dataset.",
|
2250
|
+
)
|
2054
2251
|
parser.add_argument(
|
2055
2252
|
"--max-concurrency",
|
2056
2253
|
type=int,
|
@@ -2174,5 +2371,33 @@ if __name__ == "__main__":
|
|
2174
2371
|
default=256,
|
2175
2372
|
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
2176
2373
|
)
|
2374
|
+
mooncake_group = parser.add_argument_group("mooncake dataset arguments")
|
2375
|
+
mooncake_group.add_argument(
|
2376
|
+
"--mooncake-slowdown-factor",
|
2377
|
+
type=float,
|
2378
|
+
default=1.0,
|
2379
|
+
help="Slowdown factor for replaying the mooncake trace. "
|
2380
|
+
"A value of 2.0 means the replay is twice as slow. "
|
2381
|
+
"NOTE: --request-rate is IGNORED in mooncake mode.",
|
2382
|
+
)
|
2383
|
+
mooncake_group.add_argument(
|
2384
|
+
"--mooncake-num-rounds",
|
2385
|
+
type=int,
|
2386
|
+
default=1,
|
2387
|
+
help="Number of conversation rounds for each session in the mooncake dataset. "
|
2388
|
+
"A value > 1 will enable true multi-turn session benchmarking.",
|
2389
|
+
)
|
2390
|
+
mooncake_group.add_argument(
|
2391
|
+
"--mooncake-workload",
|
2392
|
+
type=str,
|
2393
|
+
default="conversation",
|
2394
|
+
choices=[
|
2395
|
+
"mooncake",
|
2396
|
+
"conversation",
|
2397
|
+
"synthetic",
|
2398
|
+
"toolagent",
|
2399
|
+
],
|
2400
|
+
help="Underlying workload for the mooncake dataset.",
|
2401
|
+
)
|
2177
2402
|
args = parser.parse_args()
|
2178
2403
|
run_benchmark(args)
|
sglang/lang/interpreter.py
CHANGED
@@ -740,7 +740,7 @@ class StreamExecutor:
|
|
740
740
|
# Execute the stored lazy generation calls
|
741
741
|
self.backend.role_end_generate(self)
|
742
742
|
|
743
|
-
from sglang.srt.reasoning_parser import ReasoningParser
|
743
|
+
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
744
744
|
|
745
745
|
reasoning_parser = ReasoningParser(expr.model_type)
|
746
746
|
other = expr.expr
|
sglang/srt/configs/__init__.py
CHANGED
@@ -5,6 +5,8 @@ from sglang.srt.configs.exaone import ExaoneConfig
|
|
5
5
|
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
6
6
|
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
7
7
|
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
8
|
+
from sglang.srt.configs.longcat_flash import LongcatFlashConfig
|
9
|
+
from sglang.srt.configs.qwen3_next import Qwen3NextConfig
|
8
10
|
from sglang.srt.configs.step3_vl import (
|
9
11
|
Step3TextConfig,
|
10
12
|
Step3VisionEncoderConfig,
|
@@ -16,10 +18,12 @@ __all__ = [
|
|
16
18
|
"ChatGLMConfig",
|
17
19
|
"DbrxConfig",
|
18
20
|
"DeepseekVL2Config",
|
21
|
+
"LongcatFlashConfig",
|
19
22
|
"MultiModalityConfig",
|
20
23
|
"KimiVLConfig",
|
21
24
|
"MoonViTConfig",
|
22
25
|
"Step3VLConfig",
|
23
26
|
"Step3TextConfig",
|
24
27
|
"Step3VisionEncoderConfig",
|
28
|
+
"Qwen3NextConfig",
|
25
29
|
]
|
sglang/srt/configs/internvl.py
CHANGED
@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|
6
6
|
import sentencepiece as spm
|
7
7
|
from transformers import (
|
8
8
|
TOKENIZER_MAPPING,
|
9
|
+
GptOssConfig,
|
9
10
|
LlamaConfig,
|
10
11
|
PretrainedConfig,
|
11
12
|
PreTrainedTokenizer,
|
12
13
|
Qwen2Config,
|
13
14
|
Qwen3Config,
|
15
|
+
Qwen3MoeConfig,
|
14
16
|
)
|
15
17
|
|
16
18
|
from sglang.utils import logger
|
@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig):
|
|
316
318
|
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
317
319
|
self.llm_config = Qwen2Config(**llm_config)
|
318
320
|
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
|
321
|
+
self.llm_config = Qwen3MoeConfig(**llm_config)
|
322
|
+
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
|
319
323
|
self.llm_config = Qwen3Config(**llm_config)
|
324
|
+
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
|
325
|
+
self.llm_config = GptOssConfig(**llm_config)
|
320
326
|
else:
|
321
327
|
raise ValueError(
|
322
328
|
"Unsupported architecture: {}".format(
|
@@ -0,0 +1,104 @@
|
|
1
|
+
from transformers.configuration_utils import PretrainedConfig
|
2
|
+
from transformers.utils import logging
|
3
|
+
|
4
|
+
logger = logging.get_logger(__name__)
|
5
|
+
|
6
|
+
FLASH_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
7
|
+
|
8
|
+
|
9
|
+
class LongcatFlashConfig(PretrainedConfig):
|
10
|
+
model_type = "longcat_flash"
|
11
|
+
keys_to_ignore_at_inference = ["past_key_values"]
|
12
|
+
|
13
|
+
def __init__(
|
14
|
+
self,
|
15
|
+
vocab_size=131072,
|
16
|
+
hidden_size=6144,
|
17
|
+
intermediate_size=None,
|
18
|
+
ffn_hidden_size=12288,
|
19
|
+
expert_ffn_hidden_size=2048,
|
20
|
+
num_layers=28,
|
21
|
+
num_hidden_layers=None,
|
22
|
+
num_attention_heads=64,
|
23
|
+
ep_size=1,
|
24
|
+
kv_lora_rank=512,
|
25
|
+
q_lora_rank=1536,
|
26
|
+
qk_rope_head_dim=128,
|
27
|
+
qk_nope_head_dim=128,
|
28
|
+
v_head_dim=128,
|
29
|
+
n_routed_experts=512,
|
30
|
+
moe_topk=12,
|
31
|
+
norm_topk_prob=False,
|
32
|
+
max_position_embeddings=131072,
|
33
|
+
rms_norm_eps=1e-05,
|
34
|
+
use_cache=True,
|
35
|
+
pad_token_id=None,
|
36
|
+
bos_token_id=1,
|
37
|
+
eos_token_id=2,
|
38
|
+
pretraining_tp=1,
|
39
|
+
tie_word_embeddings=False,
|
40
|
+
rope_theta=10000000.0,
|
41
|
+
rope_scaling=None,
|
42
|
+
attention_bias=False,
|
43
|
+
attention_dropout=0.0,
|
44
|
+
mla_scale_q_lora=True,
|
45
|
+
mla_scale_kv_lora=True,
|
46
|
+
torch_dtype="bfloat16",
|
47
|
+
params_dtype="bfloat16",
|
48
|
+
rounter_params_dtype="float32",
|
49
|
+
router_bias=False,
|
50
|
+
topk_method=None,
|
51
|
+
routed_scaling_factor=6.0,
|
52
|
+
zero_expert_num=256,
|
53
|
+
zero_expert_type="identity",
|
54
|
+
nextn_use_scmoe=False,
|
55
|
+
num_nextn_predict_layers=1,
|
56
|
+
**kwargs,
|
57
|
+
):
|
58
|
+
super().__init__(
|
59
|
+
pad_token_id=pad_token_id,
|
60
|
+
bos_token_id=bos_token_id,
|
61
|
+
eos_token_id=eos_token_id,
|
62
|
+
tie_word_embeddings=tie_word_embeddings,
|
63
|
+
torch_dtype=torch_dtype,
|
64
|
+
params_dtype=params_dtype,
|
65
|
+
rounter_params_dtype=rounter_params_dtype,
|
66
|
+
topk_method=topk_method,
|
67
|
+
router_bias=router_bias,
|
68
|
+
nextn_use_scmoe=nextn_use_scmoe,
|
69
|
+
num_nextn_predict_layers=num_nextn_predict_layers,
|
70
|
+
**kwargs,
|
71
|
+
)
|
72
|
+
self.vocab_size = vocab_size
|
73
|
+
self.max_position_embeddings = max_position_embeddings
|
74
|
+
self.hidden_size = hidden_size
|
75
|
+
self.num_hidden_layers = (
|
76
|
+
num_hidden_layers if num_hidden_layers is not None else num_layers
|
77
|
+
)
|
78
|
+
self.intermediate_size = (
|
79
|
+
intermediate_size if intermediate_size is not None else ffn_hidden_size
|
80
|
+
)
|
81
|
+
self.moe_intermediate_size = expert_ffn_hidden_size
|
82
|
+
self.num_attention_heads = num_attention_heads
|
83
|
+
self.ep_size = ep_size
|
84
|
+
self.kv_lora_rank = kv_lora_rank
|
85
|
+
self.q_lora_rank = q_lora_rank
|
86
|
+
self.qk_rope_head_dim = qk_rope_head_dim
|
87
|
+
self.v_head_dim = v_head_dim
|
88
|
+
self.qk_nope_head_dim = qk_nope_head_dim
|
89
|
+
self.n_routed_experts = n_routed_experts
|
90
|
+
self.moe_topk = moe_topk
|
91
|
+
self.norm_topk_prob = norm_topk_prob
|
92
|
+
self.rms_norm_eps = rms_norm_eps
|
93
|
+
self.pretraining_tp = pretraining_tp
|
94
|
+
self.use_cache = use_cache
|
95
|
+
self.rope_theta = rope_theta
|
96
|
+
self.rope_scaling = rope_scaling
|
97
|
+
self.attention_bias = attention_bias
|
98
|
+
self.attention_dropout = attention_dropout
|
99
|
+
self.mla_scale_q_lora = mla_scale_q_lora
|
100
|
+
self.mla_scale_kv_lora = mla_scale_kv_lora
|
101
|
+
self.zero_expert_num = zero_expert_num
|
102
|
+
self.zero_expert_type = zero_expert_type
|
103
|
+
self.routed_scaling_factor = routed_scaling_factor
|
104
|
+
self.hidden_act = "silu"
|