sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -115,13 +115,15 @@ class UnquantizedLinearMethod(LinearMethodBase):
115
115
  x: torch.Tensor,
116
116
  bias: Optional[torch.Tensor] = None,
117
117
  ) -> torch.Tensor:
118
-
119
118
  if use_intel_amx_backend(layer):
120
119
  x_shapes = x.shape
121
120
  if len(x_shapes) == 3:
122
121
  x = x.view(-1, x.shape[-1])
123
122
  output = torch.ops.sgl_kernel.weight_packed_linear(
124
- x, layer.weight, bias, True # is_vnni
123
+ x,
124
+ layer.weight,
125
+ bias,
126
+ True, # is_vnni
125
127
  )
126
128
  if len(x_shapes) == 3:
127
129
  output = output.view(x_shapes[0], x_shapes[1], -1)
@@ -138,19 +140,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
138
140
  self.use_triton_kernels = use_triton_kernels
139
141
  self.with_bias = False
140
142
 
141
- self.triton_kernel_moe_forward = None
142
- self.triton_kernel_moe_with_bias_forward = None
143
- if torch.cuda.is_available() and use_triton_kernels:
144
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
145
- triton_kernel_moe_forward as _tk_forward,
146
- )
147
- from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
148
- triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
149
- )
150
-
151
- self.triton_kernel_moe_forward = _tk_forward
152
- self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
153
-
154
143
  def create_weights(
155
144
  self,
156
145
  layer: torch.nn.Module,
@@ -231,14 +220,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
231
220
  self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
232
221
  ):
233
222
  self.moe_runner_config = moe_runner_config
234
- self.runner = MoeRunner(MoeRunnerBackend.TRITON, moe_runner_config)
223
+ backend = (
224
+ MoeRunnerBackend.TRITON_KERNELS
225
+ if self.use_triton_kernels
226
+ else MoeRunnerBackend.TRITON
227
+ )
228
+ self.runner = MoeRunner(backend, moe_runner_config)
235
229
 
236
230
  def apply(
237
231
  self,
238
232
  layer: torch.nn.Module,
239
233
  dispatch_output: StandardDispatchOutput,
240
234
  ) -> CombineInput:
241
-
242
235
  return self.forward(
243
236
  layer=layer,
244
237
  dispatch_output=dispatch_output,
@@ -249,7 +242,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
249
242
  layer: torch.nn.Module,
250
243
  dispatch_output: StandardDispatchOutput,
251
244
  ) -> CombineInput:
252
-
253
245
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
254
246
 
255
247
  x = dispatch_output.hidden_states
@@ -257,30 +249,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
257
249
 
258
250
  moe_runner_config = self.moe_runner_config
259
251
 
260
- if self.use_triton_kernels:
261
- if self.with_bias:
262
- assert self.triton_kernel_moe_with_bias_forward is not None
263
- output = self.triton_kernel_moe_with_bias_forward(
264
- hidden_states=x,
265
- w1=layer.w13_weight,
266
- w2=layer.w2_weight,
267
- b1=layer.w13_weight_bias,
268
- b2=layer.w2_weight_bias,
269
- topk_output=topk_output,
270
- moe_runner_config=moe_runner_config,
271
- w1_pcg=None,
272
- w2_pcg=None,
273
- )
274
- else:
275
- assert self.triton_kernel_moe_forward is not None
276
- output = self.triton_kernel_moe_forward(
277
- hidden_states=x,
278
- w1=layer.w13_weight,
279
- w2=layer.w2_weight,
280
- topk_output=topk_output,
281
- moe_runner_config=moe_runner_config,
282
- )
283
- return StandardCombineInput(hidden_states=output)
252
+ backend = self.runner.runner_backend
253
+ if backend.is_triton_kernels():
254
+ from sglang.srt.layers.moe.moe_runner.triton_kernels import (
255
+ TritonKernelsQuantInfo,
256
+ )
257
+
258
+ quant_info = TritonKernelsQuantInfo(
259
+ w13_weight=layer.w13_weight,
260
+ w2_weight=layer.w2_weight,
261
+ w13_bias=getattr(layer, "w13_weight_bias", None),
262
+ w2_bias=getattr(layer, "w2_weight_bias", None),
263
+ )
264
+ return self.runner.run(dispatch_output, quant_info)
284
265
  else:
