sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,9 @@ import torch
18
18
  import triton
19
19
 
20
20
  if os.environ["SGLANG_ENABLE_TORCH_COMPILE"] == "1":
21
- import torch._dynamo
21
+ import logging
22
22
 
23
+ torch._logging.set_logs(dynamo=logging.ERROR)
23
24
  torch._dynamo.config.suppress_errors = True
24
25
 
25
26
  from sglang.global_config import global_config
@@ -338,23 +339,39 @@ class FlashInferMLAAttnBackend(AttentionBackend):
338
339
  layer: RadixAttention,
339
340
  forward_batch: ForwardBatch,
340
341
  save_kv_cache: bool = True,
342
+ q_rope: Optional[torch.Tensor] = None,
343
+ k_rope: Optional[torch.Tensor] = None,
341
344
  ):
342
345
 
343
346
  cache_loc = forward_batch.out_cache_loc
344
347
  logits_soft_cap = layer.logit_cap
345
348
  prefill_wrapper_paged = self.forward_metadata.prefill_wrapper
346
- qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
347
349
  k_buf = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
348
350
 
349
351
  # Save kv cache
350
352
  if save_kv_cache and k is not None:
351
353
  assert v is not None
352
354
  if save_kv_cache:
353
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
355
+ if k_rope is not None:
356
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
357
+ layer, cache_loc, k, k_rope
358
+ )
359
+ else:
360
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
361
+ if q_rope is not None:
362
+ q = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
363
+ q_rope = q_rope.view(
364
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
365
+ )
354
366
 
355
367
  if self.forward_metadata.use_ragged:
356
368
  # ragged prefill
357
- o, _ = self.prefill_wrapper_ragged.forward_return_lse(
369
+ if q_rope is not None:
370
+ q = torch.cat([q, q_rope], dim=-1)
371
+ qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
372
+ if k_rope is not None:
373
+ k = torch.cat([k, k_rope], dim=-1)
374
+ o = self.prefill_wrapper_ragged.forward(
358
375
  qall,
359
376
  k.view(-1, layer.tp_k_head_num, layer.head_dim),
360
377
  v.view(-1, layer.tp_k_head_num, layer.v_head_dim),
@@ -364,11 +381,19 @@ class FlashInferMLAAttnBackend(AttentionBackend):
364
381
  )
365
382
  else:
366
383
  # mla paged prefill
384
+ if q_rope is None:
385
+ qall = q.view(-1, layer.tp_q_head_num, layer.head_dim)
386
+ q, q_rope = (
387
+ qall[:, :, : layer.v_head_dim],
388
+ qall[:, :, layer.v_head_dim :],
389
+ )
390
+ o = q.new_empty(q.shape)
367
391
  o = prefill_wrapper_paged.run(
368
- qall[:, :, : layer.v_head_dim],
369
- qall[:, :, layer.v_head_dim :],
392
+ q,
393
+ q_rope,
370
394
  k_buf[:, :, : layer.v_head_dim],
371
395
  k_buf[:, :, layer.v_head_dim :],
396
+ out=o,
372
397
  )
373
398
 
374
399
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -381,6 +406,9 @@ class FlashInferMLAAttnBackend(AttentionBackend):
381
406
  layer: RadixAttention,
382
407
  forward_batch: ForwardBatch,
383
408
  save_kv_cache: bool = True,
409
+ # For multi-head latent attention
410
+ q_rope: Optional[torch.Tensor] = None,
411
+ k_rope: Optional[torch.Tensor] = None,
384
412
  ):
385
413
  decode_wrapper = self.forward_metadata.decode_wrapper
386
414
  cache_loc = forward_batch.out_cache_loc
@@ -388,23 +416,42 @@ class FlashInferMLAAttnBackend(AttentionBackend):
388
416
  if k is not None:
389
417
  assert v is not None
390
418
  if save_kv_cache:
391
- forward_batch.token_to_kv_pool.set_kv_buffer(
392
- layer,
393
- cache_loc,
394
- k,
395
- v,
396
- )
419
+ if k_rope is not None:
420
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
421
+ layer,
422
+ cache_loc,
423
+ k,
424
+ k_rope,
425
+ )
426
+ else:
427
+ forward_batch.token_to_kv_pool.set_kv_buffer(
428
+ layer,
429
+ cache_loc,
430
+ k,
431
+ v,
432
+ )
397
433
 
398
434
  # Reshape inputs
399
- reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
435
+ if q_rope is not None:
436
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
437
+ q_rope = q_rope.view(
438
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
439
+ )
440
+ else:
441
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
442
+ q_nope = reshaped_q[:, :, : layer.v_head_dim]
443
+ q_rope = reshaped_q[:, :, layer.v_head_dim :]
444
+
400
445
  k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
401
446
 
447
+ o = q_nope.new_empty(q_nope.shape)
402
448
  # Direct call to run without the wrapper
403
449
  o = decode_wrapper.run(
404
- reshaped_q[:, :, : layer.v_head_dim],
405
- reshaped_q[:, :, layer.v_head_dim :],
450
+ q_nope,
451
+ q_rope,
406
452
  k_buffer[:, :, : layer.v_head_dim],
407
453
  k_buffer[:, :, layer.v_head_dim :],
454
+ out=o,
408
455
  )
409
456
 
410
457
  return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
