sglang 0.5.1.post1__py3-none-any.whl → 0.5.1.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/bench_one_batch_server.py +79 -53
  2. sglang/bench_serving.py +186 -14
  3. sglang/profiler.py +0 -1
  4. sglang/srt/conversation.py +38 -5
  5. sglang/srt/disaggregation/decode.py +4 -0
  6. sglang/srt/disaggregation/prefill.py +4 -0
  7. sglang/srt/entrypoints/engine.py +2 -2
  8. sglang/srt/entrypoints/openai/protocol.py +27 -24
  9. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  10. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  11. sglang/srt/entrypoints/tool.py +7 -7
  12. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  15. sglang/srt/harmony_parser.py +588 -0
  16. sglang/srt/hf_transformers_utils.py +16 -7
  17. sglang/srt/layers/attention/ascend_backend.py +218 -111
  18. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  19. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  20. sglang/srt/layers/attention/flashinfer_mla_backend.py +76 -91
  21. sglang/srt/layers/attention/utils.py +15 -94
  22. sglang/srt/layers/communicator.py +1 -2
  23. sglang/srt/layers/moe/cutlass_moe.py +0 -15
  24. sglang/srt/layers/moe/ep_moe/layer.py +1 -7
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  27. sglang/srt/layers/moe/topk.py +1 -1
  28. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  29. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -7
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  31. sglang/srt/layers/quantization/fp8.py +2 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  33. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  34. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  35. sglang/srt/layers/quantization/mxfp4.py +16 -23
  36. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  37. sglang/srt/layers/utils.py +0 -14
  38. sglang/srt/lora/lora_manager.py +29 -12
  39. sglang/srt/managers/cache_controller.py +223 -156
  40. sglang/srt/managers/detokenizer_manager.py +5 -0
  41. sglang/srt/managers/io_struct.py +30 -0
  42. sglang/srt/managers/scheduler.py +58 -7
  43. sglang/srt/managers/scheduler_metrics_mixin.py +15 -0
  44. sglang/srt/managers/tokenizer_manager.py +36 -3
  45. sglang/srt/mem_cache/hicache_storage.py +31 -20
  46. sglang/srt/mem_cache/hiradix_cache.py +12 -3
  47. sglang/srt/mem_cache/memory_pool.py +73 -14
  48. sglang/srt/mem_cache/memory_pool_host.py +3 -2
  49. sglang/srt/mem_cache/radix_cache.py +1 -0
  50. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +5 -13
  51. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +85 -81
  52. sglang/srt/metrics/collector.py +5 -5
  53. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  54. sglang/srt/model_executor/model_runner.py +1 -1
  55. sglang/srt/models/deepseek_v2.py +12 -3
  56. sglang/srt/models/gpt_oss.py +2 -1
  57. sglang/srt/models/qwen2_5_vl.py +1 -0
  58. sglang/srt/offloader.py +115 -0
  59. sglang/srt/reasoning_parser.py +56 -300
  60. sglang/srt/server_args.py +10 -5
  61. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  62. sglang/srt/utils.py +59 -12
  63. sglang/test/test_cutlass_moe.py +33 -28
  64. sglang/version.py +1 -1
  65. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/METADATA +6 -5
  66. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/RECORD +69 -65
  67. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/WHEEL +0 -0
  68. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/licenses/LICENSE +0 -0
  69. {sglang-0.5.1.post1.dist-info → sglang-0.5.1.post3.dist-info}/top_level.txt +0 -0
@@ -12,11 +12,16 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
12
12
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
13
13
  from sglang.srt.layers.radix_attention import AttentionType
14
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
15
+ from sglang.srt.utils import get_bool_env_var
15
16
 
16
17
  if TYPE_CHECKING:
17
18
  from sglang.srt.layers.radix_attention import RadixAttention
18
19
  from sglang.srt.model_executor.model_runner import ModelRunner
19
20
 
21
+ import os
22
+
23
+ import numpy as np
24
+
20
25
 
21
26
  @dataclass
22
27
  class ForwardMetadata:
