sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -57,6 +57,7 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
57
57
  is_mla_preprocess_enabled,
58
58
  )
59
59
  from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
60
+ from sglang.srt.layers.attention.utils import concat_and_cast_mha_k_triton
60
61
  from sglang.srt.layers.communicator import (
61
62
  LayerCommunicator,
62
63
  LayerScatterModes,
@@ -241,6 +242,10 @@ class AttnForwardMethod(IntEnum):
241
242
  # This method can avoid OOM when prefix lengths are long.
242
243
  MHA_CHUNKED_KV = auto()
243
244
 
245
+ # Use multi-head attention, execute the MHA for prefix and extended kv in one shot
246
+ # when the sequence lengths are below the threshold.
247
+ MHA_ONE_SHOT = auto()
248
+
244
249
  # Use MLA but with fused RoPE
245
250
  MLA_FUSED_ROPE = auto()
246
251
 
@@ -306,6 +311,14 @@ def _is_extend_without_speculative(forward_batch):
306
311
  )
307
312
 
308
313
 
314
+ def _support_mha_one_shot(attn: DeepseekV2AttentionMLA, forward_batch, backend_name):
315
+ attn_supported = backend_name in ["fa3", "flashinfer", "flashmla"]
316
+ sum_seq_lens = (
317
+ sum(forward_batch.seq_lens_cpu) if forward_batch.seq_lens_cpu is not None else 0
318
+ )
319
+ return attn_supported and sum_seq_lens <= forward_batch.get_max_chunk_capacity()
320
+
321
+
309
322
  def _handle_attention_backend(
310
323
  attn: DeepseekV2AttentionMLA, forward_batch, backend_name
311
324
  ):
@@ -325,6 +338,8 @@ def _handle_attention_backend(
325
338
  or sum_extend_prefix_lens == 0
326
339
  )
327
340
  ):
341
+ if _support_mha_one_shot(attn, forward_batch, backend_name):
342
+ return AttnForwardMethod.MHA_ONE_SHOT
328
343
  return AttnForwardMethod.MHA_CHUNKED_KV
329
344
  else:
330
345
  return _dispatch_mla_subtype(attn, forward_batch)
@@ -335,7 +350,11 @@ def handle_attention_flashinfer(attn, forward_batch):
335
350
 
336
351
 
337
352
  def handle_attention_fa3(attn, forward_batch):
338
- return _handle_attention_backend(attn, forward_batch, "fa3")
353
+ # when deterministic inference is enabled, use MLA
354
+ if get_global_server_args().enable_deterministic_inference:
355
+ return _dispatch_mla_subtype(attn, forward_batch)
356
+ else:
357
+ return _handle_attention_backend(attn, forward_batch, "fa3")
339
358
 
340
359
 
341
360
  def handle_attention_flashmla(attn, forward_batch):
@@ -379,6 +398,10 @@ def handle_attention_nsa(attn, forward_batch):
379
398
 
380
399
 
381
400
  def handle_attention_triton(attn, forward_batch):
401
+ # when deterministic inference is enabled, use MLA
402
+ if get_global_server_args().enable_deterministic_inference:
403
+ return _dispatch_mla_subtype(attn, forward_batch)
404
+
382
405
  if (
383
406
  _is_extend_without_speculative(forward_batch)
384
407
  and sum(forward_batch.extend_prefix_lens_cpu) == 0
@@ -982,16 +1005,14 @@ class DeepseekV2MoE(nn.Module):
982
1005
  )
983
1006
 
984
1007
  def op_experts(self, state):
