sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,299 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Optional
4
+
5
+ import torch
6
+ from torch.nn.functional import scaled_dot_product_attention
7
+
8
+ from sglang.srt.layers.attention import AttentionBackend
9
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
10
+
11
+ if TYPE_CHECKING:
12
+ from sglang.srt.layers.radix_attention import RadixAttention
13
+ from sglang.srt.model_executor.model_runner import ModelRunner
14
+
15
+
16
+ class TorchNativeAttnBackend(AttentionBackend):
17
+ def __init__(self, model_runner: ModelRunner):
18
+ super().__init__()
19
+ self.forward_metadata = None
20
+ self.device = model_runner.device
21
+
22
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
23
+ """Init the metadata for a forward pass."""
24
+ pass
25
+
26
+ def init_cuda_graph_state(self, max_bs: int):
27
+ # TODO: Support CUDA graph
28
+ raise ValueError(
29
+ "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
30
+ )
31
+
32
+ def init_forward_metadata_capture_cuda_graph(
33
+ self,
34
+ bs: int,
35
+ req_pool_indices: torch.Tensor,
36
+ seq_lens: torch.Tensor,
37
+ encoder_lens: Optional[torch.Tensor] = None,
38
+ ):
39
+ # TODO: Support CUDA graph
40
+ raise ValueError(
41
+ "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
42
+ )
43
+
44
+ def init_forward_metadata_replay_cuda_graph(
45
+ self,
46
+ bs: int,
47
+ req_pool_indices: torch.Tensor,
48
+ seq_lens: torch.Tensor,
49
+ seq_lens_sum: int,
50
+ encoder_lens: Optional[torch.Tensor] = None,
51
+ ):
52
+ # TODO: Support CUDA graph
53
+ raise ValueError(
54
+ "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
55
+ )
56
+
57
+ def get_cuda_graph_seq_len_fill_value(self):
58
+ # TODO: Support CUDA graph
59
+ raise ValueError(
60
+ "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
61
+ )
62
+
63
+ def _run_sdpa_forward_extend(
64
+ self,
65
+ query: torch.Tensor,
66
+ output: torch.Tensor,
67
+ k_cache: torch.Tensor,
68
+ v_cache: torch.Tensor,
69
+ req_to_token: torch.Tensor,
70
+ req_pool_indices: torch.Tensor,
71
+ seq_lens: torch.Tensor,
72
+ extend_prefix_lens: torch.Tensor,
73
+ extend_seq_lens: torch.Tensor,
74
+ scaling=None,
75
+ enable_gqa=False,
76
+ causal=False,
77
+ ):
78
+ """Run the extend forward by using torch native sdpa op.
79
+
80
+ Args:
81
+ query: [num_tokens, num_heads, head_size]
82
+ output: [num_tokens, num_heads, head_size]
83
+ k_cache: [max_total_num_tokens, num_heads, head_size]
84
+ v_cache: [max_total_num_tokens, num_heads, head_size]
85
+ req_to_token: [max_num_reqs, max_context_len]
86
+ req_pool_indices: [num_seqs]
87
+ seq_lens: [num_seqs]
88
+ extend_prefix_lens: [num_seqs]
89
+ extend_seq_lens: [num_seqs]
90
+ scaling: float or None
91
+ enable_gqa: bool
92
+ causal: bool
93
+
94
+ Returns:
95
+ output: [num_tokens, num_heads, head_size]
96
+ """
97
+
98
+ assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
99
+ assert seq_lens.shape[0] == extend_seq_lens.shape[0]
100
+
101
+ # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
102
+ query = query.movedim(0, query.dim() - 2)
103
+
104
+ start_q, start_kv = 0, 0
105
+ for seq_idx in range(seq_lens.shape[0]):
106
+ # TODO: this loop process a sequence per iter, this is inefficient.
107
+ # Need optimize the performance later.
108
+
109
+ extend_seq_len_q = extend_seq_lens[seq_idx]
110
+ prefill_seq_len_q = extend_prefix_lens[seq_idx]
111
+
112
+ seq_len_kv = seq_lens[seq_idx]
113
+ end_q = start_q + extend_seq_len_q
114
+ end_kv = start_kv + seq_len_kv
115
+
116
+ per_req_query = query[:, start_q:end_q, :]
117
+ per_req_query_redudant = torch.empty(
118
+ (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
119
+ dtype=per_req_query.dtype,
120
+ device=per_req_query.device,
121
+ )
122
+
123
+ per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
124
+
125
+ # get key and value from cache. per_req_tokens contains the kv cache
126
+ # index for each token in the sequence.
127
+ req_pool_idx = req_pool_indices[seq_idx]
128
+ per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
129
+ per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
130
+ per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
131
+
132
+ per_req_out_redudant = (
133
+ scaled_dot_product_attention(
134
+ per_req_query_redudant.unsqueeze(0),
135
+ per_req_key.unsqueeze(0),
136
+ per_req_value.unsqueeze(0),
137
+ enable_gqa=enable_gqa,
138
+ scale=scaling,
139
+ is_causal=causal,
140
+ )
141
+ .squeeze(0)
142
+ .movedim(query.dim() - 2, 0)
143
+ )
144
+ output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
145
+ start_q, start_kv = end_q, end_kv
146
+ return output
147
+
148
+ def _run_sdpa_forward_decode(
149
+ self,
150
+ query: torch.Tensor,
151
+ output: torch.Tensor,
152
+ k_cache: torch.Tensor,
153
+ v_cache: torch.Tensor,
154
+ req_to_token: torch.Tensor,
155
+ req_pool_indices: torch.Tensor,
156
+ seq_lens: torch.Tensor,
157
+ scaling=None,
158
+ enable_gqa=False,
159
+ causal=False,
160
+ ):
161
+ """Run the decode forward by using torch native sdpa op.
162
+
163
+ Args:
164
+ query: [num_tokens, num_heads, head_size]
165
+ output: [num_tokens, num_heads, head_size]
166
+ k_cache: [max_total_num_tokens, num_heads, head_size]
167
+ v_cache: [max_total_num_tokens, num_heads, head_size]
168
+ req_to_token: [max_num_reqs, max_context_len]
169
+ req_pool_indices: [num_seqs]
170
+ seq_lens: [num_seqs]
171
+ scaling: float or None
172
+ enable_gqa: bool
173
+ causal: bool
174
+
175
+ Returns:
176
+ output: [num_tokens, num_heads, head_size]
177
+ """
178
+
179
+ # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
180
+ query = query.movedim(0, query.dim() - 2)
181
+
182
+ start_q, start_kv = 0, 0
183
+ for seq_idx in range(seq_lens.shape[0]):
184
+ # TODO: this loop process a sequence per iter, this is inefficient.
185
+ # Need optimize the performance later.
186
+
187
+ seq_len_q = 1
188
+ seq_len_kv = seq_lens[seq_idx]
189
+ end_q = start_q + seq_len_q
190
+ end_kv = start_kv + seq_len_kv
191
+
192
+ per_req_query = query[:, start_q:end_q, :]
193
+
194
+ # get key and value from cache. per_req_tokens contains the kv cache
195
+ # index for each token in the sequence.
196
+ req_pool_idx = req_pool_indices[seq_idx]
197
+ per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
198
+ per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
199
+ per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
200
+
201
+ per_req_out = (
202
+ scaled_dot_product_attention(
203
+ per_req_query.unsqueeze(0),
204
+ per_req_key.unsqueeze(0),
205
+ per_req_value.unsqueeze(0),
206
+ enable_gqa=enable_gqa,
207
+ scale=scaling,
208
+ is_causal=causal,
209
+ )
210
+ .squeeze(0)
211
+ .movedim(query.dim() - 2, 0)
212
+ )
213
+ output[start_q:end_q, :, :] = per_req_out
214
+ start_q, start_kv = end_q, end_kv
215
+
216
+ return output
217
+
218
+ def forward_extend(
219
+ self,
220
+ q,
221
+ k,
222
+ v,
223
+ layer: RadixAttention,
224
+ forward_batch: ForwardBatch,
225
+ save_kv_cache=True,
226
+ ):
227
+ if layer.qk_head_dim != layer.v_head_dim:
228
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
229
+ else:
230
+ o = torch.empty_like(q)
231
+
232
+ if save_kv_cache:
233
+ forward_batch.token_to_kv_pool.set_kv_buffer(
234
+ layer, forward_batch.out_cache_loc, k, v
235
+ )
236
+
237
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
238
+
239
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
240
+ o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
241
+
242
+ self._run_sdpa_forward_extend(
243
+ q_,
244
+ o_,
245
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
246
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
247
+ forward_batch.req_to_token_pool.req_to_token,
248
+ forward_batch.req_pool_indices,
249
+ forward_batch.seq_lens,
250
+ forward_batch.extend_prefix_lens,
251
+ forward_batch.extend_seq_lens,
252
+ scaling=layer.scaling,
253
+ enable_gqa=use_gqa,
254
+ causal=not layer.is_cross_attention,
255
+ )
256
+ return o
257
+
258
+ def forward_decode(
259
+ self,
260
+ q,
261
+ k,
262
+ v,
263
+ layer: RadixAttention,
264
+ forward_batch: ForwardBatch,
265
+ save_kv_cache=True,
266
+ ):
267
+ # During torch.compile, there is a bug in rotary_emb that causes the
268
+ # output value to have a 3D tensor shape. This reshapes the output correctly.
269
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
270
+
271
+ if layer.qk_head_dim != layer.v_head_dim:
272
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
273
+ else:
274
+ o = torch.empty_like(q)
275
+
276
+ if save_kv_cache:
277
+ forward_batch.token_to_kv_pool.set_kv_buffer(
278
+ layer, forward_batch.out_cache_loc, k, v
279
+ )
280
+
281
+ use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
282
+
283
+ q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
284
+ o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
285
+
286
+ self._run_sdpa_forward_decode(
287
+ q_,
288
+ o_,
289
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
290
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
291
+ forward_batch.req_to_token_pool.req_to_token,
292
+ forward_batch.req_pool_indices,
293
+ forward_batch.seq_lens,
294
+ scaling=layer.scaling,
295
+ enable_gqa=use_gqa,
296
+ causal=False,
297
+ )
298
+
299
+ return o
@@ -114,7 +114,13 @@ class TritonAttnBackend(AttentionBackend):
114
114
  return 1
115
115
 
116
116
  def forward_extend(
117
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
117
+ self,
118
+ q,
119
+ k,
120
+ v,
121
+ layer: RadixAttention,
122
+ forward_batch: ForwardBatch,
123
+ save_kv_cache=True,
118
124
  ):
119
125
  # TODO: reuse the buffer across layers
120
126
  if layer.qk_head_dim != layer.v_head_dim:
@@ -122,9 +128,10 @@ class TritonAttnBackend(AttentionBackend):
122
128
  else:
123
129
  o = torch.empty_like(q)
124
130
 
125
- forward_batch.token_to_kv_pool.set_kv_buffer(
126
- layer, forward_batch.out_cache_loc, k, v
127
- )
131
+ if save_kv_cache:
132
+ forward_batch.token_to_kv_pool.set_kv_buffer(
133
+ layer, forward_batch.out_cache_loc, k, v
134
+ )
128
135
 
129
136
  start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
130
137
  self.extend_attention_fwd(
@@ -146,7 +153,13 @@ class TritonAttnBackend(AttentionBackend):
146
153
  return o
147
154
 
148
155
  def forward_decode(
149
- self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
156
+ self,
157
+ q,
158
+ k,
159
+ v,
160
+ layer: RadixAttention,
161
+ forward_batch: ForwardBatch,
162
+ save_kv_cache=True,
150
163
  ):
151
164
  # During torch.compile, there is a bug in rotary_emb that causes the
152
165
  # output value to have a 3D tensor shape. This reshapes the output correctly.
@@ -160,9 +173,10 @@ class TritonAttnBackend(AttentionBackend):
160
173
 
161
174
  start_loc, attn_logits, max_seq_len, max_extend_len = self.forward_metadata
162
175
 
163
- forward_batch.token_to_kv_pool.set_kv_buffer(
164
- layer, forward_batch.out_cache_loc, k, v
165
- )
176
+ if save_kv_cache:
177
+ forward_batch.token_to_kv_pool.set_kv_buffer(
178
+ layer, forward_batch.out_cache_loc, k, v
179
+ )
166
180
 
167
181
  self.decode_attention_fwd(
168
182
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
@@ -284,6 +284,9 @@ def extend_attention_fwd(
284
284
  elif Lq == 288:
285
285
  BLOCK_DMODEL = 256
286
286
  BLOCK_DPE = 32
287
+ elif Lq == 192:
288
+ BLOCK_DMODEL = 128
289
+ BLOCK_DPE = 64
287
290
  else:
288
291
  BLOCK_DMODEL = triton.next_power_of_2(Lq)
289
292
  BLOCK_DPE = 0
File without changes