@@ -54,7 +59,6 @@ class AscendAttnBackend(AttentionBackend):
54
59
  super().__init__()
55
60
  self.forward_metadata = None
56
61
  self.device = model_runner.device
57
- self.gen_attention_mask(128, model_runner.dtype)
58
62
  self.page_size = model_runner.page_size
59
63
  self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
60
64
  if self.use_mla:
@@ -65,6 +69,17 @@ class AscendAttnBackend(AttentionBackend):
65
69
  self.max_context_len = model_runner.model_config.context_len
66
70
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
67
71
  self.graph_mode = False
72
+ self.use_fia = get_bool_env_var("ASCEND_USE_FIA", "False")
73
+ if not self.use_fia:
74
+ self.gen_attention_mask(128, model_runner.dtype)
75
+ mask_length = 2048
76
+ self.fia_mask = ~torch.tril(
77
+ torch.ones(
78
+ (mask_length, mask_length),
79
+ dtype=torch.bool,
80
+ device=model_runner.device,
81
+ )
82
+ )
68
83
 
69
84
  def init_forward_metadata(self, forward_batch: ForwardBatch):
70
85
  """Init the metadata for a forward pass."""
@@ -81,6 +96,9 @@ class AscendAttnBackend(AttentionBackend):
81
96
  forward_batch.extend_seq_lens.cpu().int()
82
97
  )
83
98
  self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
99
+ self.forward_metadata.seq_lens_list_cumsum = np.cumsum(
100
+ forward_batch.extend_seq_lens_cpu
101
+ )
84
102
 
85
103
  self.graph_mode = False
86
104
 
@@ -151,71 +169,89 @@ class AscendAttnBackend(AttentionBackend):
151
169
  forward_batch: ForwardBatch,
152
170
  save_kv_cache=True,
153
171
  ):
154
- if save_kv_cache:
155
- forward_batch.token_to_kv_pool.set_kv_buffer(
156
- layer, forward_batch.out_cache_loc, k, v
157
- )
172
+ if not self.use_mla:
173
+ if save_kv_cache:
174
+ forward_batch.token_to_kv_pool.set_kv_buffer(
175
+ layer, forward_batch.out_cache_loc, k, v
176
+ )
158
177
 
159
- k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
160
- v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
178
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
179
+ v_cache = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
161
180
 
