sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (131) hide show
  1. sglang/bench_serving.py +1 -1
  2. sglang/lang/chat_template.py +29 -0
  3. sglang/srt/_custom_ops.py +19 -17
  4. sglang/srt/configs/__init__.py +2 -0
  5. sglang/srt/configs/janus_pro.py +629 -0
  6. sglang/srt/configs/model_config.py +24 -14
  7. sglang/srt/conversation.py +80 -2
  8. sglang/srt/custom_op.py +64 -3
  9. sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
  10. sglang/srt/distributed/parallel_state.py +10 -1
  11. sglang/srt/entrypoints/engine.py +5 -3
  12. sglang/srt/entrypoints/http_server.py +1 -1
  13. sglang/srt/function_call_parser.py +33 -2
  14. sglang/srt/hf_transformers_utils.py +16 -1
  15. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
  17. sglang/srt/layers/attention/triton_backend.py +1 -3
  18. sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  21. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
  22. sglang/srt/layers/attention/vision.py +43 -62
  23. sglang/srt/layers/dp_attention.py +30 -2
  24. sglang/srt/layers/elementwise.py +411 -0
  25. sglang/srt/layers/linear.py +1 -1
  26. sglang/srt/layers/logits_processor.py +1 -0
  27. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  28. sglang/srt/layers/moe/ep_moe/layer.py +25 -9
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
  37. sglang/srt/layers/moe/router.py +342 -0
  38. sglang/srt/layers/parameter.py +10 -0
  39. sglang/srt/layers/quantization/__init__.py +90 -68
  40. sglang/srt/layers/quantization/blockwise_int8.py +1 -2
  41. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  44. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  45. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  46. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  49. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  51. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  57. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  58. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/fp8.py +174 -106
  68. sglang/srt/layers/quantization/fp8_kernel.py +210 -38
  69. sglang/srt/layers/quantization/fp8_utils.py +156 -15
  70. sglang/srt/layers/quantization/modelopt_quant.py +5 -1
  71. sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
  72. sglang/srt/layers/quantization/w8a8_int8.py +152 -3
  73. sglang/srt/layers/rotary_embedding.py +5 -3
  74. sglang/srt/layers/sampler.py +29 -35
  75. sglang/srt/layers/vocab_parallel_embedding.py +0 -1
  76. sglang/srt/lora/backend/__init__.py +9 -12
  77. sglang/srt/managers/cache_controller.py +74 -8
  78. sglang/srt/managers/data_parallel_controller.py +1 -1
  79. sglang/srt/managers/image_processor.py +37 -631
  80. sglang/srt/managers/image_processors/base_image_processor.py +219 -0
  81. sglang/srt/managers/image_processors/janus_pro.py +79 -0
  82. sglang/srt/managers/image_processors/llava.py +152 -0
  83. sglang/srt/managers/image_processors/minicpmv.py +86 -0
  84. sglang/srt/managers/image_processors/mlama.py +60 -0
  85. sglang/srt/managers/image_processors/qwen_vl.py +161 -0
  86. sglang/srt/managers/io_struct.py +32 -15
  87. sglang/srt/managers/multi_modality_padding.py +134 -0
  88. sglang/srt/managers/schedule_batch.py +213 -118
  89. sglang/srt/managers/schedule_policy.py +40 -8
  90. sglang/srt/managers/scheduler.py +176 -683
  91. sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
  92. sglang/srt/managers/tokenizer_manager.py +6 -6
  93. sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
  94. sglang/srt/mem_cache/base_prefix_cache.py +6 -8
  95. sglang/srt/mem_cache/chunk_cache.py +12 -44
  96. sglang/srt/mem_cache/hiradix_cache.py +71 -34
  97. sglang/srt/mem_cache/memory_pool.py +81 -17
  98. sglang/srt/mem_cache/paged_allocator.py +283 -0
  99. sglang/srt/mem_cache/radix_cache.py +117 -36
  100. sglang/srt/model_executor/cuda_graph_runner.py +68 -20
  101. sglang/srt/model_executor/forward_batch_info.py +23 -10
  102. sglang/srt/model_executor/model_runner.py +63 -63
  103. sglang/srt/model_loader/loader.py +2 -1
  104. sglang/srt/model_loader/weight_utils.py +1 -1
  105. sglang/srt/models/deepseek_janus_pro.py +2127 -0
  106. sglang/srt/models/deepseek_nextn.py +23 -3
  107. sglang/srt/models/deepseek_v2.py +200 -191
  108. sglang/srt/models/grok.py +374 -119
  109. sglang/srt/models/minicpmv.py +28 -89
  110. sglang/srt/models/mllama.py +1 -1
  111. sglang/srt/models/qwen2.py +0 -1
  112. sglang/srt/models/qwen2_5_vl.py +25 -50
  113. sglang/srt/models/qwen2_vl.py +33 -49
  114. sglang/srt/openai_api/adapter.py +59 -35
  115. sglang/srt/openai_api/protocol.py +8 -1
  116. sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
  117. sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
  118. sglang/srt/server_args.py +24 -16
  119. sglang/srt/speculative/eagle_worker.py +75 -39
  120. sglang/srt/utils.py +104 -9
  121. sglang/test/runners.py +104 -10
  122. sglang/test/test_block_fp8.py +106 -16
  123. sglang/test/test_custom_ops.py +88 -0
  124. sglang/test/test_utils.py +20 -4
  125. sglang/utils.py +0 -4
  126. sglang/version.py +1 -1
  127. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
  128. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
  129. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
  130. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
  131. {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,12 +1,12 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from functools import lru_cache
4
- from typing import Optional
4
+ from typing import Optional, Tuple
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
8
8
  import torch.nn.functional as F
9
- from einops import rearrange, repeat
9
+ from einops import rearrange
10
10
 
11
11
  from sglang.srt.distributed import parallel_state
12
12
  from sglang.srt.distributed import utils as dist_utils
@@ -22,47 +22,29 @@ from sglang.srt.layers.quantization import QuantizationConfig
22
22
  from sglang.srt.utils import add_prefix
23
23
 
24
24
 
25
- def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
26
- if not interleaved:
27
- x1, x2 = x.chunk(2, dim=-1)
28
- return torch.cat((-x2, x1), dim=-1)
29
- else:
30
- x1, x2 = x[..., ::2], x[..., 1::2]
31
- return rearrange(
32
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
33
- )
25
+ # Copied from transformers, modeling_qwen2_vl.py
26
+ def rotate_half(x):
27
+ """Rotates half the hidden dims of the input."""
28
+ x1 = x[..., : x.shape[-1] // 2]
29
+ x2 = x[..., x.shape[-1] // 2 :]
30
+ return torch.cat((-x2, x1), dim=-1)
34
31
 
35
32
 
36
- def apply_rotary_emb_torch(
37
- x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
38
- ) -> torch.Tensor:
39
- """
40
- x: (batch_size, seqlen, nheads, headdim)
41
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
42
- """
43
- ro_dim = cos.shape[-1] * 2
44
- assert ro_dim <= x.shape[-1]
45
- cos = repeat(
46
- cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
47
- )
48
- sin = repeat(
49
- sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
50
- )
51
- return torch.cat(
52
- [
53
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
54
- x[..., ro_dim:],
55
- ],
56
- dim=-1,
57
- )
58
-
59
-
60
- def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
61
- t_ = t.float()
62
- cos = freqs.cos()
63
- sin = freqs.sin()
64
- output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
65
- return output
33
+ def apply_rotary_pos_emb_vision(
34
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ orig_q_dtype = q.dtype
37
+ orig_k_dtype = k.dtype
38
+ q, k = q.float(), k.float()
39
+
40
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
41
+ q_embed = (q * cos) + (rotate_half(q) * sin)
42
+ k_embed = (k * cos) + (rotate_half(k) * sin)
43
+
44
+ q_embed = q_embed.to(orig_q_dtype)
45
+ k_embed = k_embed.to(orig_k_dtype)
46
+
47
+ return q_embed, k_embed
66
48
 
67
49
 
68
50
  class VisionAttention(nn.Module):
@@ -75,8 +57,8 @@ class VisionAttention(nn.Module):
75
57
  use_context_forward (bool, default to True):
76
58
  if ``True``, a flash_attn style attention will be applied
77
59
  Otherwise, a full-sequence attention will be applied.
78
- use_full_precision_softmax (bool, default to False):
79
- if ``True``, the softmax will be performed in full-precision
60
+ softmax_in_single_precision (bool, default to False):
61
+ if ``True``, the softmax will be performed in single-precision
80
62
  Otherwise, it will be performed in half-precision
81
63
 
82
64
  """
@@ -90,7 +72,7 @@ class VisionAttention(nn.Module):
90
72
  quant_config: Optional[QuantizationConfig] = None,
91
73
  dropout: float = 0.0,
92
74
  use_context_forward: bool = True,
93
- use_full_precision_softmax: bool = False,
75
+ softmax_in_single_precision: bool = False,
94
76
  flatten_batch: bool = False,
95
77
  prefix: str = "",
96
78
  ):
@@ -113,7 +95,7 @@ class VisionAttention(nn.Module):
113
95
  head_size=self.head_size,
114
96
  dropout=dropout,
115
97
  flatten_batch=flatten_batch,
116
- use_full_precision_softmax=use_full_precision_softmax,
98
+ softmax_in_single_precision=softmax_in_single_precision,
117
99
  )
118
100
 
119
101
  self.use_qkv_parallel = use_qkv_parallel
@@ -143,7 +125,7 @@ class VisionAttention(nn.Module):
143
125
  self,
144
126
  x: torch.Tensor,
145
127
  cu_seqlens: Optional[torch.Tensor] = None,
146
- rotary_pos_emb: torch.Tensor = None,
128
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
147
129
  attention_mask: Optional[torch.Tensor] = None,
148
130
  ) -> torch.Tensor:
149
131
  r"""
@@ -151,21 +133,17 @@ class VisionAttention(nn.Module):
151
133
  x: [b, s, embed_dim]
152
134
  cu_seqlens: [b]
153
135
  Returns:
154
- [s, b, num_heads * head]
136
+ [s, b, head * head_size]
155
137
  """
156
138
  bsz, s, _ = x.shape
139
+ head = self.num_attention_heads_per_partition
157
140
  if self.use_qkv_parallel:
158
141
  # [b, s, embed_dim] --> [b, s, embed_dim]
159
142
  qkv, _ = self.qkv_proj(x)
160
143
  q, k, v = qkv.chunk(3, dim=-1)
161
144
 
162
- # [b, s, embed_dim] --> [b * s, num_heads, head_size]
163
- q, k, v = [
164
- x.reshape(
165
- bsz * s, self.num_attention_heads_per_partition, -1
166
- ).contiguous()
167
- for x in (q, k, v)
168
- ]
145
+ # [b, s, embed_dim] --> [b * s, head, head_size]
146
+ q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
169
147
  else:
170
148
  # [b, s, embed_dim] --> [s, b, embed_dim]
171
149
  x = rearrange(x, "b s ... -> s b ...")
@@ -173,7 +151,7 @@ class VisionAttention(nn.Module):
173
151
  qkv, _ = self.qkv_proj(x)
174
152
  # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
175
153
  new_x_shape = qkv.size()[:-1] + (
176
- self.num_attention_heads_per_partition,
154
+ head,
177
155
  3 * self.hidden_size_per_attention_head,
178
156
  )
179
157
  qkv = qkv.view(*new_x_shape)
@@ -186,9 +164,12 @@ class VisionAttention(nn.Module):
186
164
  rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
187
165
  ]
188
166
 
189
- if rotary_pos_emb is not None:
190
- q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
191
- k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
167
+ if position_embeddings is not None:
168
+ cos, sin = position_embeddings
169
+ original_shape = q.shape
170
+ q, k = q.view(s, head, -1), k.view(s, head, -1)
171
+ q, k = apply_rotary_pos_emb_vision(q, k, cos, sin)
172
+ q, k = q.reshape(original_shape), k.reshape(original_shape)
192
173
 
193
174
  if self.use_qkv_parallel:
194
175
  pass
@@ -230,12 +211,12 @@ class VisionSdpaAttention(nn.Module):
230
211
  head_size: int,
231
212
  dropout: float = 0.0,
232
213
  flatten_batch: bool = False,
233
- use_full_precision_softmax: bool = False,
214
+ softmax_in_single_precision: bool = False,
234
215
  ):
235
216
  super().__init__()
236
217
  self.head_size = head_size
237
218
  self.flatten_batch = flatten_batch
238
- self.use_full_precision_softmax = use_full_precision_softmax
219
+ self.softmax_in_single_precision = softmax_in_single_precision
239
220
  self.dropout = dropout
240
221
 
241
222
  @staticmethod
@@ -319,14 +300,14 @@ class VisionSdpaAttention(nn.Module):
319
300
  )
320
301
 
321
302
  if attention_mask is None:
322
- if self.use_full_precision_softmax:
303
+ if self.softmax_in_single_precision:
323
304
  raise RuntimeError("Empty attention mask")
324
305
  else:
325
306
  attention_mask = attention_mask.to(device=q.device)
326
307
 
327
308
  q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
328
309
 
329
- if self.use_full_precision_softmax:
310
+ if self.softmax_in_single_precision:
330
311
  scale = self.head_size**-0.5
331
312
  k_transposed = rearrange(k, "b h s d -> b h d s")
332
313
  attn_weights = torch.matmul(q, k_transposed) * scale
@@ -1,6 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import functools
4
+ import logging
5
+ from contextlib import contextmanager
4
6
  from typing import TYPE_CHECKING, Union
5
7
 
6
8
  import torch
@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
14
16
  tensor_model_parallel_all_reduce,
15
17
  )
16
18
 
19
+ logger = logging.getLogger(__name__)
20
+
17
21
  if TYPE_CHECKING:
18
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
19
23
 
@@ -86,6 +90,27 @@ def get_attention_dp_size():
86
90
  return _DP_SIZE
87
91
 
88
92
 
93
+ @contextmanager
94
+ def disable_dp_size():
95
+ """Patch the tp group temporarily until this function ends.
96
+
97
+ This method is for draft workers of speculative decoding to run draft model
98
+ with different tp degree from that of target model workers.
99
+
100
+ Args:
101
+ tp_group (GroupCoordinator): the tp group coordinator
102
+ """
103
+ global _DP_SIZE
104
+ assert _DP_SIZE is not None, "dp attention not initialized!"
105
+
106
+ old_dp_size = _DP_SIZE
107
+ _DP_SIZE = 1
108
+ try:
109
+ yield
110
+ finally:
111
+ _DP_SIZE = old_dp_size
112
+
113
+
89
114
  def get_dp_local_info(forward_batch: ForwardBatch):
90
115
  dp_rank = get_attention_dp_rank()
91
116
 
@@ -159,7 +184,8 @@ def dp_gather(
159
184
  layer_id != "embedding" or get_attention_tp_rank() == 0
160
185
  ):
161
186
  assert (
162
- global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
187
+ global_tokens.untyped_storage().data_ptr()
188
+ != local_tokens.untyped_storage().data_ptr()
163
189
  ), "aliasing between global_tokens and local_tokens not allowed"
164
190
  memcpy_triton(
165
191
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
@@ -174,8 +200,9 @@ def dp_gather(
174
200
  torch.ops.sglang.inplace_all_reduce(
175
201
  global_tokens, group_name=get_tp_group().unique_name
176
202
  )
203
+
177
204
  else:
178
- global_tokens = tensor_model_parallel_all_reduce(global_tokens)
205
+ global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
179
206
 
180
207
 
181
208
  def dp_scatter(
@@ -186,6 +213,7 @@ def dp_scatter(
186
213
  # local_num_tokens is not necessarily the same as local_tokens.shape[0],
187
214
  # since local_tokens may be padded for cuda graph
188
215
  local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
216
+
189
217
  local_tokens.fill_(0)
190
218
  assert local_tokens.is_contiguous()
191
219
  assert global_tokens.is_contiguous()
@@ -0,0 +1,411 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ fused_softcap_autotune = triton.autotune(
8
+ configs=[
9
+ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
10
+ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8),
11
+ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16),
12
+ triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4),
13
+ triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8),
14
+ triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4),
15
+ triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8),
16
+ triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16),
17
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
18
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
19
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
20
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32),
21
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32),
22
+ triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32),
23
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
24
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
25
+ triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32),
26
+ ],
27
+ key=["n_ele"],
28
+ )
29
+
30
+
31
+ @triton.jit
32
+ def fused_softcap_kernel(
33
+ output_ptr,
34
+ input_ptr,
35
+ n_ele,
36
+ softcap_const: tl.constexpr,
37
+ BLOCK_SIZE: tl.constexpr,
38
+ ):
39
+ pid = tl.program_id(axis=0)
40
+ block_start = pid * BLOCK_SIZE
41
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
42
+ mask = offsets < n_ele
43
+ x = tl.load(input_ptr + offsets, mask=mask)
44
+ fx = x.to(tl.float32)
45
+ fxs = fx / softcap_const
46
+ exped = tl.exp(2 * fxs)
47
+ top = exped - 1
48
+ bottom = exped + 1
49
+ output = top / bottom * softcap_const
50
+ tl.store(output_ptr + offsets, output, mask=mask)
51
+
52
+
53
+ fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel)
54
+
55
+
56
+ def fused_softcap(x, softcap_const, autotune=False):
57
+ output = torch.empty_like(x, dtype=torch.float32)
58
+ n_elements = output.numel()
59
+ if autotune:
60
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
61
+ fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const)
62
+ else:
63
+ fused_softcap_kernel[(triton.cdiv(n_elements, 128),)](
64
+ output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8
65
+ )
66
+ return output
67
+
68
+
69
+ # cast to float + softcap
70
+ class Softcap:
71
+ def __init__(self, softcap_const: float):
72
+ self.softcap_const = softcap_const
73
+
74
+ def __call__(self, *args, **kwargs):
75
+ return self.forward(*args, **kwargs)
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ if x.is_cuda:
79
+ return self.forward_cuda(x)
80
+ else:
81
+ return self.forward_native(x)
82
+
83
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
84
+ return torch.tanh(x.float() / self.softcap_const) * self.softcap_const
85
+
86
+ def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor:
87
+ return fused_softcap(x, self.softcap_const, autotune=autotune)
88
+
89
+
90
+ rmsnorm_autotune = triton.autotune(
91
+ configs=[
92
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1),
93
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
94
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1),
95
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
96
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
97
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
98
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
99
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4),
100
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4),
101
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8),
102
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8),
103
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8),
104
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16),
105
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4),
106
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4),
107
+ triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8),
108
+ triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16),
109
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8),
110
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16),
111
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
112
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1),
113
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1),
114
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1),
115
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4),
116
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4),
117
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4),
118
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8),
119
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16),
120
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
121
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1),
122
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1),
123
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1),
124
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4),
125
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4),
126
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4),
127
+ ],
128
+ key=["hidden_dim"],
129
+ )
130
+
131
+
132
+ @triton.jit
133
+ def fused_dual_residual_rmsnorm_kernel(
134
+ output_ptr,
135
+ mid_ptr,
136
+ activ_ptr,
137
+ residual_ptr,
138
+ weight1_ptr,
139
+ weight2_ptr,
140
+ eps: tl.constexpr,
141
+ hidden_dim: tl.constexpr,
142
+ BLOCK_SIZE: tl.constexpr,
143
+ ):
144
+ pid = tl.program_id(axis=0)
145
+ input_start = pid * hidden_dim
146
+
147
+ offsets = tl.arange(0, BLOCK_SIZE)
148
+ mask = offsets < hidden_dim
149
+
150
+ a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
151
+ a = a_.to(tl.float32)
152
+ rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
153
+
154
+ r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0)
155
+ w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0)
156
+ w1 = w1_.to(tl.float32)
157
+
158
+ a2r = r + (a / rms * w1).to(r.dtype)
159
+ tl.store(
160
+ mid_ptr + input_start + offsets,
161
+ a2r,
162
+ mask=mask,
163
+ )
164
+
165
+ a2r = a2r.to(tl.float32)
166
+ rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps)
167
+
168
+ w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0)
169
+ w2 = w2_.to(tl.float32)
170
+
171
+ tl.store(
172
+ output_ptr + input_start + offsets,
173
+ a2r / rms2 * w2, # implicitly casts to output dtype here
174
+ mask=mask,
175
+ )
176
+
177
+
178
+ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
179
+ fused_dual_residual_rmsnorm_kernel
180
+ )
181
+
182
+
183
+ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
184
+ assert len(x.shape) == 2
185
+ assert x.shape == residual.shape and x.dtype == residual.dtype
186
+ output, mid = torch.empty_like(x), torch.empty_like(x)
187
+ bs, hidden_dim = x.shape
188
+ if autotune:
189
+ fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
190
+ output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
191
+ )
192
+ else:
193
+ config = {
194
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
195
+ "num_warps": max(
196
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
197
+ ),
198
+ }
199
+
200
+ fused_dual_residual_rmsnorm_kernel[(bs,)](
201
+ output,
202
+ mid,
203
+ x,
204
+ residual,
205
+ weight1,
206
+ weight2,
207
+ eps=eps,
208
+ hidden_dim=hidden_dim,
209
+ **config,
210
+ )
211
+
212
+ return output, mid
213
+
214
+
215
+ @triton.jit
216
+ def fused_rmsnorm_kernel(
217
+ output_ptr,
218
+ activ_ptr,
219
+ weight_ptr,
220
+ eps: tl.constexpr,
221
+ hidden_dim: tl.constexpr,
222
+ BLOCK_SIZE: tl.constexpr,
223
+ ):
224
+ pid = tl.program_id(axis=0)
225
+ input_start = pid * hidden_dim
226
+
227
+ offsets = tl.arange(0, BLOCK_SIZE)
228
+ mask = offsets < hidden_dim
229
+
230
+ a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
231
+ a = a_.to(tl.float32)
232
+ rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
233
+
234
+ w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0)
235
+ w1 = w1_.to(tl.float32)
236
+
237
+ a_rms = a / rms * w1
238
+
239
+ tl.store(
240
+ output_ptr + input_start + offsets,
241
+ a_rms, # implicitly casts to output dtype here
242
+ mask=mask,
243
+ )
244
+
245
+
246
+ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
247
+ assert len(x.shape) == 2
248
+ if inplace:
249
+ output = x
250
+ else:
251
+ output = torch.empty_like(x)
252
+ bs, hidden_dim = x.shape
253
+ config = {
254
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
255
+ "num_warps": max(
256
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
257
+ ),
258
+ }
259
+
260
+ fused_rmsnorm_kernel[(bs,)](
261
+ output, x, weight, eps=eps, hidden_dim=hidden_dim, **config
262
+ )
263
+ return output
264
+
265
+
266
+ class FusedDualResidualRMSNorm:
267
+ """
268
+ Fused implementation of
269
+ y = RMSNorm2(RMSNorm1(x) + residual))
270
+ """
271
+
272
+ def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1
273
+ self.rmsnorm1 = rmsnorm1
274
+ self.rmsnorm2 = rmsnorm2
275
+ self.variance_epsilon = self.rmsnorm1.variance_epsilon
276
+ assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon
277
+ assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape
278
+
279
+ def __call__(self, *args, **kwargs):
280
+ return self.forward(*args, **kwargs)
281
+
282
+ def forward(
283
+ self, x: torch.Tensor, residual: torch.Tensor
284
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
285
+ if x.is_cuda:
286
+ return self.forward_cuda(x, residual)
287
+ else:
288
+ return self.forward_flashinfer(x, residual)
289
+
290
+ def forward_cuda(
291
+ self, x: torch.Tensor, residual: torch.Tensor, autotune=False
292
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
293
+ return fused_dual_residual_rmsnorm(
294
+ x,
295
+ residual,
296
+ self.rmsnorm1.weight,
297
+ self.rmsnorm2.weight,
298
+ self.variance_epsilon,
299
+ autotune=autotune,
300
+ )
301
+
302
+ def forward_flashinfer(
303
+ self,
304
+ x: torch.Tensor,
305
+ residual: torch.Tensor,
306
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
307
+ normed1 = self.rmsnorm1(x)
308
+ residual = normed1 + residual
309
+ return self.rmsnorm2(residual), residual
310
+
311
+ def forward_native(
312
+ self,
313
+ x: torch.Tensor,
314
+ residual: torch.Tensor,
315
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
316
+ normed1 = self.rmsnorm1.forward_native(x)
317
+ residual = normed1 + residual
318
+ return self.rmsnorm2.forward_native(residual), residual
319
+
320
+
321
+ # gelu on first half of vector
322
+ @triton.jit
323
+ def gelu_and_mul_kernel(
324
+ out_hidden_states_ptr, # (bs, hidden_dim)
325
+ out_scales_ptr, # (bs,)
326
+ hidden_states_ptr, # (bs, hidden_dim * 2)
327
+ quant_max: tl.constexpr,
328
+ static_scale: tl.constexpr,
329
+ hidden_dim: tl.constexpr, # the output hidden_dim
330
+ BLOCK_SIZE: tl.constexpr,
331
+ ):
332
+ pid = tl.program_id(axis=0)
333
+
334
+ input_start = pid * hidden_dim * 2
335
+ output_start = pid * hidden_dim
336
+
337
+ input1_offs = tl.arange(0, BLOCK_SIZE)
338
+ mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
339
+ input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
340
+ output_offs = tl.arange(0, BLOCK_SIZE)
341
+
342
+ x1 = tl.load(
343
+ hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
344
+ ).to(tl.float32)
345
+ x3 = tl.load(
346
+ hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
347
+ ).to(tl.float32)
348
+
349
+ # gelu
350
+ # cast down before mul to better match training?
351
+ gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1
352
+ out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty)
353
+
354
+ if quant_max is not None:
355
+ raise NotImplementedError()
356
+
357
+ tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
358
+
359
+
360
+ def gelu_and_mul_triton(
361
+ hidden_states,
362
+ scales=None,
363
+ quantize=None, # dtype to quantize to
364
+ out=None,
365
+ ):
366
+ bs, in_hidden_dim = hidden_states.shape
367
+ hidden_dim = in_hidden_dim // 2
368
+
369
+ if out is None:
370
+ out_hidden_states = torch.empty(
371
+ (bs, hidden_dim),
372
+ dtype=quantize or hidden_states.dtype,
373
+ device=hidden_states.device,
374
+ )
375
+ else:
376
+ assert out.shape == (bs, hidden_dim)
377
+ assert out.dtype == (quantize or hidden_states.dtype)
378
+ out_hidden_states = out
379
+ out_scales = None
380
+ static_scale = False
381
+ if quantize is not None:
382
+ if scales is None:
383
+ out_scales = torch.empty(
384
+ (bs,), dtype=torch.float32, device=hidden_states.device
385
+ )
386
+ else:
387
+ out_scales = scales
388
+ static_scale = True
389
+
390
+ config = {
391
+ # 8 ele per thread (not tuned)
392
+ "num_warps": max(
393
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
394
+ ),
395
+ }
396
+
397
+ gelu_and_mul_kernel[(bs,)](
398
+ out_hidden_states,
399
+ out_scales,
400
+ hidden_states,
401
+ quant_max=torch.finfo(quantize).max if quantize is not None else None,
402
+ static_scale=static_scale,
403
+ hidden_dim=hidden_dim,
404
+ BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
405
+ **config,
406
+ )
407
+
408
+ if quantize is not None:
409
+ return out_hidden_states, out_scales
410
+ else:
411
+ return out_hidden_states, None
@@ -18,6 +18,7 @@ from sglang.srt.distributed import (
18
18
  )
19
19
  from sglang.srt.layers.parameter import (
20
20
  BasevLLMParameter,
21
+ BlockQuantScaleParameter,
21
22
  PackedColumnParameter,
22
23
  PackedvLLMParameter,
23
24
  PerTensorScaleParameter,
@@ -27,7 +28,6 @@ from sglang.srt.layers.quantization.base_config import (
27
28
  QuantizationConfig,
28
29
  QuantizeMethodBase,
29
30
  )
30
- from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
31
31
  from sglang.srt.utils import set_weight_attrs
32
32
 
33
33
  logger = logging.getLogger(__name__)