985
- state.hidden_states_experts_output = self.experts.run_moe_core(
1008
+ state.combine_input = self.experts.run_moe_core(
986
1009
  dispatch_output=state.dispatch_output,
987
1010
  )
988
1011
 
989
1012
  def op_combine_a(self, state):
990
1013
  if self.ep_size > 1:
991
1014
  self.experts.dispatcher.combine_a(
992
- hidden_states=state.pop("hidden_states_experts_output"),
993
- topk_ids=state.dispatch_output.topk_ids,
994
- topk_weights=state.dispatch_output.topk_weights,
1015
+ combine_input=state.pop("combine_input"),
995
1016
  tbo_subbatch_index=state.get("tbo_subbatch_index"),
996
1017
  )
997
1018
  state.pop("dispatch_output")
@@ -1062,6 +1083,7 @@ class DeepseekV2AttentionMLA(nn.Module):
1062
1083
  self.scaling = self.qk_head_dim**-0.5
1063
1084
  self.rope_theta = rope_theta
1064
1085
  self.max_position_embeddings = max_position_embeddings
1086
+ self.kv_cache_dtype = get_global_server_args().kv_cache_dtype
1065
1087
 
1066
1088
  # NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
1067
1089
  if rope_scaling:
@@ -1359,6 +1381,10 @@ class DeepseekV2AttentionMLA(nn.Module):
1359
1381
  inner_state = self.forward_normal_chunked_kv_prepare(
1360
1382
  positions, hidden_states, forward_batch, zero_allocator
1361
1383
  )
1384
+ elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
1385
+ inner_state = self.forward_normal_one_shot_prepare(
1386
+ positions, hidden_states, forward_batch, zero_allocator
1387
+ )
1362
1388
  elif attn_forward_method == AttnForwardMethod.MLA:
1363
1389
  if not self.is_mla_preprocess_enabled:
1364
1390
  inner_state = self.forward_absorb_prepare(
@@ -1410,6 +1436,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1410
1436
  return self.forward_normal_core(*inner_state)
1411
1437
  elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
1412
1438
  return self.forward_normal_chunked_kv_core(*inner_state)
1439
+ elif attn_forward_method == AttnForwardMethod.MHA_ONE_SHOT:
1440
+ return self.forward_normal_one_shot_core(*inner_state)
1413
1441
  elif attn_forward_method == AttnForwardMethod.MLA:
1414
1442
  return self.forward_absorb_core(*inner_state)
1415
1443
  elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
@@ -1444,41 +1472,24 @@ class DeepseekV2AttentionMLA(nn.Module):
1444
1472
  kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1445
1473
  latent_cache = latent_cache.unsqueeze(1)
1446
1474
  kv_a = self.kv_a_layernorm(kv_a)
1447
- kv = self.kv_b_proj(kv_a)[0]
1448
- kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1449
- k_nope = kv[..., : self.qk_nope_head_dim]
1450
- v = kv[..., self.qk_nope_head_dim :]
1451
1475
  k_pe = latent_cache[:, :, self.kv_lora_rank :]
1452
1476
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
1453
1477
  q[..., self.qk_nope_head_dim :] = q_pe
1454
- k = torch.empty_like(q)
1455
1478
 
1456
- # Temporary for DeepSeek V3/R1 only, but can generalize if needed
1479
+ self._set_mla_kv_buffer(latent_cache, kv_a, k_pe, forward_batch)
1457
1480
  if (
1458
- _is_cuda
1459
- and (self.num_local_heads == 128)
1460
- and (self.qk_nope_head_dim == 128)
1461
- and (self.qk_rope_head_dim == 64)
1481
+ forward_batch.mha_one_shot
1482
+ and sum(forward_batch.extend_prefix_lens_cpu) != 0
1462
1483
  ):
1463
- concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
1464
- else:
1465
- k[..., : self.qk_nope_head_dim] = k_nope
1466
- k[..., self.qk_nope_head_dim :] = k_pe
1467
-
1468
- if not _is_npu:
1469
- latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
1470
- latent_cache[:, :, self.kv_lora_rank :] = k_pe
1471
-
1472
- # Save latent cache
1473
- forward_batch.token_to_kv_pool.set_kv_buffer(
1474
- self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1475
- )
1476
- else:
1477
- # To reduce a time-costing split operation
1478
- forward_batch.token_to_kv_pool.set_kv_buffer(
1479
- self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
1484
+ kv_a, k_pe = self._get_mla_kv_buffer(
1485
+ forward_batch.fetch_mha_one_shot_kv_indices(), q.dtype, forward_batch
1480
1486
  )
1487
+ kv = self.kv_b_proj(kv_a)[0]
1488
+ kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
1489
+ k_nope = kv[..., : self.qk_nope_head_dim]
1490
+ v = kv[..., self.qk_nope_head_dim :]
1481
1491
 
1492
+ k = self._concat_and_cast_mha_k(k_nope, k_pe, forward_batch)
1482
1493
  return q, k, v, forward_batch
1483
1494
 
1484
1495
  def forward_normal_core(self, q, k, v, forward_batch):
@@ -2288,20 +2299,11 @@ class DeepseekV2AttentionMLA(nn.Module):
2288
2299
  for i in range(forward_batch.num_prefix_chunks):
2289
2300
  forward_batch.set_prefix_chunk_idx(i)
2290
2301
 
2302
+ kv_indices = forward_batch.prefix_chunk_kv_indices[i]
2291
2303
  # Fetch latent cache from memory pool with precomputed chunked kv indices
2292
- latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
2293
- self.attn_mha.layer_id
2294
- )
2295
- latent_cache = (
2296
- latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
2297
- .contiguous()
2298
- .to(q.dtype)
2304
+ kv_a_normed, k_pe = self._get_mla_kv_buffer(
2305
+ kv_indices, q.dtype, forward_batch
2299
2306
  )
2300
-
2301
- kv_a_normed, k_pe = latent_cache.split(
2302
- [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
2303
- )
2304
- kv_a_normed = kv_a_normed.squeeze(1).contiguous()
2305
2307
  kv = self.kv_b_proj(kv_a_normed)[0]
2306
2308
  kv = kv.view(
2307
2309
  -1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
@@ -2376,6 +2378,107 @@ class DeepseekV2AttentionMLA(nn.Module):
2376
2378
  output, _ = self.o_proj(attn_output)
2377
2379
  return output
2378
2380
 
2381
+ def forward_normal_one_shot_prepare(
2382
+ self,
2383
+ positions: torch.Tensor,
2384
+ hidden_states: torch.Tensor,
2385
+ forward_batch: ForwardBatch,
2386
+ zero_allocator: BumpAllocator,
2387
+ ):
2388
+ forward_batch.mha_one_shot = True
2389
+ return self.forward_normal_prepare(
2390
+ positions, hidden_states, forward_batch, zero_allocator
2391
+ )
2392
+
2393
+ def forward_normal_one_shot_core(self, q, k, v, forward_batch):
2394
+ has_extend_prefix = any(forward_batch.extend_prefix_lens_cpu)
2395
+ # Only initialize the info once
2396
+ if has_extend_prefix and forward_batch.num_prefix_chunks is None:
2397
+ forward_batch.num_prefix_chunks = 0
2398
+ if hasattr(forward_batch.attn_backend, "init_mha_chunk_metadata"):
2399
+ forward_batch.attn_backend.init_mha_chunk_metadata(forward_batch)
2400
+ forward_batch.mha_return_lse = False
2401
+ # Do mha for extended part without prefix
2402
+ forward_batch.set_attn_attend_prefix_cache(False)
2403
+ return self.forward_normal_core(q, k, v, forward_batch)
2404
+
2405
+ def _set_mla_kv_buffer(
2406
+ self,
2407
+ latent_cache: torch.Tensor,
2408
+ kv_a: torch.Tensor,
2409
+ k_pe: torch.Tensor,
2410
+ forward_batch: ForwardBatch,
2411
+ ):
2412
+ if _is_cuda:
2413
+ # Save latent cache
2414
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
2415
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
2416
+ )
2417
+ elif _is_npu:
2418
+ # To reduce a time-costing split operation
2419
+ forward_batch.token_to_kv_pool.set_kv_buffer(
2420
+ self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
2421
+ )
2422
+ else:
2423
+ latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
2424
+ latent_cache[:, :, self.kv_lora_rank :] = k_pe
2425
+
2426
+ # Save latent cache
2427
+ forward_batch.token_to_kv_pool.set_kv_buffer(
2428
+ self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
2429
+ )
2430
+
2431
+ def _get_mla_kv_buffer(
2432
+ self,
2433
+ kv_indices: torch.Tensor,
2434
+ dst_dtype: torch.dtype,
2435
+ forward_batch: ForwardBatch,
2436
+ ):
2437
+ if _is_cuda:
2438
+ kv_a, k_pe = forward_batch.token_to_kv_pool.get_mla_kv_buffer(
2439
+ self.attn_mha, kv_indices, dst_dtype
2440
+ )
2441
+ kv_a = kv_a.squeeze(1)
2442
+ else:
2443
+ latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
2444
+ self.attn_mha.layer_id
2445
+ )
2446
+ latent_cache = latent_cache_buf[kv_indices].contiguous().to(dst_dtype)
2447
+
2448
+ kv_a, k_pe = latent_cache.split(
2449
+ [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
2450
+ )
2451
+ kv_a = kv_a.squeeze(1).contiguous()
2452
+ return kv_a, k_pe
2453
+
2454
+ def _concat_and_cast_mha_k(self, k_nope, k_pe, forward_batch):
2455
+ # Temporary for DeepSeek V3/R1 only, but can generalize if needed
2456
+ k_shape = (k_nope.shape[0], self.num_local_heads, self.qk_head_dim)
2457
+ if (
2458
+ _is_cuda
2459
+ and (self.num_local_heads == 128)
2460
+ and (self.qk_nope_head_dim == 128)
2461
+ and (self.qk_rope_head_dim == 64)
2462
+ ):
2463
+ k = k_nope.new_empty(*k_shape)
2464
+ concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
2465
+ elif _is_cuda:
2466
+ # fa3 mha support fp8 inputs
2467
+ if (
2468
+ self.current_attention_backend == "fa3"
2469
+ and self.kv_cache_dtype != "auto"
2470
+ ):
2471
+ attn_dtype = forward_batch.token_to_kv_pool.dtype
2472
+ else:
2473
+ attn_dtype = k_nope.dtype
2474
+ k = k_nope.new_empty(*k_shape, dtype=attn_dtype)
2475
+ concat_and_cast_mha_k_triton(k, k_nope, k_pe)
2476
+ else:
2477
+ k = k_nope.new_empty(*k_shape)
2478
+ k[..., : self.qk_nope_head_dim] = k_nope
2479
+ k[..., self.qk_nope_head_dim :] = k_pe
2480
+ return k
2481
+
2379
2482
  @staticmethod
2380
2483
  def _get_q_b_proj_quant_config(quant_config):
2381
2484
  if get_bool_env_var("SGLANG_NVFP4_CKPT_FP8_GEMM_IN_ATTN"):