162
- if not self.use_mla:
163
- query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
164
- output = torch.empty(
165
- (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
166
- dtype=query.dtype,
167
- device=query.device,
168
- )
181
+ if self.use_fia:
182
+ """FIA will support multi-bs in the later version of CANN"""
183
+ q = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
184
+ attn_output = torch.empty(
185
+ (q.size(0), layer.tp_q_head_num, layer.v_head_dim),
186
+ device=q.device,
187
+ dtype=q.dtype,
188
+ )
189
+ q_len_offset = 0
190
+ for q_len in forward_batch.extend_seq_lens_cpu:
191
+ attn_output[q_len_offset : q_len_offset + q_len] = (
192
+ torch.ops.npu.npu_fused_infer_attention_score(
193
+ q[None, q_len_offset : q_len_offset + q_len],
194
+ k[None, q_len_offset : q_len_offset + q_len],
195
+ v[None, q_len_offset : q_len_offset + q_len],
196
+ num_heads=layer.tp_q_head_num,
197
+ num_key_value_heads=layer.tp_k_head_num,
198
+ input_layout="BSND", # todo, TND not supports q_heads!=k_heads
199
+ atten_mask=self.fia_mask.unsqueeze(0),
200
+ sparse_mode=3,
201
+ scale=layer.scaling,
202
+ next_tokens=0,
203
+ )[0]
204
+ )
205
+ q_len_offset += q_len
206
+ attn_output = attn_output.view(
207
+ -1, layer.tp_q_head_num * layer.v_head_dim
208
+ )
169
209
 
170
- torch_npu._npu_flash_attention_qlens(
171
- query=query,
172
- key_cache=k_cache,
173
- value_cache=v_cache,
174
- mask=self.mask,
175
- block_table=self.forward_metadata.block_tables,
176
- seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
177
- context_lens=self.forward_metadata.seq_lens_cpu_int,
178
- scale_value=layer.scaling,
179
- num_heads=layer.tp_q_head_num,
180
- num_kv_heads=layer.tp_k_head_num,
181
- out=output,
182
- )
183
- return output
184
- else:
185
- if layer.qk_head_dim != layer.v_head_dim:
186
- o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
187
210
  else:
188
- o = torch.empty_like(q)
189
-
190
- use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
191
-
192
- q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
193
- o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
194
-
195
- causal = True
196
- if (
197
- layer.is_cross_attention
198
- or layer.attn_type == AttentionType.ENCODER_ONLY
199
- ):
200
- causal = False
201
-
202
- self.native_attn._run_sdpa_forward_extend(
203
- q_,
204
- o_,
205
- k_cache.view(
206
- -1, layer.tp_k_head_num, (self.kv_lora_rank + self.qk_rope_head_dim)
207
- ),
208
- v_cache.view(-1, layer.tp_v_head_num, self.kv_lora_rank),
209
- forward_batch.req_to_token_pool.req_to_token,
210
- forward_batch.req_pool_indices,
211
- forward_batch.seq_lens,
212
- forward_batch.extend_prefix_lens,
213
- forward_batch.extend_seq_lens,
214
- scaling=layer.scaling,
215
- enable_gqa=use_gqa,
216
- causal=causal,
211
+ query = q.view(-1, layer.tp_q_head_num * layer.qk_head_dim)
212
+ attn_output = torch.empty(
213
+ (query.shape[0], layer.tp_q_head_num * layer.v_head_dim),
214
+ dtype=query.dtype,
215
+ device=query.device,
216
+ )
217
+
218
+ torch_npu._npu_flash_attention_qlens(
219
+ query=query,
220
+ key_cache=k_cache,
221
+ value_cache=v_cache,
222
+ mask=self.mask,
223
+ block_table=self.forward_metadata.block_tables,
224
+ seq_len=self.forward_metadata.extend_seq_lens_cpu_int,
225
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
226
+ scale_value=layer.scaling,
227
+ num_heads=layer.tp_q_head_num,
228
+ num_kv_heads=layer.tp_k_head_num,
229
+ out=attn_output,
230
+ )
231
+ else:
232
+ assert (
233
+ layer.qk_head_dim != layer.v_head_dim
234
+ ), "FIA only supports qk_head_dim != v_head_dim"
235
+ q_nope, q_rope = q.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
236
+ k_nope, k_rope = k.split([layer.v_head_dim, self.qk_rope_head_dim], dim=-1)
237
+
238
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
239
+ q_nope,
240
+ k_nope,
241
+ v,
242
+ query_rope=q_rope,
243
+ key_rope=k_rope,
244
+ num_heads=layer.tp_q_head_num,
245
+ input_layout="TND",
246
+ atten_mask=self.fia_mask,
247
+ sparse_mode=3,
248
+ actual_seq_lengths=self.forward_metadata.seq_lens_list_cumsum,
249
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_list_cumsum,
250
+ scale=layer.scaling,
251
+ next_tokens=0,
217
252
  )
218
- return o
253
+
254
+ return attn_output
219
255
 
220
256
  def forward_decode(
221
257
  self,
@@ -224,13 +260,17 @@ class AscendAttnBackend(AttentionBackend):
224
260
  v: torch.Tensor,
225
261
  layer: RadixAttention,
226
262
  forward_batch: ForwardBatch,
227
- save_kv_cache=True,
263
+ save_kv_cache: bool = False,
264
+ # For multi-head latent attention
265
+ q_rope: Optional[torch.Tensor] = None,
266
+ k_rope: Optional[torch.Tensor] = None,
228
267
  ):
229
- if save_kv_cache:
230
- forward_batch.token_to_kv_pool.set_kv_buffer(
231
- layer, forward_batch.out_cache_loc, k, v
232
- )
233
268
  if not self.use_mla:
269
+ if save_kv_cache:
270
+ forward_batch.token_to_kv_pool.set_kv_buffer(
271
+ layer, forward_batch.out_cache_loc, k, v
272
+ )
273
+ num_tokens = q.shape[0]
234
274
  if self.graph_mode:
235
275
  k_cache = forward_batch.token_to_kv_pool.get_key_buffer(
236
276
  layer.layer_id
@@ -239,7 +279,6 @@ class AscendAttnBackend(AttentionBackend):
239
279
  layer.layer_id
240
280
  ).view(-1, self.page_size, layer.tp_v_head_num * layer.v_head_dim)
241
281
  query = q.view(-1, 1, layer.tp_q_head_num * layer.qk_head_dim)
242
- num_tokens = query.shape[0]
243
282
  workspace = (
244
283
  torch_npu._npu_fused_infer_attention_score_get_max_workspace(
245
284
  query,
@@ -254,7 +293,7 @@ class AscendAttnBackend(AttentionBackend):
254
293
  actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
255
294
  )
256
295
  )
257
- output = torch.empty(
296
+ attn_output = torch.empty(
258
297
  (num_tokens, 1, layer.tp_q_head_num * layer.v_head_dim),
259
298
  dtype=q.dtype,
260
299
  device=q.device,
@@ -272,61 +311,129 @@ class AscendAttnBackend(AttentionBackend):
272
311
  scale=layer.scaling,
273
312
  actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_list,
274
313
  workspace=workspace,
275
- out=[output, softmax_lse],
314
+ out=[attn_output, softmax_lse],
276
315
  )
277
316
  else:
278
317
  k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
279
318
  v_cache = forward_batch.token_to_kv_pool.get_value_buffer(
280
319
  layer.layer_id
281
320
  )
321
+ if self.use_fia:
322
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
323
+ q.view(
324
+ forward_batch.batch_size,
325
+ -1,
326
+ layer.tp_q_head_num,
327
+ layer.qk_head_dim,
328
+ ),
329
+ k_cache.view(
330
+ -1, self.page_size, layer.tp_k_head_num * layer.qk_head_dim
331
+ ),
332
+ v_cache.view(
333
+ -1, self.page_size, layer.tp_v_head_num * layer.qk_head_dim
334
+ ),
335
+ num_heads=layer.tp_q_head_num,
336
+ num_key_value_heads=layer.tp_k_head_num,
337
+ input_layout="BSND",
338
+ atten_mask=None,
339
+ block_size=self.page_size,
340
+ block_table=self.forward_metadata.block_tables,
341
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
342
+ scale=layer.scaling,
343
+ )
344
+ else:
345
+ query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
346
+ attn_output = torch.empty(
347
+ (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
348
+ dtype=query.dtype,
349
+ device=query.device,
350
+ )
282
351
 
283
- query = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
284
- num_tokens = query.shape[0]
285
- output = torch.empty(
286
- (num_tokens, layer.tp_q_head_num, layer.v_head_dim),
287
- dtype=query.dtype,
288
- device=query.device,
352
+ torch_npu._npu_paged_attention(
353
+ query=query,
354
+ key_cache=k_cache,
355
+ value_cache=v_cache,
356
+ num_heads=layer.tp_q_head_num,
357
+ num_kv_heads=layer.tp_k_head_num,
358
+ scale_value=layer.scaling,
359
+ block_table=self.forward_metadata.block_tables,
360
+ context_lens=self.forward_metadata.seq_lens_cpu_int,
361
+ out=attn_output,
362
+ )
363
+ return attn_output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
364
+ else:
365
+ if save_kv_cache:
366
+ forward_batch.token_to_kv_pool.set_kv_buffer(
367
+ layer, forward_batch.out_cache_loc, k, k_rope
289
368
  )
290
-
291
- torch_npu._npu_paged_attention(
292
- query=query,
293
- key_cache=k_cache,
294
- value_cache=v_cache,
369
+ num_tokens = q.shape[0]
370
+ kv_c = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
371
+ k_pe = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
372
+
373
+ if (self.graph_mode or self.use_fia) and (
374
+ layer.tp_q_head_num // layer.tp_k_head_num
375
+ ) >= 8:
376
+ """layer.tp_q_head_num // layer.tp_k_head_num < 8 will support in the later version of CANN"""
377
+ kv_c = kv_c.view(
378
+ -1, self.page_size, layer.tp_k_head_num * self.kv_lora_rank
379
+ )
380
+ k_pe = k_pe.view(
381
+ -1, self.page_size, layer.tp_k_head_num * self.qk_rope_head_dim
382
+ )
383
+ q = q.view(
384
+ forward_batch.batch_size, -1, layer.tp_q_head_num, self.kv_lora_rank
385
+ )
386
+ q_rope = q_rope.view(
387
+ forward_batch.batch_size,
388
+ -1,
389
+ layer.tp_q_head_num,
390
+ self.qk_rope_head_dim,
391
+ )
392
+ attn_output, _ = torch.ops.npu.npu_fused_infer_attention_score(
393
+ q,
394
+ kv_c,
395
+ kv_c,
396
+ query_rope=q_rope,
397
+ key_rope=k_pe,
295
398
  num_heads=layer.tp_q_head_num,
399
+ num_key_value_heads=layer.tp_k_head_num,
400
+ input_layout="BSND",
401
+ atten_mask=None,
402
+ sparse_mode=0,
403
+ scale=layer.scaling,
404
+ antiquant_mode=0,
405
+ antiquant_scale=None,
406
+ block_table=self.forward_metadata.block_tables,
407
+ block_size=self.page_size,
408
+ actual_seq_lengths_kv=self.forward_metadata.seq_lens_cpu_int,
409
+ )
410
+ else:
411
+ assert (
412
+ self.graph_mode == False
413
+ ) # _npu_paged_attention_mla not support graph mode
414
+ q = torch.cat([q, q_rope], dim=-1)
415
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
416
+ kv_c_and_k_pe_cache = torch.cat([kv_c, k_pe], dim=-1)
417
+ kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
418
+ -1,
419
+ self.page_size,
420
+ layer.tp_k_head_num,
421
+ self.kv_lora_rank + self.qk_rope_head_dim,
422
+ )
423
+ attn_output = torch.empty(
424
+ [num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
425
+ dtype=q.dtype,
426
+ device=q.device,
427
+ )
428
+ torch_npu._npu_paged_attention_mla(
429
+ query=query,
430
+ key_cache=kv_c_and_k_pe_cache,
296
431
  num_kv_heads=layer.tp_k_head_num,
432
+ num_heads=layer.tp_q_head_num,
297
433
  scale_value=layer.scaling,
298
434
  block_table=self.forward_metadata.block_tables,
299
435
  context_lens=self.forward_metadata.seq_lens_cpu_int,
300
- out=output,
436
+ mla_vheadsize=self.kv_lora_rank,
437
+ out=attn_output,
301
438
  )
302
- return output.view(num_tokens, layer.tp_q_head_num * layer.v_head_dim)
303
- else:
304
- query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
305
- num_tokens = query.shape[0]
306
- kv_c_and_k_pe_cache = forward_batch.token_to_kv_pool.get_key_buffer(
307
- layer.layer_id
308
- )
309
- kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view(
310
- -1,
311
- self.page_size,
312
- layer.tp_k_head_num,
313
- self.kv_lora_rank + self.qk_rope_head_dim,
314
- )
315
-
316
- attn_output = torch.empty(
317
- [num_tokens, layer.tp_q_head_num, self.kv_lora_rank],
318
- dtype=q.dtype,
319
- device=q.device,
320
- )
321
- torch_npu._npu_paged_attention_mla(
322
- query=query,
323
- key_cache=kv_c_and_k_pe_cache,
324
- num_kv_heads=layer.tp_k_head_num,
325
- num_heads=layer.tp_q_head_num,
326
- scale_value=layer.scaling,
327
- block_table=self.forward_metadata.block_tables,
328
- context_lens=self.forward_metadata.seq_lens_cpu_int,
329
- mla_vheadsize=self.kv_lora_rank,
330
- out=attn_output,
331
- )
332
439
  return attn_output.view(num_tokens, layer.tp_q_head_num * self.kv_lora_rank)