sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import functools
4
+ import logging
5
+ from contextlib import contextmanager
4
6
  from typing import TYPE_CHECKING, Union
5
7
 
6
8
  import torch
@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
14
16
  tensor_model_parallel_all_reduce,
15
17
  )
16
18
 
19
+ logger = logging.getLogger(__name__)
20
+
17
21
  if TYPE_CHECKING:
18
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
19
23
 
@@ -34,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
34
38
  return attn_tp_rank, attn_tp_size, dp_rank
35
39
 
36
40
 
37
- def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
41
+ def initialize_dp_attention(
42
+ enable_dp_attention: bool,
43
+ tp_rank: int,
44
+ tp_size: int,
45
+ dp_size: int,
46
+ ):
38
47
  global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
39
48
 
40
49
  from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
@@ -42,7 +51,11 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
42
51
  _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
43
52
  enable_dp_attention, tp_rank, tp_size, dp_size
44
53
  )
45
- _DP_SIZE = dp_size
54
+
55
+ if enable_dp_attention:
56
+ _DP_SIZE = dp_size
57
+ else:
58
+ _DP_SIZE = 1
46
59
 
47
60
  tp_group = get_tp_group()
48
61
  _ATTN_TP_GROUP = GroupCoordinator(
@@ -50,7 +63,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
50
63
  list(range(head, head + _ATTN_TP_SIZE))
51
64
  for head in range(0, tp_size, _ATTN_TP_SIZE)
52
65
  ],
53
- tp_rank,
66
+ tp_group.local_rank,
54
67
  torch.distributed.get_backend(tp_group.device_group),
55
68
  SYNC_TOKEN_IDS_ACROSS_TP,
56
69
  False,
@@ -86,6 +99,27 @@ def get_attention_dp_size():
86
99
  return _DP_SIZE
87
100
 
88
101
 
102
+ @contextmanager
103
+ def disable_dp_size():
104
+ """Patch the tp group temporarily until this function ends.
105
+
106
+ This method is for draft workers of speculative decoding to run draft model
107
+ with different tp degree from that of target model workers.
108
+
109
+ Args:
110
+ tp_group (GroupCoordinator): the tp group coordinator
111
+ """
112
+ global _DP_SIZE
113
+ assert _DP_SIZE is not None, "dp attention not initialized!"
114
+
115
+ old_dp_size = _DP_SIZE
116
+ _DP_SIZE = 1
117
+ try:
118
+ yield
119
+ finally:
120
+ _DP_SIZE = old_dp_size
121
+
122
+
89
123
  def get_dp_local_info(forward_batch: ForwardBatch):
90
124
  dp_rank = get_attention_dp_rank()
91
125
 
@@ -144,22 +178,22 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
144
178
  memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
145
179
 
146
180
 
