sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +133 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +32 -21
  49. sglang/srt/layers/layernorm.py +24 -2
  50. sglang/srt/layers/linear.py +17 -5
  51. sglang/srt/layers/logits_processor.py +25 -7
  52. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  53. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  54. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  55. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  62. sglang/srt/layers/moe/topk.py +31 -18
  63. sglang/srt/layers/parameter.py +1 -1
  64. sglang/srt/layers/quantization/__init__.py +184 -126
  65. sglang/srt/layers/quantization/base_config.py +5 -0
  66. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  67. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  68. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  69. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  70. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  71. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  73. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  74. sglang/srt/layers/quantization/fp8.py +76 -34
  75. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  76. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  77. sglang/srt/layers/quantization/gptq.py +36 -9
  78. sglang/srt/layers/quantization/kv_cache.py +98 -0
  79. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  80. sglang/srt/layers/quantization/utils.py +153 -0
  81. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  82. sglang/srt/layers/rotary_embedding.py +66 -87
  83. sglang/srt/layers/sampler.py +1 -1
  84. sglang/srt/lora/layers.py +68 -0
  85. sglang/srt/lora/lora.py +2 -22
  86. sglang/srt/lora/lora_manager.py +47 -23
  87. sglang/srt/lora/mem_pool.py +110 -51
  88. sglang/srt/lora/utils.py +12 -1
  89. sglang/srt/managers/cache_controller.py +2 -5
  90. sglang/srt/managers/data_parallel_controller.py +30 -8
  91. sglang/srt/managers/expert_distribution.py +81 -0
  92. sglang/srt/managers/io_struct.py +39 -3
  93. sglang/srt/managers/mm_utils.py +373 -0
  94. sglang/srt/managers/multimodal_processor.py +68 -0
  95. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  96. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  97. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  98. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  99. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  100. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  101. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  102. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  103. sglang/srt/managers/schedule_batch.py +133 -30
  104. sglang/srt/managers/scheduler.py +273 -20
  105. sglang/srt/managers/session_controller.py +1 -1
  106. sglang/srt/managers/tokenizer_manager.py +59 -23
  107. sglang/srt/managers/tp_worker.py +1 -1
  108. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  109. sglang/srt/managers/utils.py +6 -1
  110. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  111. sglang/srt/mem_cache/memory_pool.py +255 -98
  112. sglang/srt/mem_cache/paged_allocator.py +2 -2
  113. sglang/srt/mem_cache/radix_cache.py +4 -4
  114. sglang/srt/model_executor/cuda_graph_runner.py +27 -13
  115. sglang/srt/model_executor/forward_batch_info.py +68 -11
  116. sglang/srt/model_executor/model_runner.py +70 -6
  117. sglang/srt/model_loader/loader.py +160 -2
  118. sglang/srt/model_loader/weight_utils.py +45 -0
  119. sglang/srt/models/deepseek_janus_pro.py +29 -86
  120. sglang/srt/models/deepseek_nextn.py +22 -10
  121. sglang/srt/models/deepseek_v2.py +208 -77
  122. sglang/srt/models/deepseek_vl2.py +358 -0
  123. sglang/srt/models/gemma3_causal.py +684 -0
  124. sglang/srt/models/gemma3_mm.py +462 -0
  125. sglang/srt/models/llama.py +47 -7
  126. sglang/srt/models/llama_eagle.py +1 -0
  127. sglang/srt/models/llama_eagle3.py +196 -0
  128. sglang/srt/models/llava.py +3 -3
  129. sglang/srt/models/llavavid.py +3 -3
  130. sglang/srt/models/minicpmo.py +1995 -0
  131. sglang/srt/models/minicpmv.py +62 -137
  132. sglang/srt/models/mllama.py +4 -4
  133. sglang/srt/models/phi3_small.py +1 -1
  134. sglang/srt/models/qwen2.py +3 -0
  135. sglang/srt/models/qwen2_5_vl.py +68 -146
  136. sglang/srt/models/qwen2_classification.py +75 -0
  137. sglang/srt/models/qwen2_moe.py +9 -1
  138. sglang/srt/models/qwen2_vl.py +25 -63
  139. sglang/srt/openai_api/adapter.py +124 -28
  140. sglang/srt/openai_api/protocol.py +23 -2
  141. sglang/srt/sampling/sampling_batch_info.py +1 -1
  142. sglang/srt/sampling/sampling_params.py +6 -6
  143. sglang/srt/server_args.py +99 -9
  144. sglang/srt/speculative/build_eagle_tree.py +7 -347
  145. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  146. sglang/srt/speculative/eagle_utils.py +208 -252
  147. sglang/srt/speculative/eagle_worker.py +139 -53
  148. sglang/srt/speculative/spec_info.py +6 -1
  149. sglang/srt/torch_memory_saver_adapter.py +22 -0
  150. sglang/srt/utils.py +182 -21
  151. sglang/test/__init__.py +0 -0
  152. sglang/test/attention/__init__.py +0 -0
  153. sglang/test/attention/test_flashattn_backend.py +312 -0
  154. sglang/test/runners.py +2 -0
  155. sglang/test/test_activation.py +2 -1
  156. sglang/test/test_block_fp8.py +5 -4
  157. sglang/test/test_block_fp8_ep.py +2 -1
  158. sglang/test/test_dynamic_grad_mode.py +58 -0
  159. sglang/test/test_layernorm.py +3 -2
  160. sglang/test/test_utils.py +55 -4
  161. sglang/utils.py +31 -0
  162. sglang/version.py +1 -1
  163. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  164. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
  165. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  166. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  167. sglang/srt/managers/image_processor.py +0 -55
  168. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  169. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  170. sglang/srt/managers/multi_modality_padding.py +0 -134
  171. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  172. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,196 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ from sglang.srt.utils import add_prefix
