sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ and uses BatchMLAPaged wrapper for decoding.
9
9
  More details can be found in https://docs.flashinfer.ai/api/mla.html
10
10
  """
11
11
 
12
+ import os
12
13
  from dataclasses import dataclass
13
14
  from functools import partial
14
15
  from typing import TYPE_CHECKING, Callable, Optional, Union
@@ -16,6 +17,12 @@ from typing import TYPE_CHECKING, Callable, Optional, Union
16
17
  import torch
17
18
  import triton
18
19
 
20
+ if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
21
+ import logging
22
+
23
+ torch._logging.set_logs(dynamo=logging.ERROR)
24
+ torch._dynamo.config.suppress_errors = True
25
+
19
26
  from sglang.global_config import global_config
20
27
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
21
28
  from sglang.srt.layers.attention.flashinfer_backend import (
@@ -332,23 +339,39 @@ class FlashInferMLAAttnBackend(AttentionBackend):
332
339
  layer: RadixAttention,
333
340
  forward_batch: ForwardBatch,
334
341
  save_kv_cache: bool = True,
342
+ q_rope: Optional[torch.Tensor] = None,
343
+ k_rope: Optional[torch.Tensor] = None,
335
344
  ):
336
345
 
337
346
  cache_loc = forward_batch.out_cache_loc
338
347
  logits_soft_cap = layer.logit_cap
339
348
  prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
340
- qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
341
349
  k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
342
350
 
343
351
  # Save kv cache
344
352
  if save_kv_cache and k is not None:
345
353
  assert v is not None
346
354
  if save_kv_cache:
347
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
355
+ if k_rope is not None:
356
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
357
+ layer, cache_loc, k, k_rope
358
+ )
359
+ else:
360
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
361
+ if q_rope is not None:
362
+ q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
363
+ q_rope = q_rope.view(
364
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
365
+ )
348
366
 
349
367
  if self.forward_metadata.use_ragged:
350
368
  # ragged prefill
351
- o, _ = self.prefill_wrapper_ragged.forward_return_lse(
369
+ if q_rope is not None:
370
+ q = torch.cat([q, q_rope], dim=-1)
371
+ qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
372
+ if k_rope is not None:
373
+ k = torch.cat([k, k_rope], dim=-1)
374
+ o = self.prefill_wrapper_ragged.forward(
352
375
  qall,
353
376
  k.view(-1, layer.tp_k_head_num, layer.head_dim),
354
377
  v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
@@ -358,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
358
381
  )
359
382
  else:
360
383
  # mla paged prefill
384
+ if q_rope is None:
385
+ qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
386
+ q, q_rope = (
387
+ qall[:, :, : layer.v_head_dim],
388
+ qall[:, :, layer.v_head_dim :],
389
+ )
390
+ o = q.new_empty(q.shape)
361
391
  o = prefill_wrapper_paged.run(
362
- qall[:, :, : layer.v_head_dim],
363
- qall[:, :, layer.v_head_dim :],
392
+ q,
393
+ q_rope,
364
394
  k_buf[:, :, : layer.v_head_dim],
365
395
  k_buf[:, :, layer.v_head_dim :],
396
+ out=o,
366
397
  )
367
398
 
368
399
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -375,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
375
406
  layer: RadixAttention,
376
407
  forward_batch: ForwardBatch,
377
408
  save_kv_cache: bool = True,
409
+ # For multi-head latent attention
410
+ q_rope: Optional[torch.Tensor] = None,
411
+ k_rope: Optional[torch.Tensor] = None,
378
412
  ):
379
413
  decode_wrapper = self.forward_metadata.decode_wrapper
380
414
  cache_loc = forward_batch.out_cache_loc
@@ -382,20 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend):
382
416
  if k is not None:
383
417
  assert v is not None
384
418
  if save_kv_cache:
385
- forward_batch.token_to_kv_pool.set_kv_buffer(
386
- layer,
387
- cache_loc,
388
- k,
389
- v,
390
- )
391
- reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
419
+ if k_rope is not None:
420
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
421
+ layer,
422
+ cache_loc,
423
+ k,
424
+ k_rope,
425
+ )
426
+ else:
427
+ forward_batch.token_to_kv_pool.set_kv_buffer(
428
+ layer,
429
+ cache_loc,
430
+ k,
431
+ v,
432
+ )
433
+
434
+ # Reshape inputs
435
+ if q_rope is not None:
436
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
437
+ q_rope = q_rope.view(
438
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
439
+ )
440
+ else:
441
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
442
+ q_nope = reshaped_q[:, :, : layer.v_head_dim]
443
+ q_rope = reshaped_q[:, :, layer.v_head_dim :]
444
+
392
445
  k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
393
- reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
446
+
447
+ o = q_nope.new_empty(q_nope.shape)
448
+ # Direct call to run without the wrapper
394
449
  o = decode_wrapper.run(
395
- reshaped_q[:, :, : layer.v_head_dim],
396
- reshaped_q[:, :, layer.v_head_dim :],
397
- reshaped_k[:, :, : layer.v_head_dim],
398
- reshaped_k[:, :, layer.v_head_dim :],
450
+ q_nope,
451
+ q_rope,
452
+ k_buffer[:, :, : layer.v_head_dim],
453
+ k_buffer[:, :, layer.v_head_dim :],
454
+ out=o,
399
455
  )
400
456
 
401
457
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -825,16 +881,18 @@ def fast_mla_decode_plan(
825
881
  self._sm_scale = sm_scale
826
882
 
827
883
  with self.device as device:
828
- stream = torch.cuda.current_stream(device).cuda_stream
829
- self._cached_module.plan(
830
- self._float_workspace_buffer,
831
- self._int_workspace_buffer,
832
- self._pin_memory_int_workspace_buffer,
833
- qo_indptr_cpu,
834
- kv_indptr_cpu,
835
- kv_len_arr_cpu,
836
- num_heads,
837
- head_dim_ckv,
838
- causal,
839
- stream,
840
- )
884
+ try:
885
+ # Standard version with just the required arguments (no use_profiler)
886
+ self._cached_module.plan.default(
887
+ self._float_workspace_buffer,
888
+ self._int_workspace_buffer,
889
+ self._pin_memory_int_workspace_buffer,
890
+ qo_indptr_cpu,
891
+ kv_indptr_cpu,
892
+ kv_len_arr_cpu,
893
+ num_heads,
894
+ head_dim_ckv,
895
+ causal,
896
+ )
897
+ except Exception as e:
898
+ raise RuntimeError(f"Error in alternate MLA plan: {e}")
@@ -241,6 +241,9 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
241
241
  seq_lens_cpu,
242
242
  )
243
243
 
244
+ def get_cuda_graph_seq_len_fill_value(self):
245
+ return 1024
246
+
244
247
  def forward_decode(
245
248
  self,
246
249
  q: torch.Tensor,
@@ -0,0 +1,46 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from sgl_kernel import merge_state_v2
5
+
6
+ from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
7
+ from sglang.srt.utils import is_cuda
8
+
9
+ _is_cuda = is_cuda()
10
+
11
+
12
+ # Automatically fallback to the Triton kernel in some cases
13
+ # (e.g., for AMD GPUs, when the head dimension is not a multiple
14
+ # of 4 or 8, and in FP8 precision)
15
+ def _supported_dtypes(o: torch.Tensor) -> bool:
16
+ return o.dtype in [torch.float32, torch.half, torch.bfloat16]
17
+
18
+
19
+ def _supported_headdim(o: torch.Tensor) -> bool:
20
+ headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
21
+ if o.dtype == torch.float32:
22
+ return headdim % 4 == 0
23
+ return headdim % 8 == 0
24
+
25
+
26
+ def merge_state(
27
+ prefix_output: torch.Tensor,
28
+ prefix_lse: torch.Tensor,
29
+ suffix_output: torch.Tensor,
30
+ suffix_lse: torch.Tensor,
31
+ output: Optional[torch.Tensor] = None,
32
+ output_lse: Optional[torch.Tensor] = None,
33
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
34
+ if (
35
+ _is_cuda
36
+ and _supported_dtypes(prefix_output)
37
+ and _supported_headdim(prefix_output)
38
+ ):
39
+ return merge_state_v2(
40
+ prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
41
+ )
42
+ else:
43
+ # Fallback to Triton kernel
44
+ return merge_state_triton(
45
+ prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
46
+ )
@@ -0,0 +1,96 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def merge_state_kernel(
10
+ output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
11
+ output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
12
+ prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
13
+ prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
14
+ suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
15
+ suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
16
+ HEAD_SIZE: tl.constexpr,
17
+ PADDED_HEAD_SIZE: tl.constexpr,
18
+ OUTPUT_LSE: tl.constexpr,
19
+ ):
20
+ token_idx = tl.program_id(0)
21
+ num_tokens = tl.num_programs(0)
22
+ head_idx = tl.program_id(1)
23
+ num_heads = tl.num_programs(1)
24
+
25
+ p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
26
+ s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
27
+ p_lse = float("-inf") if p_lse == float("inf") else p_lse
28
+ s_lse = float("-inf") if s_lse == float("inf") else s_lse
29
+
30
+ max_lse = tl.maximum(p_lse, s_lse)
31
+ p_lse = p_lse - max_lse
32
+ s_lse = s_lse - max_lse
33
+ out_se = tl.exp(p_lse) + tl.exp(s_lse)
34
+
35
+ if OUTPUT_LSE:
36
+ out_lse = tl.log(out_se) + max_lse
37
+ tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
38
+
39
+ head_arange = tl.arange(0, PADDED_HEAD_SIZE)
40
+ head_mask = head_arange < HEAD_SIZE
41
+ p_out = tl.load(
42
+ prefix_output
43
+ + token_idx * num_heads * HEAD_SIZE
44
+ + head_idx * HEAD_SIZE
45
+ + head_arange,
46
+ mask=head_mask,
47
+ )
48
+ s_out = tl.load(
49
+ suffix_output
50
+ + token_idx * num_heads * HEAD_SIZE
51
+ + head_idx * HEAD_SIZE
52
+ + head_arange,
53
+ mask=head_mask,
54
+ )
55
+
56
+ p_scale = tl.exp(p_lse) / out_se
57
+ s_scale = tl.exp(s_lse) / out_se
58
+ out = p_out * p_scale + s_out * s_scale
59
+ tl.store(
60
+ output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
61
+ out,
62
+ mask=head_mask,
63
+ )
64
+
65
+
66
+ def merge_state_triton(
67
+ prefix_output: torch.Tensor,
68
+ prefix_lse: torch.Tensor,
69
+ suffix_output: torch.Tensor,
70
+ suffix_lse: torch.Tensor,
71
+ output: Optional[torch.Tensor] = None,
72
+ output_lse: Optional[torch.Tensor] = None,
73
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
74
+ # Avoid creating new tensors if they are already provided
75
+ if output is None:
76
+ output = torch.empty_like(prefix_output)
77
+ if output_lse is None:
78
+ output_lse = torch.empty_like(prefix_lse)
79
+
80
+ num_tokens = output.shape[0]
81
+ num_query_heads = output.shape[1]
82
+ head_size = output.shape[2]
83
+ padded_head_size = triton.next_power_of_2(head_size)
84
+
85
+ merge_state_kernel[(num_tokens, num_query_heads)](
86
+ output,
87
+ output_lse,
88
+ prefix_output,
89
+ prefix_lse,
90
+ suffix_output,
91
+ suffix_lse,
92
+ head_size,
93
+ padded_head_size,
94
+ output_lse is not None,
95
+ )
96
+ return output, output_lse