147
- def dp_gather(
181
+ def _dp_gather(
148
182
  global_tokens: torch.Tensor,
149
183
  local_tokens: torch.Tensor,
150
184
  forward_batch: ForwardBatch,
151
- layer_id: Union[str, int],
185
+ is_partial: bool,
152
186
  ):
153
187
  local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
154
188
 
155
189
  global_tokens.fill_(0)
156
190
  assert local_tokens.is_contiguous()
157
191
  assert global_tokens.is_contiguous()
158
- if local_tokens.shape[0] > 0 and (
159
- layer_id != "embedding" or get_attention_tp_rank() == 0
160
- ):
192
+
193
+ if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
161
194
  assert (
162
- global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
195
+ global_tokens.untyped_storage().data_ptr()
196
+ != local_tokens.untyped_storage().data_ptr()
163
197
  ), "aliasing between global_tokens and local_tokens not allowed"
164
198
  memcpy_triton(
165
199
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
@@ -174,8 +208,25 @@ def dp_gather(
174
208
  torch.ops.sglang.inplace_all_reduce(
175
209
  global_tokens, group_name=get_tp_group().unique_name
176
210
  )
211
+
177
212
  else:
178
- global_tokens = tensor_model_parallel_all_reduce(global_tokens)
213
+ global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
214
+
215
+
216
+ def dp_gather_partial(
217
+ global_tokens: torch.Tensor,
218
+ local_tokens: torch.Tensor,
219
+ forward_batch: ForwardBatch,
220
+ ):
221
+ _dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
222
+
223
+
224
+ def dp_gather_replicate(
225
+ global_tokens: torch.Tensor,
226
+ local_tokens: torch.Tensor,
227
+ forward_batch: ForwardBatch,
228
+ ):
229
+ _dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)
179
230
 
180
231
 
181
232
  def dp_scatter(
@@ -186,6 +237,7 @@ def dp_scatter(
186
237
  # local_num_tokens is not necessarily the same as local_tokens.shape[0],
187
238
  # since local_tokens may be padded for cuda graph
188
239
  local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
240
+
189
241
  local_tokens.fill_(0)
190
242
  assert local_tokens.is_contiguous()
191
243
  assert global_tokens.is_contiguous()
@@ -197,16 +249,3 @@ def dp_scatter(
197
249
  memcpy_triton(
198
250
  local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
199
251
  )
200
-
201
-
202
- def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
203
- def do_logits_dp_scatter(logits: torch.Tensor):
204
- local_logits = torch.empty(
205
- (forward_batch.input_ids.shape[0], *logits.shape[1:]),
206
- dtype=logits.dtype,
207
- device=logits.device,
208
- )
209
- dp_scatter(local_logits, logits, forward_batch)
210
- return local_logits
211
-
212
- return do_logits_dp_scatter
@@ -0,0 +1,411 @@
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+ fused_softcap_autotune = triton.autotune(
8
+ configs=[
9
+ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
10
+ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8),
11
+ triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16),
12
+ triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4),
13
+ triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8),
14
+ triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4),
15
+ triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8),
16
+ triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16),
17
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
18
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
19
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
20
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32),
21
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32),
22
+ triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32),
23
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
24
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
25
+ triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32),
26
+ ],
27
+ key=["n_ele"],
28
+ )
29
+
30
+
31
+ @triton.jit
32
+ def fused_softcap_kernel(
33
+ output_ptr,
34
+ input_ptr,
35
+ n_ele,
36
+ softcap_const: tl.constexpr,
37
+ BLOCK_SIZE: tl.constexpr,
38
+ ):
39
+ pid = tl.program_id(axis=0)
40
+ block_start = pid * BLOCK_SIZE
41
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
42
+ mask = offsets < n_ele
43
+ x = tl.load(input_ptr + offsets, mask=mask)
44
+ fx = x.to(tl.float32)
45
+ fxs = fx / softcap_const
46
+ exped = tl.exp(2 * fxs)
47
+ top = exped - 1
48
+ bottom = exped + 1
49
+ output = top / bottom * softcap_const
50
+ tl.store(output_ptr + offsets, output, mask=mask)
51
+
52
+
53
+ fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel)
54
+
55
+
56
+ def fused_softcap(x, softcap_const, autotune=False):
57
+ output = torch.empty_like(x, dtype=torch.float32)
58
+ n_elements = output.numel()
59
+ if autotune:
60
+ grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
61
+ fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const)
62
+ else:
63
+ fused_softcap_kernel[(triton.cdiv(n_elements, 128),)](
64
+ output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8
65
+ )
66
+ return output
67
+
68
+
69
+ # cast to float + softcap
70
+ class Softcap:
71
+ def __init__(self, softcap_const: float):
72
+ self.softcap_const = softcap_const
73
+
74
+ def __call__(self, *args, **kwargs):
75
+ return self.forward(*args, **kwargs)
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ if x.is_cuda:
79
+ return self.forward_cuda(x)
80
+ else:
81
+ return self.forward_native(x)
82
+
83
+ def forward_native(self, x: torch.Tensor) -> torch.Tensor:
84
+ return torch.tanh(x.float() / self.softcap_const) * self.softcap_const
85
+
86
+ def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor:
87
+ return fused_softcap(x, self.softcap_const, autotune=autotune)
88
+
89
+
90
+ rmsnorm_autotune = triton.autotune(
91
+ configs=[
92
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1),
93
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
94
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1),
95
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
96
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
97
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
98
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
99
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4),
100
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4),
101
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8),
102
+ triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8),
103
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8),
104
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16),
105
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4),
106
+ triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4),
107
+ triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8),
108
+ triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16),
109
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8),
110
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16),
111
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
112
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1),
113
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1),
114
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1),
115
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4),
116
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4),
117
+ triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4),
118
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8),
119
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16),
120
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
121
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1),
122
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1),
123
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1),
124
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4),
125
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4),
126
+ triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4),
127
+ ],
128
+ key=["hidden_dim"],
129
+ )
130
+
131
+
132
+ @triton.jit
133
+ def fused_dual_residual_rmsnorm_kernel(
134
+ output_ptr,
135
+ mid_ptr,
136
+ activ_ptr,
137
+ residual_ptr,
138
+ weight1_ptr,
139
+ weight2_ptr,
140
+ eps: tl.constexpr,
141
+ hidden_dim: tl.constexpr,
142
+ BLOCK_SIZE: tl.constexpr,
143
+ ):
144
+ pid = tl.program_id(axis=0)
145
+ input_start = pid * hidden_dim
146
+
147
+ offsets = tl.arange(0, BLOCK_SIZE)
148
+ mask = offsets < hidden_dim
149
+
150
+ a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
151
+ a = a_.to(tl.float32)
152
+ rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
153
+
154
+ r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0)
155
+ w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0)
156
+ w1 = w1_.to(tl.float32)
157
+
158
+ a2r = r + (a / rms * w1).to(r.dtype)
159
+ tl.store(
160
+ mid_ptr + input_start + offsets,
161
+ a2r,
162
+ mask=mask,
163
+ )
164
+
165
+ a2r = a2r.to(tl.float32)
166
+ rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps)
167
+
168
+ w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0)
169
+ w2 = w2_.to(tl.float32)
170
+
171
+ tl.store(
172
+ output_ptr + input_start + offsets,
173
+ a2r / rms2 * w2, # implicitly casts to output dtype here
174
+ mask=mask,
175
+ )
176
+
177
+
178
+ fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
179
+ fused_dual_residual_rmsnorm_kernel
180
+ )
181
+
182
+
183
+ def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
184
+ assert len(x.shape) == 2
185
+ assert x.shape == residual.shape and x.dtype == residual.dtype
186
+ output, mid = torch.empty_like(x), torch.empty_like(x)
187
+ bs, hidden_dim = x.shape
188
+ if autotune:
189
+ fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
190
+ output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
191
+ )
192
+ else:
193
+ config = {
194
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
195
+ "num_warps": max(
196
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
197
+ ),
198
+ }
199
+
200
+ fused_dual_residual_rmsnorm_kernel[(bs,)](
201
+ output,
202
+ mid,
203
+ x,
204
+ residual,
205
+ weight1,
206
+ weight2,
207
+ eps=eps,
208
+ hidden_dim=hidden_dim,
209
+ **config,
210
+ )
211
+
212
+ return output, mid
213
+
214
+
215
+ @triton.jit
216
+ def fused_rmsnorm_kernel(
217
+ output_ptr,
218
+ activ_ptr,
219
+ weight_ptr,
220
+ eps: tl.constexpr,
221
+ hidden_dim: tl.constexpr,
222
+ BLOCK_SIZE: tl.constexpr,
223
+ ):
224
+ pid = tl.program_id(axis=0)
225
+ input_start = pid * hidden_dim
226
+
227
+ offsets = tl.arange(0, BLOCK_SIZE)
228
+ mask = offsets < hidden_dim
229
+
230
+ a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
231
+ a = a_.to(tl.float32)
232
+ rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
233
+
234
+ w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0)
235
+ w1 = w1_.to(tl.float32)
236
+
237
+ a_rms = a / rms * w1
238
+
239
+ tl.store(
240
+ output_ptr + input_start + offsets,
241
+ a_rms, # implicitly casts to output dtype here
242
+ mask=mask,
243
+ )
244
+
245
+
246
+ def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
247
+ assert len(x.shape) == 2
248
+ if inplace:
249
+ output = x
250
+ else:
251
+ output = torch.empty_like(x)
252
+ bs, hidden_dim = x.shape
253
+ config = {
254
+ "BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
255
+ "num_warps": max(
256
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
257
+ ),
258
+ }
259
+
260
+ fused_rmsnorm_kernel[(bs,)](
261
+ output, x, weight, eps=eps, hidden_dim=hidden_dim, **config
262
+ )
263
+ return output
264
+
265
+
266
+ class FusedDualResidualRMSNorm:
267
+ """
268
+ Fused implementation of
269
+ y = RMSNorm2(RMSNorm1(x) + residual))
270
+ """
271
+
272
+ def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1
273
+ self.rmsnorm1 = rmsnorm1
274
+ self.rmsnorm2 = rmsnorm2
275
+ self.variance_epsilon = self.rmsnorm1.variance_epsilon
276
+ assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon
277
+ assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape
278
+
279
+ def __call__(self, *args, **kwargs):
280
+ return self.forward(*args, **kwargs)
281
+
282
+ def forward(
283
+ self, x: torch.Tensor, residual: torch.Tensor
284
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
285
+ if x.is_cuda:
286
+ return self.forward_cuda(x, residual)
287
+ else:
288
+ return self.forward_flashinfer(x, residual)
289
+
290
+ def forward_cuda(
291
+ self, x: torch.Tensor, residual: torch.Tensor, autotune=False
292
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
293
+ return fused_dual_residual_rmsnorm(
294
+ x,
295
+ residual,
296
+ self.rmsnorm1.weight,
297
+ self.rmsnorm2.weight,
298
+ self.variance_epsilon,
299
+ autotune=autotune,
300
+ )
301
+
302
+ def forward_flashinfer(
303
+ self,
304
+ x: torch.Tensor,
305
+ residual: torch.Tensor,
306
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
307
+ normed1 = self.rmsnorm1(x)
308
+ residual = normed1 + residual
309
+ return self.rmsnorm2(residual), residual
310
+
311
+ def forward_native(
312
+ self,
313
+ x: torch.Tensor,
314
+ residual: torch.Tensor,
315
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
316
+ normed1 = self.rmsnorm1.forward_native(x)
317
+ residual = normed1 + residual
318
+ return self.rmsnorm2.forward_native(residual), residual
319
+
320
+
321
+ # gelu on first half of vector
322
+ @triton.jit
323
+ def gelu_and_mul_kernel(
324
+ out_hidden_states_ptr, # (bs, hidden_dim)
325
+ out_scales_ptr, # (bs,)
326
+ hidden_states_ptr, # (bs, hidden_dim * 2)
327
+ quant_max: tl.constexpr,
328
+ static_scale: tl.constexpr,
329
+ hidden_dim: tl.constexpr, # the output hidden_dim
330
+ BLOCK_SIZE: tl.constexpr,
331
+ ):
332
+ pid = tl.program_id(axis=0)
333
+
334
+ input_start = pid * hidden_dim * 2
335
+ output_start = pid * hidden_dim
336
+
337
+ input1_offs = tl.arange(0, BLOCK_SIZE)
338
+ mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
339
+ input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
340
+ output_offs = tl.arange(0, BLOCK_SIZE)
341
+
342
+ x1 = tl.load(
343
+ hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
344
+ ).to(tl.float32)
345
+ x3 = tl.load(
346
+ hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
347
+ ).to(tl.float32)
348
+
349
+ # gelu
350
+ # cast down before mul to better match training?
351
+ gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1
352
+ out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty)
353
+
354
+ if quant_max is not None:
355
+ raise NotImplementedError()
356
+
357
+ tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
358
+
359
+
360
+ def gelu_and_mul_triton(
361
+ hidden_states,
362
+ scales=None,
363
+ quantize=None, # dtype to quantize to
364
+ out=None,
365
+ ):
366
+ bs, in_hidden_dim = hidden_states.shape
367
+ hidden_dim = in_hidden_dim // 2
368
+
369
+ if out is None:
370
+ out_hidden_states = torch.empty(
371
+ (bs, hidden_dim),
372
+ dtype=quantize or hidden_states.dtype,
373
+ device=hidden_states.device,
374
+ )
375
+ else:
376
+ assert out.shape == (bs, hidden_dim)
377
+ assert out.dtype == (quantize or hidden_states.dtype)
378
+ out_hidden_states = out
379
+ out_scales = None
380
+ static_scale = False
381
+ if quantize is not None:
382
+ if scales is None:
383
+ out_scales = torch.empty(
384
+ (bs,), dtype=torch.float32, device=hidden_states.device
385
+ )
386
+ else:
387
+ out_scales = scales
388
+ static_scale = True
389
+
390
+ config = {
391
+ # 8 ele per thread (not tuned)
392
+ "num_warps": max(
393
+ min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
394
+ ),
395
+ }
396
+
397
+ gelu_and_mul_kernel[(bs,)](
398
+ out_hidden_states,
399
+ out_scales,
400
+ hidden_states,
401
+ quant_max=torch.finfo(quantize).max if quantize is not None else None,
402
+ static_scale=static_scale,
403
+ hidden_dim=hidden_dim,
404
+ BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
405
+ **config,
406
+ )
407
+
408
+ if quantize is not None:
409
+ return out_hidden_states, out_scales
410
+ else:
411
+ return out_hidden_states, None
@@ -21,7 +21,9 @@ import torch.nn as nn
21
21
 
22
22
  from sglang.srt.utils import is_cuda_available
23
23
 
24
- if is_cuda_available():
24
+ _is_cuda = is_cuda_available()
25
+
26
+ if _is_cuda:
25
27
  from sgl_kernel import (
26
28
  fused_add_rmsnorm,
27
29
  gemma_fused_add_rmsnorm,
@@ -117,7 +119,27 @@ class GemmaRMSNorm(CustomOp):
117
119
  return out
118
120
 
119
121
 
120
- if not is_cuda_available():
122
+ class Gemma3RMSNorm(nn.Module):
123
+ def __init__(self, dim: int, eps: float = 1e-6):
124
+ super().__init__()
125
+ self.eps = eps
126
+ self.weight = nn.Parameter(torch.zeros(dim))
127
+
128
+ def _norm(self, x):
129
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
130
+
131
+ def forward(self, x):
132
+ output = self._norm(x.float())
133
+ # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
134
+ # See https://github.com/huggingface/transformers/pull/29402
135
+ output = output * (1.0 + self.weight.float())
136
+ return output.type_as(x)
137
+
138
+ def extra_repr(self):
139
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
140
+
141
+
142
+ if not _is_cuda:
121
143
  logger.info(
122
144
  "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
123
145
  )
@@ -23,6 +23,7 @@ from sglang.srt.layers.parameter import (
23
23
  PackedvLLMParameter,
24
24
  PerTensorScaleParameter,
25
25
  RowvLLMParameter,
26
+ _ColumnvLLMParameter,
26
27
  )
27
28
  from sglang.srt.layers.quantization.base_config import (
28
29
  QuantizationConfig,
@@ -423,8 +424,6 @@ class ColumnParallelLinear(LinearBase):
423
424
  assert loaded_weight.numel() == 1
424
425
  loaded_weight = loaded_weight.reshape(1)
425
426
 
426
- from sglang.srt.layers.parameter import _ColumnvLLMParameter
427
-
428
427
  if isinstance(param, _ColumnvLLMParameter):
429
428
  param.load_column_parallel_weight(
430
429
  loaded_weight,
@@ -687,10 +686,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
687
686
  ):
688
687
  if loaded_shard_id is None:
689
688
  if isinstance(param, PerTensorScaleParameter):
690
- param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
689
+ param.load_merged_column_weight(
690
+ loaded_weight=loaded_weight,
691
+ shard_id=0,
692
+ tp_rank=self.tp_rank,
693
+ tp_size=self.tp_size,
694
+ )
691
695
  return
692
696
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
693
- param.load_merged_column_weight(loaded_weight=loaded_weight)
697
+ param.load_merged_column_weight(
698
+ loaded_weight=loaded_weight,
699
+ tp_rank=self.tp_rank,
700
+ tp_size=self.tp_size,
701
+ )
694
702
  return
695
703
  # TODO: @dsikka - move to parameter.py
696
704
  self._load_fused_module_from_checkpoint(param, loaded_weight)
@@ -719,6 +727,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
719
727
  shard_offset=shard_offset,
720
728
  shard_size=shard_size,
721
729
  use_presharded_weights=self.use_presharded_weights,
730
+ tp_rank=self.tp_rank,
731
+ tp_size=self.tp_size,
722
732
  )
723
733
 
724
734
 
@@ -782,6 +792,8 @@ class QKVParallelLinear(ColumnParallelLinear):
782
792
  else:
783
793
  self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
784
794
  self.num_kv_head_replicas = 1
795
+ self.q_proj_shard_size = self.num_heads * self.head_size
796
+ self.kv_proj_shard_size = self.num_kv_heads * self.head_size
785
797
  input_size = self.hidden_size
786
798
  output_size = (
787
799
  (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
@@ -1234,7 +1246,7 @@ class RowParallelLinear(LinearBase):
1234
1246
  assert loaded_weight.numel() == 1
1235
1247
  loaded_weight = loaded_weight.reshape(1)
1236
1248
 
1237
- if isinstance(param, BasevLLMParameter):
1249
+ if isinstance(param, RowvLLMParameter):
1238
1250
  # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
1239
1251
  # It supports additional parameters like tp_rank and use_presharded_weights.
1240
1252
  param.load_row_parallel_weight(