17
+
18
+ # Adapted from
19
+ # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
20
+ """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
21
+
22
+ from typing import Iterable, Optional, Tuple
23
+
24
+ import torch
25
+ from torch import nn
26
+ from transformers import LlamaConfig
27
+
28
+ from sglang.srt.layers.layernorm import RMSNorm
29
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
30
+ from sglang.srt.layers.logits_processor import LogitsProcessor
31
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
32
+ from sglang.srt.layers.vocab_parallel_embedding import (
33
+ ParallelLMHead,
34
+ VocabParallelEmbedding,
35
+ )
36
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
+ from sglang.srt.models.llama import LlamaAttention, LlamaDecoderLayer, LlamaForCausalLM
38
+
39
+
40
+ class LlamaDecoderLayer(LlamaDecoderLayer):
41
+ def __init__(
42
+ self,
43
+ config: LlamaConfig,
44
+ layer_id: int = 0,
45
+ quant_config: Optional[QuantizationConfig] = None,
46
+ prefix: str = "",
47
+ ) -> None:
48
+ super().__init__(config, layer_id, quant_config, prefix)
49
+
50
+ # override qkv
51
+ self.self_attn.qkv_proj = QKVParallelLinear(
52
+ 2 * self.hidden_size,
53
+ self.self_attn.head_dim,
54
+ self.self_attn.total_num_heads,
55
+ self.self_attn.total_num_kv_heads,
56
+ bias=False,
57
+ quant_config=quant_config,
58
+ prefix=add_prefix("qkv_proj", prefix),
59
+ )
60
+
61
+ self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
62
+
63
+ def forward(
64
+ self,
65
+ positions: torch.Tensor,
66
+ embeds: torch.Tensor,
67
+ hidden_states: torch.Tensor,
68
+ forward_batch: ForwardBatch,
69
+ residual: Optional[torch.Tensor],
70
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
71
+
72
+ residual = hidden_states
73
+ embeds = self.input_layernorm(embeds)
74
+ hidden_states = self.hidden_norm(hidden_states)
75
+
76
+ hidden_states = torch.cat([embeds, hidden_states], dim=-1)
77
+ # Self Attention
78
+ hidden_states = self.self_attn(
79
+ positions=positions,
80
+ hidden_states=hidden_states,
81
+ forward_batch=forward_batch,
82
+ )
83
+
84
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
85
+
86
+ # Fully Connected
87
+ hidden_states = self.mlp(hidden_states)
88
+
89
+ return hidden_states, residual
90
+
91
+
92
+ class LlamaModel(nn.Module):
93
+ def __init__(
94
+ self,
95
+ config: LlamaConfig,
96
+ quant_config: Optional[QuantizationConfig] = None,
97
+ prefix: str = "",
98
+ ) -> None:
99
+ super().__init__()
100
+ self.config = config
101
+ self.vocab_size = config.vocab_size
102
+ self.embed_tokens = VocabParallelEmbedding(
103
+ config.vocab_size,
104
+ config.hidden_size,
105
+ prefix=add_prefix("embed_tokens", prefix),
106
+ )
107
+ self.midlayer = LlamaDecoderLayer(config, 0, quant_config, prefix)
108
+ if hasattr(config, "target_hidden_size"):
109
+ self.fc = torch.nn.Linear(config.target_hidden_size * 3, config.hidden_size)
110
+ else:
111
+ self.fc = torch.nn.Linear(config.hidden_size * 3, config.hidden_size)
112
+
113
+ self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
114
+
115
+ def forward(
116
+ self,
117
+ input_ids: torch.Tensor,
118
+ positions: torch.Tensor,
119
+ forward_batch: ForwardBatch,
120
+ input_embeds: torch.Tensor = None,
121
+ ) -> torch.Tensor:
122
+ if input_embeds is None:
123
+ embeds = self.embed_tokens(input_ids)
124
+ else:
125
+ embeds = input_embeds
126
+
127
+ hidden_states = forward_batch.spec_info.hidden_states
128
+ if hidden_states.shape[-1] != embeds.shape[-1]:
129
+ hidden_states = self.fc(hidden_states)
130
+
131
+ residual = None
132
+ hidden_states, residual = self.midlayer(
133
+ positions,
134
+ embeds,
135
+ hidden_states,
136
+ forward_batch,
137
+ residual,
138
+ )
139
+
140
+ hidden_states_to_logits, hidden_states_to_aux = self.norm(
141
+ hidden_states, residual
142
+ )
143
+
144
+ # For draft decode, we capture the hidden state before norm
145
+ return hidden_states_to_logits, [hidden_states_to_aux]
146
+
147
+
148
+ class LlamaForCausalLMEagle3(LlamaForCausalLM):
149
+ def __init__(
150
+ self,
151
+ config: LlamaConfig,
152
+ quant_config: Optional[QuantizationConfig] = None,
153
+ prefix: str = "",
154
+ ) -> None:
155
+ nn.Module.__init__(self)
156
+ self.config = config
157
+ self.quant_config = quant_config
158
+
159
+ if self.config.num_hidden_layers != 1:
160
+ raise ValueError("EAGLE3 currently only supports 1 layer")
161
+
162
+ self.model = LlamaModel(
163
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
164
+ )
165
+ # Llama 3.2 1B Instruct set tie_word_embeddings to True
166
+ # Llama 3.1 8B Instruct set tie_word_embeddings to False
167
+ if self.config.tie_word_embeddings:
168
+ self.lm_head = self.model.embed_tokens
169
+ else:
170
+ self.lm_head = ParallelLMHead(
171
+ config.draft_vocab_size,
172
+ config.hidden_size,
173
+ quant_config=quant_config,
174
+ prefix=add_prefix("lm_head", prefix),
175
+ )
176
+
177
+ self.logits_processor = LogitsProcessor(config)
178
+ self.capture_aux_hidden_states = True
179
+
180
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
181
+ for name, loaded_weight in weights:
182
+ if "d2t" in name:
183
+ # d2t stores diffs between draft id and target id
184
+ self.hot_token_id = loaded_weight + torch.arange(loaded_weight.shape[0])
185
+
186
+ if "d2t" not in name and "t2d" not in name and "lm_head" not in name:
187
+ new_name = f"model.{name}"
188
+ super().load_weights([(new_name, loaded_weight)])
189
+ elif "lm_head" in name:
190
+ super().load_weights([(name, loaded_weight)])
191
+
192
+ def get_hot_token_id(self):
193
+ return self.hot_token_id
194
+
195
+
196
+ EntryClass = [LlamaForCausalLMEagle3]
@@ -31,7 +31,7 @@ from transformers import (
31
31
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
32
32
 
33
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
- from sglang.srt.managers.schedule_batch import ImageInputs
34
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
35
35
  from sglang.srt.mm_utils import (
36
36
  get_anyres_image_grid_shape,
37
37
  unpad_image,
@@ -46,7 +46,7 @@ from sglang.srt.utils import add_prefix
46
46
 
47
47
 
48
48
  class LlavaBaseForCausalLM(nn.Module):
49
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
49
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
50
50
  image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
51
51
 
52
52
  # hardcode for spatial_unpad + anyres
@@ -134,7 +134,7 @@ class LlavaBaseForCausalLM(nn.Module):
134
134
  positions: torch.Tensor,
135
135
  forward_batch: ForwardBatch,
136
136
  ) -> torch.Tensor:
137
- image_inputs = forward_batch.image_inputs
137
+ image_inputs = forward_batch.mm_inputs
138
138
 
139
139
  if forward_batch.forward_mode.is_extend():
140
140
  # Clamp input ids. This is because the input_ids for the image tokens are
@@ -22,7 +22,7 @@ from transformers import CLIPVisionModel, LlavaConfig
22
22
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
23
23
 
24
24
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
25
- from sglang.srt.managers.schedule_batch import ImageInputs
25
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
26
26
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
27
27
  from sglang.srt.model_loader.weight_utils import default_weight_loader
28
28
  from sglang.srt.models.llama import LlamaForCausalLM
@@ -57,7 +57,7 @@ class LlavaVidForCausalLM(nn.Module):
57
57
  torch.empty(config.text_config.hidden_size, dtype=torch.float16)
58
58
  )
59
59
 
60
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
60
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
61
61
  pad_values = image_inputs.pad_values
62
62
  new_image_feature_len = self.image_feature_len
63
63
 
@@ -112,7 +112,7 @@ class LlavaVidForCausalLM(nn.Module):
112
112
  positions: torch.Tensor,
113
113
  forward_batch: ForwardBatch,
114
114
  ) -> torch.Tensor:
115
- image_inputs = forward_batch.image_inputs
115
+ image_inputs = forward_batch.mm_inputs
116
116
  if forward_batch.forward_mode.is_extend():
117
117
  bs = forward_batch.batch_size
118
118