sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -11,6 +11,11 @@ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruc
11
11
  python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --output-len 32 256 --run-name test_run
12
12
  ## run with profiling:
13
13
  python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 12 14 --input-len 256 512 --profile
14
+ ## run with profiling to custom directory:
15
+ export SGLANG_TORCH_PROFILER_DIR=/root/sglang/profile_log
16
+ python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile
17
+ ## run with CUDA profiler (nsys):
18
+ nsys profile --force-overwrite=true -o bench_one_batch python -m sglang.bench_one_batch --model-path meta-llama/Meta-Llama-3-8B-Instruct --batch 1 --input-len 256 --profile --profiler_activities CUDA_PROFILER
14
19
  # Usage (correctness test):
15
20
  python -m sglang.bench_one_batch --model-path TinyLlama/TinyLlama-1.1B-Chat-v0.4 --correct
16
21
 
@@ -93,6 +98,68 @@ profile_activities = [torch.profiler.ProfilerActivity.CPU] + [
93
98
  ]
94
99
 
95
100
 
101
+ def start_profile(profiler_activities, profile_record_shapes=False, rank_print=print):
102
+ """
103
+ Abstracted function to start profiling based on profiler_activities.
104
+ Returns profiler object (or None).
105
+ """
106
+ if "CUDA_PROFILER" in profiler_activities:
107
+ try:
108
+ torch.cuda.cudart().cudaProfilerStart()
109
+ rank_print("CUDA Profiler started (nsys will begin capturing)")
110
+ except Exception as e:
111
+ rank_print(f"Failed to start CUDA profiler: {e}")
112
+ return None
113
+ else:
114
+ activities = []
115
+ if "CPU" in profiler_activities:
116
+ activities.append(torch.profiler.ProfilerActivity.CPU)
117
+ if "GPU" in profiler_activities:
118
+ activities.append(torch.profiler.ProfilerActivity.CUDA)
119
+ if activities:
120
+ profiler = torch.profiler.profile(
121
+ activities=activities,
122
+ with_stack=True,
123
+ record_shapes=profile_record_shapes,
124
+ )
125
+ profiler.start()
126
+ return profiler
127
+ return None
128
+
129
+
130
+ def stop_profile(
131
+ profiler,
132
+ profiler_activities,
133
+ rank_print=print,
134
+ save_trace=False,
135
+ trace_filename=None,
136
+ stage=None,
137
+ ):
138
+ """
139
+ Abstracted function to stop profiling based on profiler_activities.
140
+ Optionally saves trace results and prints completion messages.
141
+ """
142
+ if "CUDA_PROFILER" in profiler_activities:
143
+ try:
144
+ torch.cuda.cudart().cudaProfilerStop()
145
+ rank_print("CUDA Profiler stopped (nsys should dump traces)")
146
+ except Exception as e:
147
+ rank_print(f"Failed to stop CUDA profiler: {e}")
148
+ elif profiler is not None:
149
+ profiler.stop()
150
+
151
+ if save_trace:
152
+ if profiler is not None:
153
+ if trace_filename:
154
+ _save_profile_trace_results(profiler, trace_filename)
155
+ stage_desc = f"for {stage}" if stage else ""
156
+ rank_print(
157
+ f"torch profiler chrome trace {stage_desc} saved to {trace_filename}"
158
+ )
159
+ if "CUDA_PROFILER" in profiler_activities:
160
+ rank_print(f"CUDA profiler trace for {stage} completed")
161
+
162
+
96
163
  @dataclasses.dataclass
97
164
  class BenchArgs:
98
165
  run_name: str = "default"
@@ -107,6 +174,8 @@ class BenchArgs:
107
174
  log_decode_step: int = 0
108
175
  profile: bool = False
109
176
  profile_record_shapes: bool = False
177
+ profiler_activities: Tuple[str] = ("CPU", "GPU")
178
+ profile_stage: str = "all"
110
179
  profile_filename_prefix: str = "profile"
111
180
 
112
181
  @staticmethod
@@ -135,14 +204,27 @@ class BenchArgs:
135
204
  default=BenchArgs.log_decode_step,
136
205
  help="Log decode latency by step, default is set to zero to disable.",
137
206
  )
138
- parser.add_argument(
139
- "--profile", action="store_true", help="Use Torch Profiler."
140
- )
207
+ parser.add_argument("--profile", action="store_true", help="Enable profiling.")
141
208
  parser.add_argument(
142
209
  "--profile-record-shapes",
143
210
  action="store_true",
144
211
  help="Record tensor shapes in profiling results.",
145
212
  )
