sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -836,7 +836,6 @@ class MllamaForConditionalGeneration(nn.Module):
836
836
  prefix="multi_modal_projector",
837
837
  )
838
838
  self.logits_processor = LogitsProcessor(config.text_config)
839
- self.capture_mode = False
840
839
 
841
840
  def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
842
841
  pixel_values = torch.cat(
@@ -865,7 +864,6 @@ class MllamaForConditionalGeneration(nn.Module):
865
864
  pixel_values = torch.cat(
866
865
  [item.pixel_values for item in mm_input.mm_items], dim=0
867
866
  )
868
- # max_num_images = max(max_num_images, sum(1 if item.is_image() else 0 for item in mm_input.items))
869
867
  max_num_images = max(max_num_images, pixel_values.shape[1])
870
868
 
871
869
  max_num_tiles = max(max_num_tiles, pixel_values.shape[2])
@@ -970,6 +968,8 @@ class MllamaForConditionalGeneration(nn.Module):
970
968
  positions: torch.Tensor,
971
969
  forward_batch: ForwardBatch,
972
970
  ) -> Union[Tuple, CausalLMOutputWithPast]:
971
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
972
+
973
973
  batched_images, batched_ar_ids, batched_ar_mask, encoder_lens_need = (
974
974
  self._batch_image_inputs(forward_batch)
975
975
  )
@@ -978,7 +978,7 @@ class MllamaForConditionalGeneration(nn.Module):
978
978
  cross_attention_mask = None
979
979
  cross_attention_states = None
980
980
 
981
- if self.capture_mode:
981
+ if get_is_capture_mode():
982
982
  # NOTE: when doing cuda graph capture, we do not want to skip cross attention
983
983
  # Make is a constant value to avoid cuda graph capture issue
984
984
  skip_cross_attention = False
@@ -15,12 +15,14 @@
15
15
  # Adapted from llama2.py
16
16
  # Modify details for the adaptation of Qwen2 model.
17
17
  """Inference-only Qwen2 model compatible with HuggingFace weights."""
18
- from typing import Any, Dict, Iterable, Optional, Tuple
18
+ import logging
19
+ from typing import Any, Dict, Iterable, Optional, Tuple, Union
19
20
 
20
21
  import torch
21
22
  from torch import nn
22
23
 
23
24
  from sglang.srt.distributed import (
25
+ get_pp_group,
24
26
  get_tensor_model_parallel_rank,
25
27
  get_tensor_model_parallel_world_size,
26
28
  )
@@ -36,11 +38,12 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
36
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
39
  from sglang.srt.layers.radix_attention import RadixAttention
38
40
  from sglang.srt.layers.rotary_embedding import get_rope
41
+ from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
39
42
  from sglang.srt.layers.vocab_parallel_embedding import (
40
43
  ParallelLMHead,
41
44
  VocabParallelEmbedding,
42
45
  )
43
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
46
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
44
47
  from sglang.srt.model_loader.weight_utils import (
45
48
  default_weight_loader,
46
49
  kv_cache_scales_loader,
@@ -50,6 +53,9 @@ from sglang.srt.utils import add_prefix, make_layers
50
53
  Qwen2Config = None
51
54
 
52
55
 
56
+ logger = logging.getLogger(__name__)
57
+
58
+
53
59
  class Qwen2MLP(nn.Module):
54
60
  def __init__(
55
61
  self,
@@ -245,15 +251,21 @@ class Qwen2Model(nn.Module):
245
251
  self.config = config
246
252
  self.padding_idx = config.pad_token_id
247
253
  self.vocab_size = config.vocab_size
248
- self.embed_tokens = VocabParallelEmbedding(
249
- config.vocab_size,
250
- config.hidden_size,
251
- quant_config=quant_config,
252
- prefix=add_prefix("embed_tokens", prefix),
253
- )
254
+ self.pp_group = get_pp_group()
255
+
256
+ if self.pp_group.is_first_rank:
257
+ self.embed_tokens = VocabParallelEmbedding(
258
+ config.vocab_size,
259
+ config.hidden_size,
260
+ quant_config=quant_config,
261
+ prefix=add_prefix("embed_tokens", prefix),
262
+ )
263
+ else:
264
+ self.embed_tokens = PPMissingLayer()
265
+
254
266
  # Use the provided decoder layer type or default to Qwen2DecoderLayer
255
267
  decoder_layer_type = decoder_layer_type or Qwen2DecoderLayer
256
- self.layers = make_layers(
268
+ self.layers, self.start_layer, self.end_layer = make_layers(
257
269
  config.num_hidden_layers,
258
270
  lambda idx, prefix: decoder_layer_type(
259
271
  layer_id=idx,
@@ -261,9 +273,14 @@ class Qwen2Model(nn.Module):
261
273
  quant_config=quant_config,
262
274
  prefix=prefix,
263
275
  ),
276
+ pp_rank=self.pp_group.rank_in_group,
277
+ pp_size=self.pp_group.world_size,
264
278
  prefix=add_prefix("layers", prefix),
265
279
  )
266
- self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
280
+ if self.pp_group.is_last_rank:
281
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
282
+ else:
283
+ self.norm = PPMissingLayer(return_tuple=True)
267
284
 
268
285
  def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
269
286
  if hasattr(self.config, "scale_emb"):
@@ -280,13 +297,20 @@ class Qwen2Model(nn.Module):
280
297
  positions: torch.Tensor,
281
298
  forward_batch: ForwardBatch,
282
299
  input_embeds: torch.Tensor = None,
283
- ) -> torch.Tensor:
284
- if input_embeds is None:
285
- hidden_states = self.embed_tokens(input_ids)
300
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
301
+ ) -> Union[torch.Tensor, PPProxyTensors]:
302
+ if self.pp_group.is_first_rank:
303
+ if input_embeds is None:
304
+ hidden_states = self.embed_tokens(input_ids)
305
+ else:
306
+ hidden_states = input_embeds
307
+ residual = None
286
308
  else:
287
- hidden_states = input_embeds
288
- residual = None
289
- for i in range(len(self.layers)):
309
+ assert pp_proxy_tensors is not None
310
+ hidden_states = pp_proxy_tensors["hidden_states"]
311
+ residual = pp_proxy_tensors["residual"]
312
+
313
+ for i in range(self.start_layer, self.end_layer):
290
314
  layer = self.layers[i]
291
315
  hidden_states, residual = layer(
292
316
  positions,
@@ -294,7 +318,15 @@ class Qwen2Model(nn.Module):
294
318
  forward_batch,
295
319
  residual,
296
320
  )
297
- hidden_states, _ = self.norm(hidden_states, residual)
321
+ if not self.pp_group.is_last_rank:
322
+ return PPProxyTensors(
323
+ {
324
+ "hidden_states": hidden_states,
325
+ "residual": residual,
326
+ }
327
+ )
328
+ else:
329
+ hidden_states, _ = self.norm(hidden_states, residual)
298
330
  return hidden_states
299
331
 
300
332
  # If this function is called, it should always initialize KV cache scale
@@ -348,6 +380,7 @@ class Qwen2ForCausalLM(nn.Module):
348
380
  prefix: str = "",
349
381
  ) -> None:
350
382
  super().__init__()
383
+ self.pp_group = get_pp_group()
351
384
  self.config = config
352
385
  self.quant_config = quant_config
353
386
  self.model = Qwen2Model(
@@ -379,14 +412,33 @@ class Qwen2ForCausalLM(nn.Module):
379
412
  forward_batch: ForwardBatch,
380
413
  input_embeds: torch.Tensor = None,
381
414
  get_embedding: bool = False,
415
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
382
416
  ) -> torch.Tensor:
383
- hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
384
- if not get_embedding:
385
- return self.logits_processor(
386
- input_ids, hidden_states, self.lm_head, forward_batch
387
- )
417
+ hidden_states = self.model(
418
+ input_ids,
419
+ positions,
420
+ forward_batch,
421
+ input_embeds,
422
+ pp_proxy_tensors=pp_proxy_tensors,
423
+ )
424
+
425
+ if self.pp_group.is_last_rank:
426
+ if not get_embedding:
427
+ return self.logits_processor(
428
+ input_ids, hidden_states, self.lm_head, forward_batch
429
+ )
430
+ else:
431
+ return self.pooler(hidden_states, forward_batch)
388
432
  else:
389
- return self.pooler(hidden_states, forward_batch)
433
+ return hidden_states
434
+
435
+ @property
436
+ def start_layer(self):
437
+ return self.model.start_layer
438
+
439
+ @property
440
+ def end_layer(self):
441
+ return self.model.end_layer
390
442
 
391
443
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
392
444
  stacked_params_mapping = [
@@ -400,6 +452,17 @@ class Qwen2ForCausalLM(nn.Module):
400
452
 
401
453
  params_dict = dict(self.named_parameters())
402
454
  for name, loaded_weight in weights:
455
+ layer_id = get_layer_id(name)
456
+ if (
457
+ layer_id is not None
458
+ and hasattr(self.model, "start_layer")
459
+ and (
460
+ layer_id < self.model.start_layer
461
+ or layer_id >= self.model.end_layer
462
+ )
463
+ ):
464
+ continue
465
+
403
466
  if "rotary_emb.inv_freq" in name or "projector" in name:
404
467
  continue
405
468
  if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
@@ -426,9 +489,15 @@ class Qwen2ForCausalLM(nn.Module):
426
489
  # Skip loading extra bias for GPTQ models.
427
490
  if name.endswith(".bias") and name not in params_dict:
428
491
  continue
429
- param = params_dict[name]
430
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
431
- weight_loader(param, loaded_weight)
492
+
493
+ if name in params_dict.keys():
494
+ param = params_dict[name]
495
+ weight_loader = getattr(
496
+ param, "weight_loader", default_weight_loader
497
+ )
498
+ weight_loader(param, loaded_weight)
499
+ else:
500
+ logger.warning(f"Parameter {name} not found in params_dict")
432
501
 
433
502
  def get_embed_and_head(self):
434
503
  return self.model.embed_tokens.weight, self.lm_head.weight
@@ -146,6 +146,8 @@ class Qwen2_5_VisionBlock(nn.Module):
146
146
  num_heads=num_heads,
147
147
  projection_size=dim,
148
148
  use_qkv_parallel=True,
149
+ rotary_embed="normal",
150
+ proj_bias=True,
149
151
  qkv_backend=qkv_backend,
150
152
  softmax_in_single_precision=softmax_in_single_precision,
151
153
  flatten_batch=flatten_batch,
@@ -497,6 +499,12 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
497
499
  return pattern.pad_input_tokens(input_ids, mm_inputs)
498
500
 
499
501
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
502
+ if any(item.precomputed_features is not None for item in items):
503
+ if not all(item.precomputed_features is not None for item in items):
504
+ raise NotImplementedError(
505
+ "MM inputs where only some items are precomputed."
506
+ )
507
+ return torch.concat([item.precomputed_features for item in items])
500
508
  # in qwen-vl, last dim is the same
501
509
  pixel_values = torch.cat([item.pixel_values for item in items], dim=0).type(
502
510
  self.visual.dtype