@@ -0,0 +1,46 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ from sgl_kernel import merge_state_v2
5
+
6
+ from sglang.srt.layers.attention.triton_ops.merge_state import merge_state_triton
7
+ from sglang.srt.utils import is_cuda
8
+
9
+ _is_cuda = is_cuda()
10
+
11
+
12
+ # Automatically fallback to the Triton kernel in some cases
13
+ # (e.g., for AMD GPUs, when the head dimension is not a multiple
14
+ # of 4 or 8, and in FP8 precision)
15
+ def _supported_dtypes(o: torch.Tensor) -> bool:
16
+ return o.dtype in [torch.float32, torch.half, torch.bfloat16]
17
+
18
+
19
+ def _supported_headdim(o: torch.Tensor) -> bool:
20
+ headdim = o.shape[2] # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
21
+ if o.dtype == torch.float32:
22
+ return headdim % 4 == 0
23
+ return headdim % 8 == 0
24
+
25
+
26
+ def merge_state(
27
+ prefix_output: torch.Tensor,
28
+ prefix_lse: torch.Tensor,
29
+ suffix_output: torch.Tensor,
30
+ suffix_lse: torch.Tensor,
31
+ output: Optional[torch.Tensor] = None,
32
+ output_lse: Optional[torch.Tensor] = None,
33
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
34
+ if (
35
+ _is_cuda
36
+ and _supported_dtypes(prefix_output)
37
+ and _supported_headdim(prefix_output)
38
+ ):
39
+ return merge_state_v2(
40
+ prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
41
+ )
42
+ else:
43
+ # Fallback to Triton kernel
44
+ return merge_state_triton(
45
+ prefix_output, prefix_lse, suffix_output, suffix_lse, output, output_lse
46
+ )
@@ -0,0 +1,96 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def merge_state_kernel(
10
+ output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_merged
11
+ output_lse, # [NUM_TOKENS, NUM_HEADS] s_merged
12
+ prefix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_a
13
+ prefix_lse, # [NUM_TOKENS, NUM_HEADS] s_a
14
+ suffix_output, # [NUM_TOKENS, NUM_HEADS, HEAD_SIZE] v_b
15
+ suffix_lse, # [NUM_TOKENS, NUM_HEADS] s_b
16
+ HEAD_SIZE: tl.constexpr,
17
+ PADDED_HEAD_SIZE: tl.constexpr,
18
+ OUTPUT_LSE: tl.constexpr,
19
+ ):
20
+ token_idx = tl.program_id(0)
21
+ num_tokens = tl.num_programs(0)
22
+ head_idx = tl.program_id(1)
23
+ num_heads = tl.num_programs(1)
24
+
25
+ p_lse = tl.load(prefix_lse + token_idx * num_heads + head_idx)
26
+ s_lse = tl.load(suffix_lse + token_idx * num_heads + head_idx)
27
+ p_lse = float("-inf") if p_lse == float("inf") else p_lse
28
+ s_lse = float("-inf") if s_lse == float("inf") else s_lse
29
+
30
+ max_lse = tl.maximum(p_lse, s_lse)
31
+ p_lse = p_lse - max_lse
32
+ s_lse = s_lse - max_lse
33
+ out_se = tl.exp(p_lse) + tl.exp(s_lse)
34
+
35
+ if OUTPUT_LSE:
36
+ out_lse = tl.log(out_se) + max_lse
37
+ tl.store(output_lse + token_idx * num_heads + head_idx, out_lse)
38
+
39
+ head_arange = tl.arange(0, PADDED_HEAD_SIZE)
40
+ head_mask = head_arange < HEAD_SIZE
41
+ p_out = tl.load(
42
+ prefix_output
43
+ + token_idx * num_heads * HEAD_SIZE
44
+ + head_idx * HEAD_SIZE
45
+ + head_arange,
46
+ mask=head_mask,
47
+ )
48
+ s_out = tl.load(
49
+ suffix_output
50
+ + token_idx * num_heads * HEAD_SIZE
51
+ + head_idx * HEAD_SIZE
52
+ + head_arange,
53
+ mask=head_mask,
54
+ )
55
+
56
+ p_scale = tl.exp(p_lse) / out_se
57
+ s_scale = tl.exp(s_lse) / out_se
58
+ out = p_out * p_scale + s_out * s_scale
59
+ tl.store(
60
+ output + token_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE + head_arange,
61
+ out,
62
+ mask=head_mask,
63
+ )
64
+
65
+
66
+ def merge_state_triton(
67
+ prefix_output: torch.Tensor,
68
+ prefix_lse: torch.Tensor,
69
+ suffix_output: torch.Tensor,
70
+ suffix_lse: torch.Tensor,
71
+ output: Optional[torch.Tensor] = None,
72
+ output_lse: Optional[torch.Tensor] = None,
73
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
74
+ # Avoid creating new tensors if they are already provided
75
+ if output is None:
76
+ output = torch.empty_like(prefix_output)
77
+ if output_lse is None:
78
+ output_lse = torch.empty_like(prefix_lse)
79
+
80
+ num_tokens = output.shape[0]
81
+ num_query_heads = output.shape[1]
82
+ head_size = output.shape[2]
83
+ padded_head_size = triton.next_power_of_2(head_size)
84
+
85
+ merge_state_kernel[(num_tokens, num_query_heads)](
86
+ output,
87
+ output_lse,
88
+ prefix_output,
89
+ prefix_lse,
90
+ suffix_output,
91
+ suffix_lse,
92
+ head_size,
93
+ padded_head_size,
94
+ output_lse is not None,
95
+ )
96
+ return output, output_lse