285
266
  if _use_aiter:
286
267
  assert not moe_runner_config.no_combine, "unsupported"
@@ -311,7 +292,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
311
292
  )
312
293
  return StandardCombineInput(hidden_states=output)
313
294
  else:
314
-
315
295
  quant_info = TritonMoeQuantInfo(
316
296
  w13_weight=layer.w13_weight,
317
297
  w2_weight=layer.w2_weight,
@@ -325,7 +305,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
325
305
  layer: torch.nn.Module,
326
306
  dispatch_output: StandardDispatchOutput,
327
307
  ) -> CombineInput:
328
-
329
308
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
330
309
 
331
310
  x = dispatch_output.hidden_states
@@ -380,7 +359,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
380
359
  layer: torch.nn.Module,
381
360
  dispatch_output: StandardDispatchOutput,
382
361
  ) -> CombineInput:
383
-
384
362
  import torch_npu
385
363
 
386
364
  from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
@@ -23,7 +23,8 @@ if TYPE_CHECKING:
23
23
  from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE
24
24
  from sglang.srt.layers.moe.token_dispatcher import (
25
25
  CombineInput,
26
- DeepEPNormalOutput,
26
+ DeepEPLLDispatchOutput,
27
+ DeepEPNormalDispatchOutput,
27
28
  StandardDispatchOutput,
28
29
  )
29
30
 
@@ -328,10 +329,45 @@ class W4AFp8MoEMethod(FusedMoEMethodBase):
328
329
  output *= self.moe_runner_config.routed_scaling_factor
329
330
  return StandardCombineInput(hidden_states=output)
330
331
 
332
+ def apply_deepep_ll(
333
+ self,
334
+ layer: DeepEPMoE,
335
+ dispatch_output: DeepEPLLDispatchOutput,
336
+ ) -> torch.Tensor:
337
+
338
+ from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe_deepep_ll
339
+
340
+ hidden_states, _, topk_ids, _, masked_m, _ = dispatch_output
341
+
342
+ output = cutlass_w4a8_moe_deepep_ll(
343
+ hidden_states,
344
+ layer.w13_weight,
345
+ layer.w2_weight,
346
+ layer.w13_weight_scale_inv,
347
+ layer.w2_weight_scale_inv,
348
+ topk_ids,
349
+ masked_m,
350
+ layer.quant_method.a_strides1,
351
+ layer.quant_method.b_strides1,
352
+ layer.quant_method.c_strides1,
353
+ layer.quant_method.a_strides2,
354
+ layer.quant_method.b_strides2,
355
+ layer.quant_method.c_strides2,
356
+ layer.quant_method.s_strides13,
357
+ layer.quant_method.s_strides2,
358
+ layer.quant_method.expert_offsets,
359
+ layer.quant_method.problem_sizes1,
360
+ layer.quant_method.problem_sizes2,
361
+ layer.w13_input_scale,
362
+ layer.w2_input_scale,
363
+ )
364
+
365
+ return output
366
+
331
367
  def apply_deepep_normal(
332
368
  self,
333
369
  layer: DeepEPMoE,
334
- dispatch_output: DeepEPNormalOutput,
370
+ dispatch_output: DeepEPNormalDispatchOutput,
335
371
  ) -> torch.Tensor:
336
372
  from sglang.srt.layers.moe.cutlass_w4a8_moe import (
337
373
  cutlass_w4a8_moe_deepep_normal,
@@ -142,8 +142,11 @@ def unified_attention_with_output(
142
142
  ret = forward_batch.attn_backend.forward(
143
143
  query, key, value, attention_layer, forward_batch, save_kv_cache
144
144
  )
145
- assert output.shape == ret.shape
146
- output.copy_(ret)
145
+ assert (
146
+ output.numel() == ret.numel()
147
+ ), f"Output tensor element mismatch: {output.numel()} != {ret.numel()}"
148
+
149
+ output.view(ret.shape).copy_(ret)
147
150
  return
148
151
 
149
152
 
@@ -11,6 +11,7 @@ import triton
11
11
  import triton.language as tl
12
12
 
13
13
  from sglang.srt.custom_op import CustomOp
14
+ from sglang.srt.server_args import get_global_server_args
14
15
  from sglang.srt.utils import (
15
16
  cpu_has_amx_support,
16
17
  get_bool_env_var,
@@ -124,18 +125,29 @@ class RotaryEmbedding(CustomOp):
124
125
  self.cos_sin_cache: torch.Tensor
125
126
  self.register_buffer("cos_sin_cache", cache, persistent=False)
126
127
 
128
+ if get_global_server_args().rl_on_policy_target == "fsdp":
129
+ self._forward_method = self.forward_native
130
+
127
131
  def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
128
132
  """Compute the inverse frequency."""
129
133
  # NOTE(woosuk): To exactly match the HF implementation, we need to
130
134
  # use CPU to compute the cache and then move it to GPU. However, we
131
135
  # create the cache on GPU for faster initialization. This may cause
132
136
  # a slight numerical difference between the HF implementation and ours.
137
+ init_device = (
138
+ "cpu" if get_global_server_args().rl_on_policy_target == "fsdp" else None
139
+ )
133
140
  inv_freq = 1.0 / (
134
141
  base
135
142
  ** (
136
- torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
143
+ torch.arange(
144
+ 0, self.rotary_dim, 2, dtype=torch.float, device=init_device
145
+ )
146
+ / self.rotary_dim
137
147
  )
138
148
  )
149
+ if get_global_server_args().rl_on_policy_target == "fsdp":
150
+ inv_freq = inv_freq.cuda()
139
151
  return inv_freq
140
152
 
141
153
  def _compute_cos_sin_cache(self) -> torch.Tensor:
@@ -102,6 +102,14 @@ class Sampler(nn.Module):
102
102
  if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
103
103
  probs_without_temp_scaling = torch.softmax(logits, dim=-1)
104
104
 
105
+ if get_global_server_args().rl_on_policy_target == "fsdp":
106
+ logits_div_temperature = (
107
+ logits.bfloat16().div(sampling_info.temperatures).bfloat16()
108
+ )
109
+ logprobs_via_logsoftmax_kernel = torch.log_softmax(
110
+ logits_div_temperature, dim=-1
111
+ )
112
+
105
113
  # Post process logits
106
114
  logits.div_(sampling_info.temperatures)
107
115
  logits[:] = torch.softmax(logits, dim=-1)
@@ -148,8 +156,11 @@ class Sampler(nn.Module):
148
156
  )
149
157
 
150
158
  if return_logprob:
159
+ if get_global_server_args().rl_on_policy_target == "fsdp":
160
+ logprobs = logprobs_via_logsoftmax_kernel
161
+ del logprobs_via_logsoftmax_kernel
151
162
  # clamp to avoid -inf
152
- if SGLANG_RETURN_ORIGINAL_LOGPROB:
163
+ elif SGLANG_RETURN_ORIGINAL_LOGPROB:
153
164
  logprobs = torch.log(probs_without_temp_scaling).clamp(
154
165
  min=torch.finfo(probs_without_temp_scaling.dtype).min
155
166
  )
@@ -574,6 +574,7 @@ class GenerateReqInput(BaseReq):
574
574
  custom_labels=self.custom_labels,
575
575
  return_bytes=self.return_bytes,
576
576
  return_entropy=self.return_entropy,
577
+ http_worker_ipc=self.http_worker_ipc,
577
578
  )
578
579
 
579
580
 
@@ -759,6 +760,7 @@ class EmbeddingReqInput(BaseReq):
759
760
  sampling_params=self.sampling_params[i],
760
761
  rid=self.rid[i],
761
762
  is_cross_encoder_request=True,
763
+ http_worker_ipc=self.http_worker_ipc,
762
764
  )
763
765
 
764
766
  return EmbeddingReqInput(
@@ -769,6 +771,7 @@ class EmbeddingReqInput(BaseReq):
769
771
  video_data=self.video_data[i] if self.video_data is not None else None,
770
772
  sampling_params=self.sampling_params[i],
771
773
  rid=self.rid[i],
774
+ http_worker_ipc=self.http_worker_ipc,
772
775
  )