213
+ parser.add_argument(
214
+ "--profiler_activities",
215
+ type=str,
216
+ nargs="+",
217
+ default=["CPU", "GPU"],
218
+ choices=["CPU", "GPU", "CUDA_PROFILER"],
219
+ help="Profiler activities: CPU, GPU, CUDA_PROFILER. If CPU/GPU, use torch profiler. If CUDA_PROFILER, use CUDA profiler.",
220
+ )
221
+ parser.add_argument(
222
+ "--profile-stage",
223
+ type=str,
224
+ default=BenchArgs.profile_stage,
225
+ choices=["all", "prefill", "decode"],
226
+ help="Which stage to profile: all, prefill, or decode only.",
227
+ )
146
228
  parser.add_argument(
147
229
  "--profile-filename-prefix",
148
230
  type=str,
@@ -337,6 +419,18 @@ def _read_prompts_from_file(prompt_file, rank_print):
337
419
  return pf.readlines()
338
420
 
339
421
 
422
+ def _get_torch_profiler_output_dir():
423
+ return os.environ.get("SGLANG_TORCH_PROFILER_DIR", "/tmp")
424
+
425
+
426
+ def _create_torch_profiler_filename(
427
+ profile_filename_prefix, batch_size, input_len, output_len, stage
428
+ ):
429
+ output_dir = _get_torch_profiler_output_dir()
430
+ filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_{stage}.trace.json.gz"
431
+ return os.path.join(output_dir, filename)
432
+
433
+
340
434
  def _save_profile_trace_results(profiler, filename):
341
435
  parent_dir = os.path.dirname(os.path.abspath(filename))
342
436
  os.makedirs(parent_dir, exist_ok=True)
@@ -413,7 +507,10 @@ def latency_test_run_once(
413
507
  log_decode_step,
414
508
  profile,
415
509
  profile_record_shapes,
510
+ profiler_activities,
416
511
  profile_filename_prefix,
512
+ profile_stage,
513
+ tp_rank,
417
514
  ):
418
515
  max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
419
516
  if batch_size > max_batch_size:
@@ -422,7 +519,6 @@ def latency_test_run_once(
422
519
  )
423
520
  return
424
521
 
425
- # Clear the pools.
426
522
  model_runner.req_to_token_pool.clear()
427
523
  model_runner.token_to_kv_pool_allocator.clear()
428
524
 
@@ -436,20 +532,33 @@ def latency_test_run_once(
436
532
  tot_latency = 0
437
533
 
438
534
  profiler = None
439
- if profile:
440
- profiler = torch.profiler.profile(
441
- activities=profile_activities,
442
- with_stack=True,
443
- record_shapes=profile_record_shapes,
535
+ enable_profile_prefill = profile and profile_stage in ["all", "prefill"]
536
+ if enable_profile_prefill:
537
+ profiler = start_profile(
538
+ profiler_activities,
539
+ profile_record_shapes=profile_record_shapes,
540
+ rank_print=rank_print,
444
541
  )
445
- profiler.start()
446
542
 
447
- # Prefill
448
543
  synchronize(device)
449
544
  tic = time.perf_counter()
450
545
  next_token_ids, _, batch = extend(reqs, model_runner)
451
546
  synchronize(device)
452
547
  prefill_latency = time.perf_counter() - tic
548
+
549
+ if enable_profile_prefill:
550
+ trace_filename = _create_torch_profiler_filename(
551
+ profile_filename_prefix, batch_size, input_len, output_len, "prefill"
552
+ )
553
+ stop_profile(
554
+ profiler,
555
+ profiler_activities,
556
+ rank_print=rank_print,
557
+ save_trace=True,
558
+ trace_filename=trace_filename,
559
+ stage="prefill",
560
+ )
561
+
453
562
  tot_latency += prefill_latency
454
563
  throughput = input_len * batch_size / prefill_latency
455
564
  rank_print(
@@ -458,29 +567,37 @@ def latency_test_run_once(
458
567
  measurement_results["prefill_latency"] = prefill_latency
459
568
  measurement_results["prefill_throughput"] = throughput
460
569
 
461
- if profile:
462
- profiler.stop()
463
- trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
464
- _save_profile_trace_results(profiler, trace_filename)
465
- rank_print(f"torch profiler chrome trace for prefill saved to {trace_filename}")
466
-
467
- # Decode
468
570
  decode_latencies = []
571
+ profile_step_of_interest = output_len // 2
572
+ enable_profile_decode = profile and profile_stage in ["all", "decode"]
469
573
  for i in range(output_len - 1):
470
574
  synchronize(device)
471
- if profile and i == output_len / 2:
472
- profiler = None
473
- profiler = torch.profiler.profile(
474
- activities=profile_activities,
475
- with_stack=True,
476
- record_shapes=profile_record_shapes,
575
+ profiler = None
576
+ if enable_profile_decode and i == profile_step_of_interest:
577
+ profiler = start_profile(
578
+ profiler_activities,
579
+ profile_record_shapes=profile_record_shapes,
580
+ rank_print=rank_print,
477
581
  )
478
- profiler.start()
479
582
 
480
583
  tic = time.perf_counter()
481
584
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
482
585
  synchronize(device)
483
586
  latency = time.perf_counter() - tic
587
+
588
+ if enable_profile_decode and i == profile_step_of_interest:
589
+ trace_filename = _create_torch_profiler_filename(
590
+ profile_filename_prefix, batch_size, input_len, output_len, "decode"
591
+ )
592
+ stop_profile(
593
+ profiler,
594
+ profiler_activities,
595
+ rank_print=rank_print,
596
+ save_trace=True,
597
+ trace_filename=trace_filename,
598
+ stage="decode",
599
+ )
600
+
484
601
  tot_latency += latency
485
602
  throughput = batch_size / latency
486
603
  decode_latencies.append(latency)
@@ -489,14 +606,6 @@ def latency_test_run_once(
489
606
  f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
490
607
  )
491
608
 
492
- if profile and i == output_len / 2:
493
- profiler.stop()
494
- trace_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
495
- _save_profile_trace_results(profiler, trace_filename)
496
- rank_print(
497
- f"torch profiler chrome trace for decoding 1 token saved to {trace_filename}"
498
- )
499
-
500
609
  # Record decode timing from 2nd output
501
610
  if output_len > 1:
502
611
  med_decode_latency = np.median(decode_latencies)
@@ -557,7 +666,10 @@ def latency_test(
557
666
  log_decode_step=0,
558
667
  profile=False,
559
668
  profile_record_shapes=False,
560
- profile_filename_prefix="", # not used
669
+ profiler_activities=("CPU", "GPU"),
670
+ profile_filename_prefix="",
671
+ profile_stage="all",
672
+ tp_rank=tp_rank,
561
673
  )
562
674
 
563
675
  rank_print("Benchmark ...")
@@ -604,7 +716,10 @@ def latency_test(
604
716
  bench_args.log_decode_step,
605
717
  bench_args.profile if tp_rank == 0 else None,
606
718
  bench_args.profile_record_shapes if tp_rank == 0 else None,
719
+ bench_args.profiler_activities,
607
720
  bench_args.profile_filename_prefix,
721
+ bench_args.profile_stage,
722
+ tp_rank,
608
723
  )
609
724
  if ret is not None:
610
725
  result_list.append(ret)
sglang/bench_serving.py CHANGED
@@ -1014,7 +1014,7 @@ async def get_mooncake_request_over_time(
1014
1014
  def sample_mmmu_requests(
1015
1015
  num_requests: int,
1016
1016
  processor: AutoProcessor | AutoTokenizer,
1017
- backend: str,
1017
+ backend: str = "sglang",
1018
1018
  fixed_output_len: Optional[int] = None,
1019
1019
  random_sample: bool = True,
1020
1020
  ) -> List[DatasetRow]:
@@ -1369,7 +1369,10 @@ def create_mm_data_row(
1369
1369
  )["input_ids"].numel()
1370
1370
  except Exception:
1371
1371
  # Fallback: just tokenize the text prompt directly
1372
- text_prompt_len = len(processor.tokenizer.encode(text_prompt))
1372
+ tokenizer_to_use = (
1373
+ processor.tokenizer if hasattr(processor, "tokenizer") else processor
1374
+ )
1375
+ text_prompt_len = len(tokenizer_to_use.encode(text_prompt))
1373
1376
 
1374
1377
  # Vision tokens = total tokens - text tokens
1375
1378
  vision_prompt_len = prompt_len - text_prompt_len
@@ -2033,6 +2036,7 @@ async def benchmark(
2033
2036
  ):
2034
2037
  result = {
2035
2038
  # Arguments
2039
+ "tag": getattr(args, "tag", None),
2036
2040
  "backend": args.backend,
2037
2041
  "dataset_name": args.dataset_name,
2038
2042
  "request_rate": "trace" if use_trace_timestamps else request_rate,
@@ -2158,6 +2162,9 @@ def run_benchmark(args_: argparse.Namespace):
2158
2162
  if not hasattr(args, "mooncake_num_rounds"):
2159
2163
  args.mooncake_num_rounds = 1
2160
2164
 
2165
+ if not hasattr(args, "served_model_name"):
2166
+ args.served_model_name = None
2167
+
2161
2168
  print(f"benchmark_args={args}")
2162
2169
 
2163
2170
  # Set global environments
@@ -2271,7 +2278,7 @@ def run_benchmark(args_: argparse.Namespace):
2271
2278
 
2272
2279
  # Read dataset
2273
2280
  backend = args.backend
2274
- model_id = args.model
2281
+ model_id = args.served_model_name or args.model
2275
2282
  tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
2276
2283
  tokenizer = get_tokenizer(tokenizer_id)
2277
2284
  input_requests = get_dataset(args, tokenizer, model_id)
@@ -2370,6 +2377,11 @@ if __name__ == "__main__":
2370
2377
  type=str,
2371
2378
  help="Name or path of the model. If not set, the default model will request /v1/models for conf.",
2372
2379
  )
2380
+ parser.add_argument(
2381
+ "--served-model-name",
2382
+ type=str,
2383
+ help="The name of the model as served by the serving service. If not set, this defaults to the value of --model.",
2384
+ )
2373
2385
  parser.add_argument(
2374
2386
  "--tokenizer",
2375
2387
  type=str,
@@ -2627,5 +2639,8 @@ if __name__ == "__main__":
2627
2639
  ],
2628
2640
  help="Underlying workload for the mooncake dataset.",
2629
2641
  )
2642
+ parser.add_argument(
2643
+ "--tag", type=str, default=None, help="The tag to be dumped to output."
2644
+ )
2630
2645
  args = parser.parse_args()
2631
2646
  run_benchmark(args)
@@ -104,15 +104,21 @@ def launch_server_process_and_send_one_request(
104
104
  if response.status_code == 200:
105
105
  # Rank-0 node send a request to sync with other node and then return.
106
106
  if server_args.node_rank == 0:
107
+ payload = {
108
+ "input_ids": [0, 1, 2, 3],
109
+ "sampling_params": {
110
+ "max_new_tokens": 8,
111
+ "temperature": 0,
112
+ },
113
+ }
114
+ # In PD mode, include fake bootstrap fields so workers don't assert
115
+ if server_args.disaggregation_mode != "null":
116
+ payload["bootstrap_host"] = FAKE_BOOTSTRAP_HOST
117
+ payload["bootstrap_room"] = 0
118
+
107
119
  response = requests.post(
108
120
  f"{base_url}/generate",
109
- json={
110
- "input_ids": [0, 1, 2, 3],
111
- "sampling_params": {
112
- "max_new_tokens": 8,
113
- "temperature": 0,
114
- },
115
- },
121
+ json=payload,
116
122
  timeout=600,
117
123
  )
118
124
  if response.status_code != 200:
@@ -9,6 +9,7 @@ from .batch_invariant_ops import (
9
9
  log_softmax,
10
10
  matmul_persistent,
11
11
  mean_dim,
12
+ rms_norm_batch_invariant,
12
13
  set_batch_invariant_mode,
13
14
  )
14
15
 
@@ -24,4 +25,5 @@ __all__ = [
24
25
  "mean_dim",
25
26
  "get_batch_invariant_attention_block_size",
26
27
  "AttentionBlockSize",
28
+ "rms_norm_batch_invariant",
27
29
  ]
@@ -579,6 +579,126 @@ def bmm_batch_invariant(a, b, *, out=None):
579
579
  )
580
580
 
581
581
 
582
+ @triton.jit
583
+ def _rms_norm_kernel(
584
+ input_ptr,
585
+ weight_ptr,
586
+ output_ptr,
587
+ input_row_stride,
588
+ output_row_stride,
589
+ n_cols,
590
+ eps,
591
+ BLOCK_SIZE: tl.constexpr,
592
+ ):
593
+ """
594
+ Compute RMS normalization along the last dimension of a 2D tensor.
595
+ RMS Norm: y = x / sqrt(mean(x^2) + eps) * weight
596
+ Each block handles one row of the input tensor.
597
+ """
598
+ row_idx = tl.program_id(0).to(tl.int64)
599
+ row_start_ptr = input_ptr + row_idx * input_row_stride
600
+ output_row_start_ptr = output_ptr + row_idx * output_row_stride
601
+
602
+ # Step 1: Compute sum of squares in float32 to avoid overflow
603
+ sum_sq = tl.zeros([1], dtype=tl.float32)
604
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
605
+ col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
606
+ mask = col_idx < n_cols
607
+
608
+ vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
609
+ # Convert to float32 for accumulation to prevent overflow
610
+ vals_f32 = vals.to(tl.float32)
611
+ sq_vals = vals_f32 * vals_f32
612
+ sum_sq += tl.sum(tl.where(mask, sq_vals, 0.0))
613
+
614
+ # Step 2: Compute RMS (root mean square) in float32
615
+ mean_sq = sum_sq / n_cols
616
+ rms = tl.sqrt(mean_sq + eps)
617
+ inv_rms = 1.0 / rms
618
+
619
+ # Step 3: Normalize and apply weight
620
+ for col_offset in range(0, n_cols, BLOCK_SIZE):
621
+ col_idx = col_offset + tl.arange(0, BLOCK_SIZE)
622
+ mask = col_idx < n_cols
623
+ vals = tl.load(row_start_ptr + col_idx, mask=mask, other=0.0)
624
+ weight = tl.load(weight_ptr + col_idx, mask=mask, other=1.0)
625
+ # Compute in float32 then convert back to input dtype
626
+ vals_f32 = vals.to(tl.float32)
627
+ weight_f32 = weight.to(tl.float32)
628
+ output_f32 = vals_f32 * inv_rms * weight_f32
629
+ output = output_f32.to(vals.dtype)
630
+ tl.store(output_row_start_ptr + col_idx, output, mask=mask)
631
+
632
+
633
+ def rms_norm(
634
+ input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
635
+ ) -> torch.Tensor:
636
+ """
637
+ Compute RMS normalization using Triton kernel.
638
+
639
+ RMS Norm normalizes the input by the root mean square and scales by weight:
640
+ output = input / sqrt(mean(input^2) + eps) * weight
641
+
642
+ Args:
643
+ input: Input tensor of shape (..., hidden_size)
644
+ weight: Weight tensor of shape (hidden_size,)
645
+ eps: Small constant for numerical stability
646
+
647
+ Returns:
648
+ Tensor with RMS normalization applied along the last dimension
649
+ """
650
+ assert weight.dim() == 1, "Weight must be 1-dimensional"
651
+ assert input.shape[-1] == weight.shape[0], (
652
+ f"Input last dimension ({input.shape[-1]}) must match "
653
+ f"weight dimension ({weight.shape[0]})"
654
+ )
655
+
656
+ # Flatten all dimensions except the last one
657
+ original_shape = input.shape
658
+ input_2d = input.reshape(-1, input.shape[-1])
659
+ input_2d = input_2d.contiguous()
660
+ weight = weight.contiguous()
661
+
662
+ n_rows, n_cols = input_2d.shape
663
+
664
+ output = torch.empty_like(input_2d)
665
+ BLOCK_SIZE = 1024
666
+ grid = (n_rows,)
667
+ _rms_norm_kernel[grid](
668
+ input_2d,
669
+ weight,
670
+ output,
671
+ input_2d.stride(0),
672
+ output.stride(0),
673
+ n_cols,
674
+ eps,
675
+ BLOCK_SIZE=BLOCK_SIZE,
676
+ )
677
+ return output.reshape(original_shape)
678
+
679
+
680
+ def rms_norm_batch_invariant(
681
+ input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
682
+ ) -> torch.Tensor:
683
+ """
684
+ Batch-invariant wrapper for RMS normalization.
685
+
686
+ This function provides a deterministic, batch-invariant implementation
687
+ of RMS normalization for use with the batch_invariant mode.
688
+
689
+ Adapted from @https://github.com/vllm-project/vllm/blob/66a168a197ba214a5b70a74fa2e713c9eeb3251a/vllm/model_executor/layers/batch_invariant.py#L649
690
+
691
+ Args:
692
+ input: Input tensor of shape (..., hidden_size)
693
+ weight: Weight tensor of shape (hidden_size,)
694
+ eps: Small constant for numerical stability
695
+
696
+ Returns:
697
+ RMS normalized tensor
698
+ """
699
+ return rms_norm(input, weight, eps=eps)
700
+
701
+
582
702
  _batch_invariant_MODE = False
583
703
  _batch_invariant_LIB = None
584
704
  _original_torch_bmm = None
@@ -0,0 +1,9 @@
1
+ """
2
+ Checkpoint engine module for SGLang.
3
+
4
+ This module provides functionality for updating model weights via checkpoint engine.
5
+ """
6
+
7
+ from sglang.srt.checkpoint_engine.update import main
8
+
9
+ __all__ = ["main"]