sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +18 -3
- sglang/compile_deep_gemm.py +13 -7
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +25 -2
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -5
- sglang/srt/entrypoints/engine.py +13 -5
- sglang/srt/entrypoints/http_server.py +22 -3
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +7 -0
- sglang/srt/eplb/expert_distribution.py +34 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +7 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
- sglang/srt/layers/communicator.py +23 -1
- sglang/srt/layers/layernorm.py +16 -2
- sglang/srt/layers/logits_processor.py +4 -20
- sglang/srt/layers/moe/ep_moe/layer.py +0 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
- sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
- sglang/srt/layers/moe/topk.py +31 -6
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +9 -78
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/rotary_embedding.py +117 -45
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +26 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +164 -129
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +154 -59
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +171 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +11 -11
- sglang/srt/model_executor/model_runner.py +76 -21
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +149 -34
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +0 -1
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +1 -1
- sglang/srt/models/qwen2.py +1 -1
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +1 -1
- sglang/srt/models/qwen3_moe.py +16 -8
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +103 -22
- sglang/srt/single_batch_overlap.py +4 -1
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +55 -32
- sglang/srt/utils/hf_transformers_utils.py +38 -16
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
|
@@ -11,6 +11,11 @@ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruc
|
|
|
11
11
|
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
|
|
12
12
|
## run with profiling:
|
|
13
13
|
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
|
|
14
|
+
## run with profiling to custom directory:
|
|
15
|
+
export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
|
|
16
|
+
python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile
|
|
17
|
+
## run with CUDA profiler (nsys):
|
|
18
|
+
nsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profiler_activities CUDA_PROFILER
|
|
14
19
|
# Usage (correctness test):
|
|
15
20
|
python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
|
|
16
21
|
|
|
@@ -93,6 +98,68 @@ profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
|
|
|
93
98
|
]
|
|
94
99
|
|
|
95
100
|
|
|
101
|
+
def start_profile(profiler_activities, profile_record_shapes=False, rank_print=print):
|
|
102
|
+
"""
|
|
103
|
+
Abstracted function to start profiling based on profiler_activities.
|
|
104
|
+
Returns profiler object (or None).
|
|
105
|
+
"""
|
|
106
|
+
if "CUDA_PROFILER" in profiler_activities:
|
|
107
|
+
try:
|
|
108
|
+
torch.cuda.cudart().cudaProfilerStart()
|
|
109
|
+
rank_print("CUDA Profiler started (nsys will begin capturing)")
|
|
110
|
+
except Exception as e:
|
|
111
|
+
rank_print(f"Failed to start CUDA profiler: {e}")
|
|
112
|
+
return None
|
|
113
|
+
else:
|
|
114
|
+
activities = []
|
|
115
|
+
if "CPU" in profiler_activities:
|
|
116
|
+
activities.append(torch.profiler.ProfilerActivity.CPU)
|
|
117
|
+
if "GPU" in profiler_activities:
|
|
118
|
+
activities.append(torch.profiler.ProfilerActivity.CUDA)
|
|
119
|
+
if activities:
|
|
120
|
+
profiler = torch.profiler.profile(
|
|
121
|
+
activities=activities,
|
|
122
|
+
with_stack=True,
|
|
123
|
+
record_shapes=profile_record_shapes,
|
|
124
|
+
)
|
|
125
|
+
profiler.start()
|
|
126
|
+
return profiler
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def stop_profile(
|
|
131
|
+
profiler,
|
|
132
|
+
profiler_activities,
|
|
133
|
+
rank_print=print,
|
|
134
|
+
save_trace=False,
|
|
135
|
+
trace_filename=None,
|
|
136
|
+
stage=None,
|
|
137
|
+
):
|
|
138
|
+
"""
|
|
139
|
+
Abstracted function to stop profiling based on profiler_activities.
|
|
140
|
+
Optionally saves trace results and prints completion messages.
|
|
141
|
+
"""
|
|
142
|
+
if "CUDA_PROFILER" in profiler_activities:
|
|
143
|
+
try:
|
|
144
|
+
torch.cuda.cudart().cudaProfilerStop()
|
|
145
|
+
rank_print("CUDA Profiler stopped (nsys should dump traces)")
|
|
146
|
+
except Exception as e:
|
|
147
|
+
rank_print(f"Failed to stop CUDA profiler: {e}")
|
|
148
|
+
elif profiler is not None:
|
|
149
|
+
profiler.stop()
|
|
150
|
+
|
|
151
|
+
if save_trace:
|
|
152
|
+
if profiler is not None:
|
|
153
|
+
if trace_filename:
|
|
154
|
+
_save_profile_trace_results(profiler, trace_filename)
|
|
155
|
+
stage_desc = f"for {stage}" if stage else ""
|
|
156
|
+
rank_print(
|
|
157
|
+
f"torch profiler chrome trace {stage_desc} saved to {trace_filename}"
|
|
158
|
+
)
|
|
159
|
+
if "CUDA_PROFILER" in profiler_activities:
|
|
160
|
+
rank_print(f"CUDA profiler trace for {stage} completed")
|
|
161
|
+
|
|
162
|
+
|
|
96
163
|
@dataclasses.dataclass
|
|
97
164
|
class BenchArgs:
|
|
98
165
|
run_name: str = "default"
|
|
@@ -107,6 +174,8 @@ class BenchArgs:
|
|
|
107
174
|
log_decode_step: int = 0
|
|
108
175
|
profile: bool = False
|
|
109
176
|
profile_record_shapes: bool = False
|
|
177
|
+
profiler_activities: Tuple[str] = ("CPU", "GPU")
|
|
178
|
+
profile_stage: str = "all"
|
|
110
179
|
profile_filename_prefix: str = "profile"
|
|
111
180
|
|
|
112
181
|
@staticmethod
|
|
@@ -135,14 +204,27 @@ class BenchArgs:
|
|
|
135
204
|
default=BenchArgs.log_decode_step,
|
|
136
205
|
help="Log decode latency by step, default is set to zero to disable.",
|
|
137
206
|
)
|
|
138
|
-
parser.add_argument(
|
|
139
|
-
"--profile", action="store_true", help="Use Torch Profiler."
|
|
140
|
-
)
|
|
207
|
+
parser.add_argument("--profile", action="store_true", help="Enable profiling.")
|
|
141
208
|
parser.add_argument(
|
|
142
209
|
"--profile-record-shapes",
|
|
143
210
|
action="store_true",
|
|
144
211
|
help="Record tensor shapes in profiling results.",
|
|
145
212
|
)
|
|
213
|
+
parser.add_argument(
|
|
214
|
+
"--profiler_activities",
|
|
215
|
+
type=str,
|
|
216
|
+
nargs="+",
|
|
217
|
+
default=["CPU", "GPU"],
|
|
218
|
+
choices=["CPU", "GPU", "CUDA_PROFILER"],
|
|
219
|
+
help="Profiler activities: CPU, GPU, CUDA_PROFILER. If CPU/GPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.",
|
|
220
|
+
)
|
|
221
|
+
parser.add_argument(
|
|
222
|
+
"--profile-stage",
|
|
223
|
+
type=str,
|
|
224
|
+
default=BenchArgs.profile_stage,
|
|
225
|
+
choices=["all", "prefill", "decode"],
|
|
226
|
+
help="Which stage to profile: all, prefill, or decode only.",
|
|
227
|
+
)
|
|
146
228
|
parser.add_argument(
|
|
147
229
|
"--profile-filename-prefix",
|
|
148
230
|
type=str,
|
|
@@ -337,6 +419,18 @@ def _read_prompts_from_file(prompt_file, rank_print):
|
|
|
337
419
|
return pf.readlines()
|
|
338
420
|
|
|
339
421
|
|
|
422
|
+
def _get_torch_profiler_output_dir():
|
|
423
|
+
return os.environ.get("SGLANG_TORCH_PROFILER_DIR", "/tmp")
|
|
424
|
+
|
|
425
|
+
|
|
426
|
+
def _create_torch_profiler_filename(
|
|
427
|
+
profile_filename_prefix, batch_size, input_len, output_len, stage
|
|
428
|
+
):
|
|
429
|
+
output_dir = _get_torch_profiler_output_dir()
|
|
430
|
+
filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz"
|
|
431
|
+
return os.path.join(output_dir, filename)
|
|
432
|
+
|
|
433
|
+
|
|
340
434
|
def _save_profile_trace_results(profiler, filename):
|
|
341
435
|
parent_dir = os.path.dirname(os.path.abspath(filename))
|
|
342
436
|
os.makedirs(parent_dir, exist_ok=True)
|
|
@@ -413,7 +507,10 @@ def latency_test_run_once(
|
|
|
413
507
|
log_decode_step,
|
|
414
508
|
profile,
|
|
415
509
|
profile_record_shapes,
|
|
510
|
+
profiler_activities,
|
|
416
511
|
profile_filename_prefix,
|
|
512
|
+
profile_stage,
|
|
513
|
+
tp_rank,
|
|
417
514
|
):
|
|
418
515
|
max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
|
|
419
516
|
if batch_size > max_batch_size:
|
|
@@ -422,7 +519,6 @@ def latency_test_run_once(
|
|
|
422
519
|
)
|
|
423
520
|
return
|
|
424
521
|
|
|
425
|
-
# Clear the pools.
|
|
426
522
|
model_runner.req_to_token_pool.clear()
|
|
427
523
|
model_runner.token_to_kv_pool_allocator.clear()
|
|
428
524
|
|
|
@@ -436,20 +532,33 @@ def latency_test_run_once(
|
|
|
436
532
|
tot_latency = 0
|
|
437
533
|
|
|
438
534
|
profiler = None
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
535
|
+
enable_profile_prefill = profile and profile_stage in ["all", "prefill"]
|
|
536
|
+
if enable_profile_prefill:
|
|
537
|
+
profiler = start_profile(
|
|
538
|
+
profiler_activities,
|
|
539
|
+
profile_record_shapes=profile_record_shapes,
|
|
540
|
+
rank_print=rank_print,
|
|
444
541
|
)
|
|
445
|
-
profiler.start()
|
|
446
542
|
|
|
447
|
-
# Prefill
|
|
448
543
|
synchronize(device)
|
|
449
544
|
tic = time.perf_counter()
|
|
450
545
|
next_token_ids, _, batch = extend(reqs, model_runner)
|
|
451
546
|
synchronize(device)
|
|
452
547
|
prefill_latency = time.perf_counter() - tic
|
|
548
|
+
|
|
549
|
+
if enable_profile_prefill:
|
|
550
|
+
trace_filename = _create_torch_profiler_filename(
|
|
551
|
+
profile_filename_prefix, batch_size, input_len, output_len, "prefill"
|
|
552
|
+
)
|
|
553
|
+
stop_profile(
|
|
554
|
+
profiler,
|
|
555
|
+
profiler_activities,
|
|
556
|
+
rank_print=rank_print,
|
|
557
|
+
save_trace=True,
|
|
558
|
+
trace_filename=trace_filename,
|
|
559
|
+
stage="prefill",
|
|
560
|
+
)
|
|
561
|
+
|
|
453
562
|
tot_latency += prefill_latency
|
|
454
563
|
throughput = input_len * batch_size / prefill_latency
|
|
455
564
|
rank_print(
|
|
@@ -458,29 +567,37 @@ def latency_test_run_once(
|
|
|
458
567
|
measurement_results["prefill_latency"] = prefill_latency
|
|
459
568
|
measurement_results["prefill_throughput"] = throughput
|
|
460
569
|
|
|
461
|
-
if profile:
|
|
462
|
-
profiler.stop()
|
|
463
|
-
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
|
|
464
|
-
_save_profile_trace_results(profiler, trace_filename)
|
|
465
|
-
rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
|
|
466
|
-
|
|
467
|
-
# Decode
|
|
468
570
|
decode_latencies = []
|
|
571
|
+
profile_step_of_interest = output_len // 2
|
|
572
|
+
enable_profile_decode = profile and profile_stage in ["all", "decode"]
|
|
469
573
|
for i in range(output_len - 1):
|
|
470
574
|
synchronize(device)
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
profiler =
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
575
|
+
profiler = None
|
|
576
|
+
if enable_profile_decode and i == profile_step_of_interest:
|
|
577
|
+
profiler = start_profile(
|
|
578
|
+
profiler_activities,
|
|
579
|
+
profile_record_shapes=profile_record_shapes,
|
|
580
|
+
rank_print=rank_print,
|
|
477
581
|
)
|
|
478
|
-
profiler.start()
|
|
479
582
|
|
|
480
583
|
tic = time.perf_counter()
|
|
481
584
|
next_token_ids, _ = decode(next_token_ids, batch, model_runner)
|
|
482
585
|
synchronize(device)
|
|
483
586
|
latency = time.perf_counter() - tic
|
|
587
|
+
|
|
588
|
+
if enable_profile_decode and i == profile_step_of_interest:
|
|
589
|
+
trace_filename = _create_torch_profiler_filename(
|
|
590
|
+
profile_filename_prefix, batch_size, input_len, output_len, "decode"
|
|
591
|
+
)
|
|
592
|
+
stop_profile(
|
|
593
|
+
profiler,
|
|
594
|
+
profiler_activities,
|
|
595
|
+
rank_print=rank_print,
|
|
596
|
+
save_trace=True,
|
|
597
|
+
trace_filename=trace_filename,
|
|
598
|
+
stage="decode",
|
|
599
|
+
)
|
|
600
|
+
|
|
484
601
|
tot_latency += latency
|
|
485
602
|
throughput = batch_size / latency
|
|
486
603
|
decode_latencies.append(latency)
|
|
@@ -489,14 +606,6 @@ def latency_test_run_once(
|
|
|
489
606
|
f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
|
|
490
607
|
)
|
|
491
608
|
|
|
492
|
-
if profile and i == output_len / 2:
|
|
493
|
-
profiler.stop()
|
|
494
|
-
trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
|
|
495
|
-
_save_profile_trace_results(profiler, trace_filename)
|
|
496
|
-
rank_print(
|
|
497
|
-
f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
|
|
498
|
-
)
|
|
499
|
-
|
|
500
609
|
# Record decode timing from 2nd output
|
|
501
610
|
if output_len > 1:
|
|
502
611
|
med_decode_latency = np.median(decode_latencies)
|
|
@@ -557,7 +666,10 @@ def latency_test(
|
|
|
557
666
|
log_decode_step=0,
|
|
558
667
|
profile=False,
|
|
559
668
|
profile_record_shapes=False,
|
|
560
|
-
|
|
669
|
+
profiler_activities=("CPU", "GPU"),
|
|
670
|
+
profile_filename_prefix="",
|
|
671
|
+
profile_stage="all",
|
|
672
|
+
tp_rank=tp_rank,
|
|
561
673
|
)
|
|
562
674
|
|
|
563
675
|
rank_print("Benchmark ...")
|
|
@@ -604,7 +716,10 @@ def latency_test(
|
|
|
604
716
|
bench_args.log_decode_step,
|
|
605
717
|
bench_args.profile if tp_rank == 0 else None,
|
|
606
718
|
bench_args.profile_record_shapes if tp_rank == 0 else None,
|
|
719
|
+
bench_args.profiler_activities,
|
|
607
720
|
bench_args.profile_filename_prefix,
|
|
721
|
+
bench_args.profile_stage,
|
|
722
|
+
tp_rank,
|
|
608
723
|
)
|
|
609
724
|
if ret is not None:
|
|
610
725
|
result_list.append(ret)
|
sglang/bench_serving.py
CHANGED
|
@@ -1014,7 +1014,7 @@ async def get_mooncake_request_over_time(
|
|
|
1014
1014
|
def sample_mmmu_requests(
|
|
1015
1015
|
num_requests: int,
|
|
1016
1016
|
processor: AutoProcessor | AutoTokenizer,
|
|
1017
|
-
backend: str,
|
|
1017
|
+
backend: str = "sglang",
|
|
1018
1018
|
fixed_output_len: Optional[int] = None,
|
|
1019
1019
|
random_sample: bool = True,
|
|
1020
1020
|
) -> List[DatasetRow]:
|
|
@@ -1369,7 +1369,10 @@ def create_mm_data_row(
|
|
|
1369
1369
|
)["input_ids"].numel()
|
|
1370
1370
|
except Exception:
|
|
1371
1371
|
# Fallback: just tokenize the text prompt directly
|
|
1372
|
-
|
|
1372
|
+
tokenizer_to_use = (
|
|
1373
|
+
processor.tokenizer if hasattr(processor, "tokenizer") else processor
|
|
1374
|
+
)
|
|
1375
|
+
text_prompt_len = len(tokenizer_to_use.encode(text_prompt))
|
|
1373
1376
|
|
|
1374
1377
|
# Vision tokens = total tokens - text tokens
|
|
1375
1378
|
vision_prompt_len = prompt_len - text_prompt_len
|
|
@@ -2033,6 +2036,7 @@ async def benchmark(
|
|
|
2033
2036
|
):
|
|
2034
2037
|
result = {
|
|
2035
2038
|
# Arguments
|
|
2039
|
+
"tag": getattr(args, "tag", None),
|
|
2036
2040
|
"backend": args.backend,
|
|
2037
2041
|
"dataset_name": args.dataset_name,
|
|
2038
2042
|
"request_rate": "trace" if use_trace_timestamps else request_rate,
|
|
@@ -2158,6 +2162,9 @@ def run_benchmark(args_: argparse.Namespace):
|
|
|
2158
2162
|
if not hasattr(args, "mooncake_num_rounds"):
|
|
2159
2163
|
args.mooncake_num_rounds = 1
|
|
2160
2164
|
|
|
2165
|
+
if not hasattr(args, "served_model_name"):
|
|
2166
|
+
args.served_model_name = None
|
|
2167
|
+
|
|
2161
2168
|
print(f"benchmark_args={args}")
|
|
2162
2169
|
|
|
2163
2170
|
# Set global environments
|
|
@@ -2271,7 +2278,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
|
2271
2278
|
|
|
2272
2279
|
# Read dataset
|
|
2273
2280
|
backend = args.backend
|
|
2274
|
-
model_id = args.model
|
|
2281
|
+
model_id = args.served_model_name or args.model
|
|
2275
2282
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
|
2276
2283
|
tokenizer = get_tokenizer(tokenizer_id)
|
|
2277
2284
|
input_requests = get_dataset(args, tokenizer, model_id)
|
|
@@ -2370,6 +2377,11 @@ if __name__ == "__main__":
|
|
|
2370
2377
|
type=str,
|
|
2371
2378
|
help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
|
|
2372
2379
|
)
|
|
2380
|
+
parser.add_argument(
|
|
2381
|
+
"--served-model-name",
|
|
2382
|
+
type=str,
|
|
2383
|
+
help="The name of the model as served by the serving service. If not set, this defaults to the value of --model.",
|
|
2384
|
+
)
|
|
2373
2385
|
parser.add_argument(
|
|
2374
2386
|
"--tokenizer",
|
|
2375
2387
|
type=str,
|
|
@@ -2627,5 +2639,8 @@ if __name__ == "__main__":
|
|
|
2627
2639
|
],
|
|
2628
2640
|
help="Underlying workload for the mooncake dataset.",
|
|
2629
2641
|
)
|
|
2642
|
+
parser.add_argument(
|
|
2643
|
+
"--tag", type=str, default=None, help="The tag to be dumped to output."
|
|
2644
|
+
)
|
|
2630
2645
|
args = parser.parse_args()
|
|
2631
2646
|
run_benchmark(args)
|
sglang/compile_deep_gemm.py
CHANGED
|
@@ -104,15 +104,21 @@ def launch_server_process_and_send_one_request(
|
|
|
104
104
|
if response.status_code == 200:
|
|
105
105
|
# Rank-0 node send a request to sync with other node and then return.
|
|
106
106
|
if server_args.node_rank == 0:
|
|
107
|
+
payload = {
|
|
108
|
+
"input_ids": [0, 1, 2, 3],
|
|
109
|
+
"sampling_params": {
|
|
110
|
+
"max_new_tokens": 8,
|
|
111
|
+
"temperature": 0,
|
|
112
|
+
},
|
|
113
|
+
}
|
|
114
|
+
# In PD mode, include fake bootstrap fields so workers don't assert
|
|
115
|
+
if server_args.disaggregation_mode != "null":
|
|
116
|
+
payload["bootstrap_host"] = FAKE_BOOTSTRAP_HOST
|
|
117
|
+
payload["bootstrap_room"] = 0
|
|
118
|
+
|
|
107
119
|
response = requests.post(
|
|
108
120
|
f"{base_url}/generate",
|
|
109
|
-
json=
|
|
110
|
-
"input_ids": [0, 1, 2, 3],
|
|
111
|
-
"sampling_params": {
|
|
112
|
-
"max_new_tokens": 8,
|
|
113
|
-
"temperature": 0,
|
|
114
|
-
},
|
|
115
|
-
},
|
|
121
|
+
json=payload,
|
|
116
122
|
timeout=600,
|
|
117
123
|
)
|
|
118
124
|
if response.status_code != 200:
|
|
@@ -9,6 +9,7 @@ from .batch_invariant_ops import (
|
|
|
9
9
|
log_softmax,
|
|
10
10
|
matmul_persistent,
|
|
11
11
|
mean_dim,
|
|
12
|
+
rms_norm_batch_invariant,
|
|
12
13
|
set_batch_invariant_mode,
|
|
13
14
|
)
|
|
14
15
|
|
|
@@ -24,4 +25,5 @@ __all__ = [
|
|
|
24
25
|
"mean_dim",
|
|
25
26
|
"get_batch_invariant_attention_block_size",
|
|
26
27
|
"AttentionBlockSize",
|
|
28
|
+
"rms_norm_batch_invariant",
|
|
27
29
|
]
|
|
@@ -579,6 +579,126 @@ def bmm_batch_invariant(a, b, *, out=None):
|
|
|
579
579
|
)
|
|
580
580
|
|
|
581
581
|
|
|
582
|
+
@triton.jit
|
|
583
|
+
def _rms_norm_kernel(
|
|
584
|
+
input_ptr,
|
|
585
|
+
weight_ptr,
|
|
586
|
+
output_ptr,
|
|
587
|
+
input_row_stride,
|
|
588
|
+
output_row_stride,
|
|
589
|
+
n_cols,
|
|
590
|
+
eps,
|
|
591
|
+
BLOCK_SIZE: tl.constexpr,
|
|
592
|
+
):
|
|
593
|
+
"""
|
|
594
|
+
Compute RMS normalization along the last dimension of a 2D tensor.
|
|
595
|
+
RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
|
|
596
|
+
Each block handles one row of the input tensor.
|
|
597
|
+
"""
|
|
598
|
+
row_idx = tl.program_id(0).to(tl.int64)
|
|
599
|
+
row_start_ptr = input_ptr + row_idx * input_row_stride
|
|
600
|
+
output_row_start_ptr = output_ptr + row_idx * output_row_stride
|
|
601
|
+
|
|
602
|
+
# Step 1: Compute sum of squares in float32 to avoid overflow
|
|
603
|
+
sum_sq = tl.zeros([1], dtype=tl.float32)
|
|
604
|
+
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
|
605
|
+
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
|
606
|
+
mask = col_idx < n_cols
|
|
607
|
+
|
|
608
|
+
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
|
609
|
+
# Convert to float32 for accumulation to prevent overflow
|
|
610
|
+
vals_f32 = vals.to(tl.float32)
|
|
611
|
+
sq_vals = vals_f32 * vals_f32
|
|
612
|
+
sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
|
|
613
|
+
|
|
614
|
+
# Step 2: Compute RMS (root mean square) in float32
|
|
615
|
+
mean_sq = sum_sq / n_cols
|
|
616
|
+
rms = tl.sqrt(mean_sq + eps)
|
|
617
|
+
inv_rms = 1.0 / rms
|
|
618
|
+
|
|
619
|
+
# Step 3: Normalize and apply weight
|
|
620
|
+
for col_offset in range(0, n_cols, BLOCK_SIZE):
|
|
621
|
+
col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
|
|
622
|
+
mask = col_idx < n_cols
|
|
623
|
+
vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
|
|
624
|
+
weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
|
|
625
|
+
# Compute in float32 then convert back to input dtype
|
|
626
|
+
vals_f32 = vals.to(tl.float32)
|
|
627
|
+
weight_f32 = weight.to(tl.float32)
|
|
628
|
+
output_f32 = vals_f32 * inv_rms * weight_f32
|
|
629
|
+
output = output_f32.to(vals.dtype)
|
|
630
|
+
tl.store(output_row_start_ptr + col_idx, output, mask=mask)
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def rms_norm(
|
|
634
|
+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
|
635
|
+
) -> torch.Tensor:
|
|
636
|
+
"""
|
|
637
|
+
Compute RMS normalization using Triton kernel.
|
|
638
|
+
|
|
639
|
+
RMS Norm normalizes the input by the root mean square and scales by weight:
|
|
640
|
+
output = input / sqrt(mean(input^2) + eps) * weight
|
|
641
|
+
|
|
642
|
+
Args:
|
|
643
|
+
input: Input tensor of shape (..., hidden_size)
|
|
644
|
+
weight: Weight tensor of shape (hidden_size,)
|
|
645
|
+
eps: Small constant for numerical stability
|
|
646
|
+
|
|
647
|
+
Returns:
|
|
648
|
+
Tensor with RMS normalization applied along the last dimension
|
|
649
|
+
"""
|
|
650
|
+
assert weight.dim() == 1, "Weight must be 1-dimensional"
|
|
651
|
+
assert input.shape[-1] == weight.shape[0], (
|
|
652
|
+
f"Input last dimension ({input.shape[-1]}) must match "
|
|
653
|
+
f"weight dimension ({weight.shape[0]})"
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
# Flatten all dimensions except the last one
|
|
657
|
+
original_shape = input.shape
|
|
658
|
+
input_2d = input.reshape(-1, input.shape[-1])
|
|
659
|
+
input_2d = input_2d.contiguous()
|
|
660
|
+
weight = weight.contiguous()
|
|
661
|
+
|
|
662
|
+
n_rows, n_cols = input_2d.shape
|
|
663
|
+
|
|
664
|
+
output = torch.empty_like(input_2d)
|
|
665
|
+
BLOCK_SIZE = 1024
|
|
666
|
+
grid = (n_rows,)
|
|
667
|
+
_rms_norm_kernel[grid](
|
|
668
|
+
input_2d,
|
|
669
|
+
weight,
|
|
670
|
+
output,
|
|
671
|
+
input_2d.stride(0),
|
|
672
|
+
output.stride(0),
|
|
673
|
+
n_cols,
|
|
674
|
+
eps,
|
|
675
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
|
676
|
+
)
|
|
677
|
+
return output.reshape(original_shape)
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def rms_norm_batch_invariant(
|
|
681
|
+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
|
|
682
|
+
) -> torch.Tensor:
|
|
683
|
+
"""
|
|
684
|
+
Batch-invariant wrapper for RMS normalization.
|
|
685
|
+
|
|
686
|
+
This function provides a deterministic, batch-invariant implementation
|
|
687
|
+
of RMS normalization for use with the batch_invariant mode.
|
|
688
|
+
|
|
689
|
+
Adapted from @https://github.com/vllm-project/vllm/blob/66a168a197ba214a5b70a74fa2e713c9eeb3251a/vllm/model_executor/layers/batch_invariant.py#L649
|
|
690
|
+
|
|
691
|
+
Args:
|
|
692
|
+
input: Input tensor of shape (..., hidden_size)
|
|
693
|
+
weight: Weight tensor of shape (hidden_size,)
|
|
694
|
+
eps: Small constant for numerical stability
|
|
695
|
+
|
|
696
|
+
Returns:
|
|
697
|
+
RMS normalized tensor
|
|
698
|
+
"""
|
|
699
|
+
return rms_norm(input, weight, eps=eps)
|
|
700
|
+
|
|
701
|
+
|
|
582
702
|
_batch_invariant_MODE = False
|
|
583
703
|
_batch_invariant_LIB = None
|
|
584
704
|
_original_torch_bmm = None
|