773
776
 
774
777
 
@@ -13,7 +13,12 @@ from __future__ import annotations
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
- """Mixin class and utils for multi-http-worker mode"""
16
+
17
+ """
18
+ Mixin classes and utils for multi-http-worker mode
19
+ This file uses multiple processes to handle requests and tokenization, reducing the overhead of python and http server.
20
+ """
21
+
17
22
  import asyncio
18
23
  import logging
19
24
  import multiprocessing as multiprocessing
@@ -566,3 +571,14 @@ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10):
566
571
  logger.warning(
567
572
  "uvicorn.supervisors.multiprocess not found, skipping monkey patch"
568
573
  )
574
+
575
+
576
+ class SenderWrapper:
577
+ def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket):
578
+ self.port_args = port_args
579
+ self.send_to_scheduler = send_to_scheduler
580
+
581
+ def send_pyobj(self, obj):
582
+ if isinstance(obj, BaseReq):
583
+ obj.http_worker_ipc = self.port_args.tokenizer_ipc_name
584
+ self.send_to_scheduler.send_pyobj(obj)
@@ -494,7 +494,7 @@ class Scheduler(
494
494
  )
495
495
  self.init_disaggregation()
496
496
 
497
- if get_bool_env_var("SGLANG_GC_LOG"):
497
+ if envs.SGLANG_LOG_GC.get():
498
498
  configure_gc_logger()
499
499
 
500
500
  # Init prefill kv split size when deterministic inference is enabled with various attention backends
