sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -39,7 +39,6 @@ class SchedulerOutputProcessorMixin:
39
39
  self: Scheduler,
40
40
  batch: ScheduleBatch,
41
41
  result: Union[GenerationBatchResult, EmbeddingBatchResult],
42
- launch_done: Optional[threading.Event] = None,
43
42
  ):
44
43
  skip_stream_req = None
45
44
 
@@ -49,29 +48,29 @@ class SchedulerOutputProcessorMixin:
49
48
  next_token_ids,
50
49
  extend_input_len_per_req,
51
50
  extend_logprob_start_len_per_req,
51
+ copy_done,
52
52
  ) = (
53
53
  result.logits_output,
54
54
  result.next_token_ids,
55
55
  result.extend_input_len_per_req,
56
56
  result.extend_logprob_start_len_per_req,
57
+ result.copy_done,
57
58
  )
58
59
 
59
- if self.enable_overlap:
60
- logits_output, next_token_ids, _ = (
61
- self.tp_worker.resolve_last_batch_result(launch_done)
62
- )
63
- else:
64
- # Move next_token_ids and logprobs to cpu
65
- next_token_ids = next_token_ids.tolist()
66
- if batch.return_logprob:
67
- if logits_output.next_token_logprobs is not None:
68
- logits_output.next_token_logprobs = (
69
- logits_output.next_token_logprobs.tolist()
70
- )
71
- if logits_output.input_token_logprobs is not None:
72
- logits_output.input_token_logprobs = tuple(
73
- logits_output.input_token_logprobs.tolist()
74
- )
60
+ if copy_done is not None:
61
+ copy_done.synchronize()
62
+
63
+ # Move next_token_ids and logprobs to cpu
64
+ next_token_ids = next_token_ids.tolist()
65
+ if batch.return_logprob:
66
+ if logits_output.next_token_logprobs is not None:
67
+ logits_output.next_token_logprobs = (
68
+ logits_output.next_token_logprobs.tolist()
69
+ )
70
+ if logits_output.input_token_logprobs is not None:
71
+ logits_output.input_token_logprobs = tuple(
72
+ logits_output.input_token_logprobs.tolist()
73
+ )
75
74
 
76
75
  hidden_state_offset = 0
77
76
 
@@ -105,7 +104,10 @@ class SchedulerOutputProcessorMixin:
105
104
  assert extend_input_len_per_req is not None
106
105
  extend_logprob_start_len = extend_logprob_start_len_per_req[i]
107
106
  extend_input_len = extend_input_len_per_req[i]
108
- num_input_logprobs = extend_input_len - extend_logprob_start_len
107
+
108
+ num_input_logprobs = self._calculate_num_input_logprobs(
109
+ req, extend_input_len, extend_logprob_start_len
110
+ )
109
111
 
110
112
  if req.return_logprob:
111
113
  self.add_logprob_return_values(
@@ -160,8 +162,8 @@ class SchedulerOutputProcessorMixin:
160
162
  extend_input_len = extend_input_len_per_req[i]
161
163
  if extend_logprob_start_len < extend_input_len:
162
164
  # Update input logprobs.
163
- num_input_logprobs = (
164
- extend_input_len - extend_logprob_start_len
165
+ num_input_logprobs = self._calculate_num_input_logprobs(
166
+ req, extend_input_len, extend_logprob_start_len
165
167
  )
166
168
  if req.return_logprob:
167
169
  self.add_input_logprob_return_values(
@@ -174,8 +176,6 @@ class SchedulerOutputProcessorMixin:
174
176
  )
175
177
  logprob_pt += num_input_logprobs
176
178
 
177
- self.set_next_batch_sampling_info_done(batch)
178
-
179
179
  else: # embedding or reward model
180
180
  embeddings = result.embeddings.tolist()
181
181
 
@@ -204,22 +204,19 @@ class SchedulerOutputProcessorMixin:
204
204
  self: Scheduler,
205
205
  batch: ScheduleBatch,
206
206
  result: GenerationBatchResult,
207
- launch_done: Optional[threading.Event] = None,
208
207
  ):
209
- logits_output, next_token_ids, can_run_cuda_graph = (
208
+ logits_output, next_token_ids, can_run_cuda_graph, copy_done = (
210
209
  result.logits_output,
211
210
  result.next_token_ids,
212
211
  result.can_run_cuda_graph,
212
+ result.copy_done,
213
213
  )
214
214
  self.num_generated_tokens += len(batch.reqs)
215
215
 
216
- if self.enable_overlap:
217
- logits_output, next_token_ids, can_run_cuda_graph = (
218
- self.tp_worker.resolve_last_batch_result(launch_done)
219
- )
220
- next_token_logprobs = logits_output.next_token_logprobs
221
- elif batch.spec_algorithm.is_none():
222
- # spec decoding handles output logprobs inside verify process.
216
+ if copy_done is not None:
217
+ copy_done.synchronize()
218
+
219
+ if batch.spec_algorithm.is_none():
223
220
  next_token_ids = next_token_ids.tolist()
224
221
  if batch.return_logprob:
225
222
  next_token_logprobs = logits_output.next_token_logprobs.tolist()
@@ -299,7 +296,6 @@ class SchedulerOutputProcessorMixin:
299
296
  self.abort_request(AbortReq(rid=req.rid))
300
297
  req.grammar.finished = req.finished()
301
298
 
302
- self.set_next_batch_sampling_info_done(batch)
303
299
  self.stream_output(batch.reqs, batch.return_logprob)
304
300
  self.token_to_kv_pool_allocator.free_group_end()
305
301
 
@@ -310,6 +306,153 @@ class SchedulerOutputProcessorMixin:
310
306
  ):
311
307
  self.log_decode_stats(can_run_cuda_graph, running_batch=batch)
312
308
 
309
+ def _process_input_token_logprobs(
310
+ self, req: Req, input_token_logprobs: List
311
+ ) -> None:
312
+ """Process input token logprobs values and indices."""
313
+ is_multi_item_scoring = self._is_multi_item_scoring(req)
314
+
315
+ # Process logprob values - handle multi-item scoring vs regular requests
316
+ if is_multi_item_scoring:
317
+ # Multi-item scoring: use all logprobs as-is
318
+ req.input_token_logprobs_val = input_token_logprobs
319
+ else:
320
+ # Regular request: add None at start, remove last (sampling token)
321
+ req.input_token_logprobs_val = [None] + input_token_logprobs[:-1]
322
+
323
+ # Process logprob indices based on scoring type
324
+ if is_multi_item_scoring:
325
+ # Multi-item scoring: only include delimiter token positions
326
+ relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
327
+ input_token_logprobs_idx = [
328
+ token_id
329
+ for token_id in relevant_tokens
330
+ if token_id == self.server_args.multi_item_scoring_delimiter
331
+ ]
332
+ else:
333
+ # Regular request: include all tokens from logprob_start_len onwards
334
+ input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
335
+
336
+ # Clip padded hash values from image tokens to prevent detokenization errors
337
+ req.input_token_logprobs_idx = [
338
+ x if x < self.model_config.vocab_size - 1 else 0
339
+ for x in input_token_logprobs_idx
340
+ ]
341
+
342
+ def _process_input_top_logprobs(self, req: Req) -> None:
343
+ """Process input top logprobs."""
344
+ if req.top_logprobs_num <= 0:
345
+ return
346
+
347
+ is_multi_item_scoring = self._is_multi_item_scoring(req)
348
+
349
+ # Initialize arrays - multi-item scoring starts empty, others start with None
350
+ req.input_top_logprobs_val = [] if is_multi_item_scoring else [None]
351
+ req.input_top_logprobs_idx = [] if is_multi_item_scoring else [None]
352
+
353
+ # Extend arrays with temp values
354
+ for val, idx in zip(
355
+ req.temp_input_top_logprobs_val,
356
+ req.temp_input_top_logprobs_idx,
357
+ strict=True,
358
+ ):
359
+ req.input_top_logprobs_val.extend(val)
360
+ req.input_top_logprobs_idx.extend(idx)
361
+
362
+ # Remove last token (sampling token) for non multi-item scoring requests
363
+ if not is_multi_item_scoring:
364
+ req.input_top_logprobs_val.pop()
365
+ req.input_top_logprobs_idx.pop()
366
+
367
+ # Clean up temp storage
368
+ req.temp_input_top_logprobs_idx = None
369
+ req.temp_input_top_logprobs_val = None
370
+
371
+ def _process_input_token_ids_logprobs(self, req: Req) -> None:
372
+ """Process input token IDs logprobs."""
373
+ if req.token_ids_logprob is None:
374
+ return
375
+
376
+ is_multi_item_scoring = self._is_multi_item_scoring(req)
377
+
378
+ # Initialize arrays - multi-item scoring starts empty, others start with None
379
+ req.input_token_ids_logprobs_val = [] if is_multi_item_scoring else [None]
380
+ req.input_token_ids_logprobs_idx = [] if is_multi_item_scoring else [None]
381
+
382
+ # Process temp values - convert tensors to lists and extend arrays
383
+ for val, idx in zip(
384
+ req.temp_input_token_ids_logprobs_val,
385
+ req.temp_input_token_ids_logprobs_idx,
386
+ strict=True,
387
+ ):
388
+ val_list = val.tolist() if isinstance(val, torch.Tensor) else val
389
+ req.input_token_ids_logprobs_val.extend(
390
+ val_list if isinstance(val_list, list) else [val_list]
391
+ )
392
+ req.input_token_ids_logprobs_idx.extend(idx)
393
+
394
+ # Remove last token (sampling token) for non multi-item scoring requests
395
+ if not is_multi_item_scoring:
396
+ req.input_token_ids_logprobs_val.pop()
397
+ req.input_token_ids_logprobs_idx.pop()
398
+
399
+ # Clean up temp storage
400
+ req.temp_input_token_ids_logprobs_idx = None
401
+ req.temp_input_token_ids_logprobs_val = None
402
+
403
+ def _calculate_relevant_tokens_len(self, req: Req) -> int:
404
+ """Calculate the expected length of logprob arrays based on whether multi-item scoring is enabled.
405
+
406
+ For multi-item scoring, only delimiter positions have logprobs.
407
+ For regular requests, all positions from logprob_start_len onwards have logprobs.
408
+ """
409
+ is_multi_item_scoring = self._is_multi_item_scoring(req)
410
+
411
+ if is_multi_item_scoring:
412
+ # Multi-item scoring: count delimiter tokens from logprob_start_len onwards
413
+ relevant_tokens = req.origin_input_ids[req.logprob_start_len :]
414
+ return sum(
415
+ 1
416
+ for token_id in relevant_tokens
417
+ if token_id == self.server_args.multi_item_scoring_delimiter
418
+ )
419
+ else:
420
+ # Regular request: all tokens from logprob_start_len onwards
421
+ return len(req.origin_input_ids) - req.logprob_start_len
422
+
423
+ def _calculate_num_input_logprobs(
424
+ self, req: Req, extend_input_len: int, extend_logprob_start_len: int
425
+ ) -> int:
426
+ """Calculate the number of input logprobs based on whether multi-item scoring is enabled.
427
+
428
+ For multi-item scoring, only delimiter positions have logprobs.
429
+ For regular requests, all positions in the range have logprobs.
430
+ """
431
+ is_multi_item_scoring = self._is_multi_item_scoring(req)
432
+
433
+ if is_multi_item_scoring:
434
+ # Multi-item scoring: count delimiter tokens in the relevant portion
435
+ relevant_tokens = req.origin_input_ids[
436
+ extend_logprob_start_len:extend_input_len
437
+ ]
438
+ return sum(
439
+ 1
440
+ for token_id in relevant_tokens
441
+ if token_id == self.server_args.multi_item_scoring_delimiter
442
+ )
443
+ else:
444
+ # Regular request: all tokens in the range
445
+ return extend_input_len - extend_logprob_start_len
446
+
447
+ def _is_multi_item_scoring(self, req: Req) -> bool:
448
+ """Check if request uses multi-item scoring.
449
+
450
+ Multi-item scoring applies to prefill-only requests when a delimiter
451
+ token is configured. In this mode, only positions containing the
452
+ delimiter token receive logprobs.
453
+ """
454
+ return req.is_prefill_only and self.server_args.multi_item_scoring_delimiter
455
+
313
456
  def add_input_logprob_return_values(
314
457
  self: Scheduler,
315
458
  i: int,
@@ -378,63 +521,14 @@ class SchedulerOutputProcessorMixin:
378
521
  assert req.input_top_logprobs_val is None
379
522
  assert req.input_top_logprobs_idx is None
380
523
 
381
- # Compute input_token_logprobs_val
382
- # Always pad the first one with None.
383
- req.input_token_logprobs_val = [None]
384
- req.input_token_logprobs_val.extend(input_token_logprobs)
385
- # The last input logprob is for sampling, so just pop it out.
386
- req.input_token_logprobs_val.pop()
524
+ # Process all input logprob types using helper functions
525
+ self._process_input_token_logprobs(req, input_token_logprobs)
526
+ self._process_input_top_logprobs(req)
387
527
 
388
- # Compute input_token_logprobs_idx
389
- input_token_logprobs_idx = req.origin_input_ids[req.logprob_start_len :]
390
- # Clip the padded hash values from image tokens.
391
- # Otherwise, it will lead to detokenization errors.
392
- input_token_logprobs_idx = [
393
- x if x < self.model_config.vocab_size - 1 else 0
394
- for x in input_token_logprobs_idx
395
- ]
396
- req.input_token_logprobs_idx = input_token_logprobs_idx
397
-
398
- if req.top_logprobs_num > 0:
399
- req.input_top_logprobs_val = [None]
400
- req.input_top_logprobs_idx = [None]
401
- assert len(req.temp_input_token_ids_logprobs_val) == len(
402
- req.temp_input_token_ids_logprobs_idx
403
- )
404
- for val, idx in zip(
405
- req.temp_input_top_logprobs_val,
406
- req.temp_input_top_logprobs_idx,
407
- strict=True,
408
- ):
409
- req.input_top_logprobs_val.extend(val)
410
- req.input_top_logprobs_idx.extend(idx)
411
-
412
- # Last token is a sample token.
413
- req.input_top_logprobs_val.pop()
414
- req.input_top_logprobs_idx.pop()
415
- req.temp_input_top_logprobs_idx = None
416
- req.temp_input_top_logprobs_val = None
417
-
418
- if req.token_ids_logprob is not None:
419
- req.input_token_ids_logprobs_val = [None]
420
- req.input_token_ids_logprobs_idx = [None]
421
-
422
- for val, idx in zip(
423
- req.temp_input_token_ids_logprobs_val,
424
- req.temp_input_token_ids_logprobs_idx,
425
- strict=True,
426
- ):
427
- req.input_token_ids_logprobs_val.extend(val)
428
- req.input_token_ids_logprobs_idx.extend(idx)
429
-
430
- # Last token is a sample token.
431
- req.input_token_ids_logprobs_val.pop()
432
- req.input_token_ids_logprobs_idx.pop()
433
- req.temp_input_token_ids_logprobs_idx = None
434
- req.temp_input_token_ids_logprobs_val = None
528
+ self._process_input_token_ids_logprobs(req)
435
529
 
436
530
  if req.return_logprob:
437
- relevant_tokens_len = len(req.origin_input_ids) - req.logprob_start_len
531
+ relevant_tokens_len = self._calculate_relevant_tokens_len(req)
438
532
  assert len(req.input_token_logprobs_val) == relevant_tokens_len
439
533
  assert len(req.input_token_logprobs_idx) == relevant_tokens_len
440
534
  if req.top_logprobs_num > 0: