sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -47,10 +47,11 @@ from sglang.srt.configs.janus_pro import *
47
47
  from sglang.srt.layers.attention.vision import VisionAttention
48
48
  from sglang.srt.layers.logits_processor import LogitsProcessor
49
49
  from sglang.srt.layers.quantization import QuantizationConfig
50
- from sglang.srt.managers.multi_modality_padding import (
50
+ from sglang.srt.managers.mm_utils import (
51
51
  MultiModalityDataPaddingPatternTokenPairs,
52
+ general_mm_embed_routine,
52
53
  )
53
- from sglang.srt.managers.schedule_batch import ImageInputs
54
+ from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
54
55
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
55
56
  from sglang.srt.model_loader.weight_utils import default_weight_loader
56
57
  from sglang.srt.models.llama import LlamaForCausalLM
@@ -251,7 +252,7 @@ def resample_patch_embed(
251
252
  try:
252
253
  from torch import vmap
253
254
  except ImportError:
254
- from functorch import vmap
255
+ from torch.func import vmap
255
256
 
256
257
  assert len(patch_embed.shape) == 4, "Four dimensions expected"
257
258
  assert len(new_size) == 2, "New shape should only be hw"
@@ -1083,7 +1084,7 @@ def create_siglip_vit(
1083
1084
  )
1084
1085
 
1085
1086
  if ckpt_path:
1086
- state_dict = torch.load(ckpt_path, map_location="cpu")
1087
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
1087
1088
 
1088
1089
  incompatible_keys = model.load_state_dict(state_dict, strict=False)
1089
1090
  print(
@@ -1289,7 +1290,7 @@ class MlpProjector(nn.Module):
1289
1290
  high_x, low_x = x_or_tuple
1290
1291
  high_x = self.high_up_proj(high_x)
1291
1292
  low_x = self.low_up_proj(low_x)
1292
- x = torch.concat([high_x, low_x], dim=-1)
1293
+ x = torch.cat([high_x, low_x], dim=-1)
1293
1294
  else:
1294
1295
  x = x_or_tuple
1295
1296
 
@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1958
1959
  )
1959
1960
  self.logits_processor = LogitsProcessor(config)
1960
1961
 
1961
- def prepare_images_seq_mask(
1962
- self, input_ids: torch.Tensor, image_inputs: ImageInputs
1963
- ) -> Optional[torch.LongTensor]:
1964
- images_seq_mask = torch.isin(
1965
- input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device)
1962
+ def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
1963
+ pixel_values = image_input.pixel_values
1964
+ bs, n = pixel_values.shape[0:2]
1965
+ pixel_values = pixel_values.to(
1966
+ device=self.vision_model.device, dtype=self.vision_model.dtype
1966
1967
  )
1967
- if images_seq_mask.sum() == 0:
1968
- # sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
1969
- return None
1970
- else:
1971
- return images_seq_mask
1968
+ images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
1969
+
1970
+ # [b x n, T2, D]
1971
+ images_embeds = self.aligner(self.vision_model(images))
1972
+
1973
+ # [b x n, T2, D] -> [b, n x T2, D]
1974
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
1975
+
1976
+ return images_embeds
1977
+
1978
+ def get_input_embeddings(self) -> nn.Embedding:
1979
+ return self.language_model.model.embed_tokens
1972
1980
 
1973
1981
  @torch.no_grad()
1974
1982
  def forward(
@@ -1978,90 +1986,25 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1978
1986
  forward_batch: ForwardBatch,
1979
1987
  ) -> torch.Tensor:
1980
1988
 
1981
- inputs_embeds = None
1982
- if (
1983
- forward_batch.image_inputs is not None
1984
- and len(forward_batch.image_inputs) != 0
1985
- and forward_batch.image_inputs[0] is not None
1986
- ):
1987
-
1988
- image_inputs = forward_batch.image_inputs[0]
1989
-
1990
- images_seq_mask = self.prepare_images_seq_mask(
1991
- input_ids=input_ids, image_inputs=image_inputs
1992
- )
1993
-
1994
- if images_seq_mask is not None:
1995
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
1996
- inputs_embeds = self.prepare_inputs_embeds(
1997
- input_ids=input_ids,
1998
- pixel_values=image_inputs.pixel_values,
1999
- images_seq_mask=images_seq_mask,
2000
- images_emb_mask=image_inputs.images_emb_mask,
2001
- )
2002
- input_ids = None
2003
-
2004
- if input_ids is not None:
2005
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
1989
+ inputs_embeds = general_mm_embed_routine(
1990
+ input_ids=input_ids,
1991
+ forward_batch=forward_batch,
1992
+ embed_tokens=self.get_input_embeddings(),
1993
+ mm_data_embedding_func=self.get_image_feature,
1994
+ )
2006
1995
 
2007
1996
  return self.language_model(
2008
- input_ids=input_ids,
1997
+ input_ids=None,
2009
1998
  positions=positions,
2010
1999
  forward_batch=forward_batch,
2011
2000
  input_embeds=inputs_embeds,
2012
2001
  get_embedding=False,
2013
2002
  )
2014
2003
 
2015
- def prepare_inputs_embeds(
2016
- self,
2017
- input_ids: torch.LongTensor,
2018
- pixel_values: torch.FloatTensor,
2019
- images_seq_mask: torch.LongTensor,
2020
- images_emb_mask: torch.BoolTensor,
2021
- **_kwargs,
2022
- ):
2023
- """
2024
-
2025
- Args:
2026
- input_ids (torch.LongTensor): [b, T]
2027
- pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
2028
- images_seq_mask (torch.BoolTensor): [b, T]
2029
- images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
2030
-
2031
- assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
2032
-
2033
- Returns:
2034
- input_embeds (torch.Tensor): [b, T, D]
2035
- """
2036
-
2037
- bs, n = pixel_values.shape[0:2]
2038
- pixel_values = pixel_values.to(
2039
- device=self.vision_model.device, dtype=self.vision_model.dtype
2040
- )
2041
- images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
2042
-
2043
- # [b x n, T2, D]
2044
- images_embeds = self.aligner(self.vision_model(images))
2045
-
2046
- # [b x n, T2, D] -> [b, n x T2, D]
2047
- images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
2048
- # [b, n, T2] -> [b, n x T2]
2049
- images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
2050
-
2051
- # [b, T, D]
2052
- # ignore the image embeddings
2053
- input_ids[input_ids < 0] = 0
2054
- inputs_embeds = self.language_model.model.embed_tokens(input_ids)
2055
-
2056
- # replace with the image embeddings
2057
- inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
2058
-
2059
- return inputs_embeds
2060
-
2061
2004
  def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
2062
2005
  return self.gen_aligner(self.gen_embed(image_ids))
2063
2006
 
2064
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
2007
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
2065
2008
  im_start_id = image_inputs.im_start_id
2066
2009
  im_end_id = image_inputs.im_end_id
2067
2010
  media_token_pairs = [(im_start_id, im_end_id)]
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig
21
- from vllm import _custom_ops as ops
22
21
 
23
22
  from sglang.srt.layers.layernorm import RMSNorm
24
23
  from sglang.srt.layers.linear import ReplicatedLinear
@@ -41,9 +40,15 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
41
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
42
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
44
- from sglang.srt.utils import add_prefix, is_hip
43
+ from sglang.srt.utils import add_prefix, is_cuda, is_hip
45
44
 
46
45
  _is_hip = is_hip()
46
+ _is_cuda = is_cuda()
47
+
48
+ if _is_cuda:
49
+ from sgl_kernel import awq_dequantize
50
+ else:
51
+ from vllm import _custom_ops as ops
47
52
 
48
53
 
49
54
  class DeepseekModelNextN(nn.Module):
@@ -261,14 +266,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
261
266
  self_attn = self.model.decoder.self_attn
262
267
  if hasattr(self_attn.kv_b_proj, "qweight"):
263
268
  # AWQ compatible
264
- w = ops.awq_dequantize(
265
- self_attn.kv_b_proj.qweight,
266
- self_attn.kv_b_proj.scales,
267
- self_attn.kv_b_proj.qzeros,
268
- 0,
269
- 0,
270
- 0,
271
- ).T
269
+ if _is_cuda:
270
+ w = awq_dequantize(
271
+ self_attn.kv_b_proj.qweight,
272
+ self_attn.kv_b_proj.scales,
273
+ self_attn.kv_b_proj.qzeros,
274
+ ).T
275
+ else:
276
+ w = ops.awq_dequantize(
277
+ self_attn.kv_b_proj.qweight,
278
+ self_attn.kv_b_proj.scales,
279
+ self_attn.kv_b_proj.qzeros,
280
+ 0,
281
+ 0,
282
+ 0,
283
+ ).T
272
284
  else:
273
285
  w = self_attn.kv_b_proj.weight
274
286
  # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.