@@ -2073,15 +2073,18 @@ class Scheduler(
2073
2073
  num_tokens_for_logprob = num_tokens
2074
2074
  else:
2075
2075
  num_tokens = local_batch.extend_num_tokens
2076
- num_tokens_for_logprob = sum(
2077
- [
2076
+ if local_batch.return_logprob:
2077
+ num_tokens_for_logprob = sum(
2078
2078
  # We should have at least 1 token for sample in every case.
2079
2079
  max(extend_len - logprob_start_len, 1)
2080
2080
  for logprob_start_len, extend_len in zip(
2081
- local_batch.extend_logprob_start_lens, local_batch.extend_lens
2081
+ local_batch.extend_logprob_start_lens,
2082
+ local_batch.extend_lens,
2082
2083
  )
2083
- ]
2084
- )
2084
+ )
2085
+ else:
2086
+ # When return_logprob = False, only need last token per request
2087
+ num_tokens_for_logprob = local_batch.batch_size()
2085
2088
 
2086
2089
  if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
2087
2090
  can_cuda_graph = 1
@@ -2322,10 +2325,10 @@ class Scheduler(
2322
2325
 
2323
2326
  self.num_generated_tokens = 0
2324
2327
  self.forward_ct_decode = 0
2325
- self.spec_num_total_accepted_tokens = 0
2326
- self.spec_num_total_forward_ct = 0
2327
- self.cum_spec_accept_length = 0
2328
- self.cum_spec_accept_count = 0
2328
+ self.spec_num_accepted_tokens = 0
2329
+ self.spec_num_forward_ct = 0
2330
+ self.spec_total_num_accepted_tokens = 0
2331
+ self.spec_total_num_forward_ct = 0
2329
2332
  torch.cuda.empty_cache()
2330
2333
  logger.info("Cache flushed successfully!")
2331
2334
  if_success = True
@@ -2398,13 +2401,16 @@ class Scheduler(
2398
2401
  self.tp_worker.model_runner.graph_mem_usage, 2
2399
2402
  )
2400
2403
 
2401
- if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2404
+ if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
2402
2405
  ret["avg_spec_accept_length"] = (
2403
- self.cum_spec_accept_length / self.cum_spec_accept_count
2406
+ self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
2404
2407
  )
2405
2408
  if RECORD_STEP_TIME:
2406
2409
  ret["step_time_dict"] = self.step_time_dict
2407
2410
 
2411
+ # This field is not serializable.
2412
+ ret.pop("model_config", None)
2413
+
2408
2414
  return GetInternalStateReqOutput(internal_state=ret)
2409
2415
 
2410
2416
  def set_internal_state(self, recv_req: SetInternalStateReq):
@@ -2431,12 +2437,12 @@ class Scheduler(
2431
2437
  if_success = False
2432
2438
  break
2433
2439
  if if_success:
2434
- if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2440
+ if not self.spec_algorithm.is_none() and self.spec_total_num_forward_ct > 0:
2435
2441
  avg_spec_accept_length = (
2436
- self.cum_spec_accept_length / self.cum_spec_accept_count
2442
+ self.spec_total_num_accepted_tokens / self.spec_total_num_forward_ct
2437
2443
  )
2438
2444
  logger.info(f"{avg_spec_accept_length=}")
2439
- self.cum_spec_accept_length = self.cum_spec_accept_count = 0
2445
+ self.spec_total_num_accepted_tokens = self.spec_total_num_forward_ct = 0
2440
2446
  for k, v in server_args_dict.items():
2441
2447
  setattr(get_global_server_args(), k, v)
2442
2448
  logger.info(f"Global server args updated! {get_global_server_args()=}")
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, List, Optional
7
7
 
8
8
  from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
9
9
  from sglang.srt.disaggregation.utils import DisaggregationMode
10
+ from sglang.srt.environ import envs
10
11
  from sglang.srt.managers.schedule_policy import PrefillAdder
11
12
  from sglang.srt.managers.scheduler import Req, ScheduleBatch
12
13
  from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
@@ -18,6 +19,7 @@ if TYPE_CHECKING:
18
19
  logger = logging.getLogger(__name__)
19
20
 
20
21
  RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
22
+ LOG_FORWARD_ITERS = envs.SGLANG_LOG_FORWARD_ITERS.get()
21
23
 
22
24
 
23
25
  class KvMetrics:
@@ -39,10 +41,13 @@ class SchedulerMetricsMixin:
39
41
  self.last_gen_throughput: float = 0.0
40
42
  self.last_input_throughput: float = 0.0
41
43
  self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
42
- self.spec_num_total_accepted_tokens = 0
43
- self.spec_num_total_forward_ct = 0
44
- self.cum_spec_accept_length = 0
45
- self.cum_spec_accept_count = 0
44
+
45
+ # The number of accepted tokens and forward ct for the recent `decode_log_interval` batches (for logging)
46
+ self.spec_num_accepted_tokens = 0
47
+ self.spec_num_forward_ct = 0
48
+ # The total number of accepted tokens and forward ct for the whole server lifetime
49
+ self.spec_total_num_accepted_tokens = 0
50
+ self.spec_total_num_forward_ct = 0
46
51
  self.kv_transfer_speed_gb_s: float = 0.0
47
52
  self.kv_transfer_latency_ms: float = 0.0
48
53
 
@@ -67,8 +72,8 @@ class SchedulerMetricsMixin:
67
72
  )
68
73
 
69
74
  def update_spec_metrics(self: Scheduler, bs: int, num_accepted_tokens: int):
70
- self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
71
- self.spec_num_total_forward_ct += bs
75
+ self.spec_num_accepted_tokens += num_accepted_tokens + bs
76
+ self.spec_num_forward_ct += bs
72
77
  self.num_generated_tokens += num_accepted_tokens
73
78
 
74
79
  def log_prefill_stats(
@@ -122,8 +127,10 @@ class SchedulerMetricsMixin:
122
127
  num_used, token_usage, _, _ = self._get_token_info()
123
128
  token_usage_msg = f"token usage: {token_usage:.2f}, "
124
129
 
130
+ iter_msg = f" [{self.forward_ct + 1}]" if LOG_FORWARD_ITERS else ""
131
+
125
132
  f = (
126
- f"Prefill batch [{self.forward_ct + 1}], "
133
+ f"Prefill batch{iter_msg}, "
127
134
  f"#new-seq: {len(can_run_list)}, "
128
135
  f"#new-token: {adder.log_input_tokens}, "
129
136
  f"#cached-token: {adder.log_hit_tokens}, "
@@ -246,27 +253,28 @@ class SchedulerMetricsMixin:
246
253
  gap_latency / self.server_args.decode_log_interval
247
254
  )
248
255
 
249
- msg = f"Decode batch [{self.forward_ct}], #running-req: {num_running_reqs}, {token_usage_msg}"
256
+ iter_msg = f" [{self.forward_ct}]" if LOG_FORWARD_ITERS else ""
257
+ msg = f"Decode batch{iter_msg}, #running-req: {num_running_reqs}, {token_usage_msg}"
250
258
 
251
259
  if self.spec_algorithm.is_none():
252
260
  spec_accept_length = 0
253
261
  spec_accept_rate = 0
254
262
  else:
255
263
  spec_accept_length = (
256
- self.spec_num_total_accepted_tokens / self.spec_num_total_forward_ct
264
+ self.spec_num_accepted_tokens / self.spec_num_forward_ct
257
265
  )
258
266
  # Calculate acceptance rate: accepted tokens / total draft tokens
259
- total_draft_tokens = self.spec_num_total_forward_ct * (
267
+ total_draft_tokens = self.spec_num_forward_ct * (
260
268
  (self.server_args.speculative_num_steps or 0) + 1
261
269
  )
262
270
  spec_accept_rate = (
263
- self.spec_num_total_accepted_tokens / total_draft_tokens
271
+ self.spec_num_accepted_tokens / total_draft_tokens
264
272
  if total_draft_tokens > 0
265
273
  else 0
266
274
  )
267
- self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
268
- self.cum_spec_accept_count += self.spec_num_total_forward_ct
269
- self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
275
+ self.spec_total_num_accepted_tokens += self.spec_num_accepted_tokens
276
+ self.spec_total_num_forward_ct += self.spec_num_forward_ct
277
+ self.spec_num_accepted_tokens = self.spec_num_forward_ct = 0
270
278
  msg += f"accept len: {spec_accept_length:.2f}, accept rate: {spec_accept_rate:.2f}, "
271
279
  cache_hit_rate = 0.0
272
280
 
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
28
28
  class SchedulerProfilerMixin:
29
29
  def init_profiler(self):
30
30
  self.torch_profiler = None
31
- self.torch_profiler_output_dir: Optional[str] = None
31
+ self.torch_profiler_output_dir: Optional[Path] = None
32
32
  self.profiler_activities: Optional[List[str]] = None
33
33
  self.profile_id: Optional[str] = None
34
34
  self.profiler_start_forward_ct: Optional[int] = None
@@ -69,7 +69,7 @@ class SchedulerProfilerMixin:
69
69
  if activities is None:
70
70
  activities = ["CPU", "GPU"]
71
71
 
72
- self.torch_profiler_output_dir = output_dir
72
+ self.torch_profiler_output_dir = Path(output_dir).expanduser()
73
73
  self.torch_profiler_with_stack = with_stack
74
74
  self.torch_profiler_record_shapes = record_shapes
75
75
  self.profiler_activities = activities
@@ -213,8 +213,7 @@ class SchedulerProfilerMixin:
213
213
  message="Profiling is not in progress. Call /start_profile first.",
214
214
  )
215
215
 
216
- if not Path(self.torch_profiler_output_dir).exists():
217
- Path(self.torch_profiler_output_dir).mkdir(parents=True, exist_ok=True)
216
+ self.torch_profiler_output_dir.mkdir(parents=True, exist_ok=True)
218
217
 
219
218
  stage_suffix = f"-{stage.name}" if stage else ""
220
219
  logger.info("Stop profiling" + stage_suffix + "...")
@@ -46,7 +46,6 @@ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchT
46
46
  from sglang.srt.managers.disagg_service import start_disagg_service
47
47
  from sglang.srt.managers.io_struct import (
48
48
  AbortReq,
49
- BaseReq,
50
49
  BatchEmbeddingOutput,
51
50
  BatchMultimodalOutput,
52
51
  BatchStrOutput,
@@ -171,7 +170,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
171
170
  self.context_len = self.model_config.context_len
172
171
  self.image_token_id = self.model_config.image_token_id
173
172
  self.max_req_input_len = None # Will be set later in engine.py
174
-
175
173
  speculative_algorithm = SpeculativeAlgorithm.from_string(
176
174
  server_args.speculative_algorithm
177
175
  )
@@ -180,9 +178,8 @@ class TokenizerManager(TokenizerCommunicatorMixin):
180
178
  if speculative_algorithm.is_none()
181
179
  else server_args.speculative_num_draft_tokens
182
180
  )
183
- # Initialize delimiter text for multi-item scoring (will be set after tokenizer is loaded)
184
- self.multi_item_delimiter_text = None
185
181
 
182
+ # Initialize tokenizer and processor
186
183
  if self.model_config.is_multimodal:
187
184
  import_processors("sglang.srt.multimodal.processors")
188
185
  try:
@@ -237,6 +234,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
237
234
  revision=server_args.revision,
238
235
  )
239
236
  self._initialize_multi_item_delimiter_text()
237
+
240
238
  # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
241
239
  if (
242
240
  server_args.enable_dynamic_batch_tokenizer
@@ -255,24 +253,20 @@ class TokenizerManager(TokenizerCommunicatorMixin):
255
253
  self.recv_from_detokenizer = get_zmq_socket(
256
254
  context, zmq.PULL, port_args.tokenizer_ipc_name, True
257
255
  )
258
- if self.server_args.tokenizer_worker_num > 1:
256
+ if self.server_args.tokenizer_worker_num == 1:
257
+ self.send_to_scheduler = get_zmq_socket(
258
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
259
+ )
260
+ else:
261
+ from sglang.srt.managers.multi_tokenizer_mixin import SenderWrapper
262
+
259
263
  # Use tokenizer_worker_ipc_name in multi-tokenizer mode
260
264
  send_to_scheduler = get_zmq_socket(
261
265
  context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
262
266
  )
263
267
 
264
- class SenderWrapper:
265
- def send_pyobj(self, obj):
266
- if isinstance(obj, BaseReq):
267
- obj.http_worker_ipc = port_args.tokenizer_ipc_name
268
- send_to_scheduler.send_pyobj(obj)
269
-
270
268
  # Make sure that each request carries the tokenizer_ipc_name for response routing
271
- self.send_to_scheduler = SenderWrapper()
272
- else:
273
- self.send_to_scheduler = get_zmq_socket(
274
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
275
- )
269
+ self.send_to_scheduler = SenderWrapper(port_args, send_to_scheduler)
276
270
 
277
271
  # Request states
278
272
  self._chosen_loop = None
@@ -320,6 +314,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
320
314
  # LoRA updates and inference to overlap.
321
315
  self.lora_update_lock = asyncio.Lock()
322
316
 
317
+ # Disaggregation
323
318
  self.disaggregation_mode = DisaggregationMode(
324
319
  self.server_args.disaggregation_mode
325
320
  )
@@ -389,9 +384,6 @@ class TokenizerManager(TokenizerCommunicatorMixin):
389
384
  obj.normalize_batch_and_arguments()
390
385
 
391
386
  if self.server_args.tokenizer_worker_num > 1:
392
- from sglang.srt.managers.multi_tokenizer_mixin import TokenizerWorker
393
-
394
- assert isinstance(self, TokenizerWorker)
395
387
  self._attach_multi_http_worker_info(obj)
396
388
 
397
389
  if self.enable_trace:
@@ -19,7 +19,13 @@ def get_hash_str(token_ids: List[int], prior_hash: str = None) -> str:
19
19
  hasher.update(bytes.fromhex(prior_hash))
20
20
 
21
21
  for t in token_ids:
22
- hasher.update(t.to_bytes(4, byteorder="little", signed=False))
22
+ if isinstance(t, tuple):
23
+ # EAGLE bigram mode: hash both elements to uniquely identify the bigram
24
+ for elem in t:
25
+ hasher.update(elem.to_bytes(4, byteorder="little", signed=False))
26
+ else:
27
+ # Regular mode: single integer token
28
+ hasher.update(t.to_bytes(4, byteorder="little", signed=False))
23
29
 
24
30
  return hasher.hexdigest()
25
31