sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +1 -11
- sglang/bench_serving.py +149 -1
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +17 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +30 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +14 -2
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/hf_transformers_utils.py +20 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +23 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +5 -0
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/lora/lora_manager.py +10 -13
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/schedule_batch.py +19 -1
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +28 -13
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +9 -12
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +2 -2
- sglang/srt/model_executor/model_runner.py +44 -33
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +1 -20
- sglang/srt/models/deepseek_v2.py +55 -20
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +53 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +24 -40
- sglang/srt/openai_api/protocol.py +28 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +30 -6
- sglang/srt/utils.py +35 -1
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama4.py
CHANGED
@@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
|
46
46
|
from sglang.srt.layers.rotary_embedding import get_rope
|
47
47
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
48
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
|
-
from sglang.srt.model_executor.forward_batch_info import
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
50
|
+
ForwardBatch,
|
51
|
+
ForwardMode,
|
52
|
+
PPProxyTensors,
|
53
|
+
)
|
50
54
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
|
51
55
|
from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
|
52
56
|
|
@@ -81,6 +85,7 @@ class Llama4MoE(nn.Module):
|
|
81
85
|
super().__init__()
|
82
86
|
self.tp_size = get_tensor_model_parallel_world_size()
|
83
87
|
self.top_k = config.num_experts_per_tok
|
88
|
+
self.device_module = torch.get_device_module()
|
84
89
|
|
85
90
|
intermediate_size_moe = config.intermediate_size
|
86
91
|
self.router = ReplicatedLinear(
|
@@ -113,7 +118,25 @@ class Llama4MoE(nn.Module):
|
|
113
118
|
reduce_results=False, # We need to do scatter before reduce
|
114
119
|
)
|
115
120
|
|
116
|
-
def forward(self, hidden_states):
|
121
|
+
def forward(self, hidden_states, forward_batch: ForwardBatch):
|
122
|
+
shared_out, routed_out = self._forward_core(
|
123
|
+
hidden_states, forward_batch.forward_mode
|
124
|
+
)
|
125
|
+
|
126
|
+
out_aD = routed_out + shared_out
|
127
|
+
|
128
|
+
if self.tp_size > 1:
|
129
|
+
out_aD = tensor_model_parallel_all_reduce(out_aD)
|
130
|
+
|
131
|
+
return out_aD
|
132
|
+
|
133
|
+
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
|
134
|
+
if hidden_states.shape[0] < 4:
|
135
|
+
return self._forward_core_shared_routed_overlap(hidden_states)
|
136
|
+
else:
|
137
|
+
return self._forward_core_normal(hidden_states)
|
138
|
+
|
139
|
+
def _forward_core_normal(self, hidden_states):
|
117
140
|
# router_scores: [num_tokens, num_experts]
|
118
141
|
router_logits, _ = self.router(hidden_states)
|
119
142
|
shared_out = self.shared_expert(hidden_states)
|
@@ -121,12 +144,35 @@ class Llama4MoE(nn.Module):
|
|
121
144
|
hidden_states=hidden_states,
|
122
145
|
router_logits=router_logits,
|
123
146
|
)
|
124
|
-
|
147
|
+
return shared_out, routed_out
|
125
148
|
|
126
|
-
|
127
|
-
|
149
|
+
def _forward_core_shared_routed_overlap(self, hidden_states):
|
150
|
+
alt_stream = _get_or_create_alt_stream(self.device_module)
|
128
151
|
|
129
|
-
|
152
|
+
alt_stream.wait_stream(self.device_module.current_stream())
|
153
|
+
|
154
|
+
shared_out = self.shared_expert(hidden_states)
|
155
|
+
|
156
|
+
with self.device_module.stream(alt_stream):
|
157
|
+
# router_scores: [num_tokens, num_experts]
|
158
|
+
router_logits, _ = self.router(hidden_states)
|
159
|
+
routed_out = self.experts(
|
160
|
+
hidden_states=hidden_states,
|
161
|
+
router_logits=router_logits,
|
162
|
+
)
|
163
|
+
self.device_module.current_stream().wait_stream(alt_stream)
|
164
|
+
|
165
|
+
return shared_out, routed_out
|
166
|
+
|
167
|
+
|
168
|
+
_alt_stream = None
|
169
|
+
|
170
|
+
|
171
|
+
def _get_or_create_alt_stream(device_module):
|
172
|
+
global _alt_stream
|
173
|
+
if _alt_stream is None:
|
174
|
+
_alt_stream = device_module.Stream()
|
175
|
+
return _alt_stream
|
130
176
|
|
131
177
|
|
132
178
|
class Llama4Attention(nn.Module):
|
@@ -380,7 +426,7 @@ class Llama4DecoderLayer(nn.Module):
|
|
380
426
|
)
|
381
427
|
|
382
428
|
# Fully Connected
|
383
|
-
hidden_states = self.feed_forward(hidden_states)
|
429
|
+
hidden_states = self.feed_forward(hidden_states, forward_batch)
|
384
430
|
|
385
431
|
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
|
386
432
|
# Scatter
|
sglang/srt/models/minicpmv.py
CHANGED
@@ -197,7 +197,7 @@ class Idefics2EncoderLayer(nn.Module):
|
|
197
197
|
use_qkv_parallel=True,
|
198
198
|
quant_config=quant_config,
|
199
199
|
dropout=config.attention_dropout,
|
200
|
-
|
200
|
+
qkv_backend="sdpa",
|
201
201
|
softmax_in_single_precision=True,
|
202
202
|
flatten_batch=False,
|
203
203
|
prefix=add_prefix("self_attn", prefix),
|
sglang/srt/models/mllama.py
CHANGED
@@ -203,7 +203,7 @@ class MllamaVisionEncoderLayer(nn.Module):
|
|
203
203
|
use_qkv_parallel=True,
|
204
204
|
quant_config=quant_config,
|
205
205
|
dropout=0.0,
|
206
|
-
|
206
|
+
qkv_backend="sdpa",
|
207
207
|
softmax_in_single_precision=False,
|
208
208
|
flatten_batch=False,
|
209
209
|
prefix=add_prefix("self_attn", prefix),
|
sglang/srt/models/phi3_small.py
CHANGED
@@ -6,7 +6,7 @@ from torch import nn
|
|
6
6
|
from transformers import Phi3Config
|
7
7
|
from transformers.configuration_utils import PretrainedConfig
|
8
8
|
|
9
|
-
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
9
|
+
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
10
10
|
from sglang.srt.layers.linear import (
|
11
11
|
MergedColumnParallelLinear,
|
12
12
|
QKVParallelLinear,
|
@@ -17,6 +17,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
17
17
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
18
18
|
from sglang.srt.layers.radix_attention import RadixAttention
|
19
19
|
from sglang.srt.layers.rotary_embedding import get_rope
|
20
|
+
from sglang.srt.layers.utils import PPMissingLayer
|
20
21
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
21
22
|
DEFAULT_VOCAB_PADDING_SIZE,
|
22
23
|
ParallelLMHead,
|
@@ -294,13 +295,24 @@ class Phi3SmallModel(nn.Module):
|
|
294
295
|
super().__init__()
|
295
296
|
|
296
297
|
self.config = config
|
298
|
+
|
299
|
+
self.pp_group = get_pp_group()
|
300
|
+
if self.pp_group.is_first_rank:
|
301
|
+
self.embed_tokens = VocabParallelEmbedding(
|
302
|
+
config.vocab_size,
|
303
|
+
config.hidden_size,
|
304
|
+
prefix=add_prefix("embed_tokens", prefix),
|
305
|
+
)
|
306
|
+
else:
|
307
|
+
self.embed_tokens = PPMissingLayer()
|
308
|
+
|
297
309
|
self.embed_tokens = VocabParallelEmbedding(
|
298
310
|
config.vocab_size,
|
299
311
|
config.hidden_size,
|
300
312
|
prefix=add_prefix("embed_tokens", prefix),
|
301
313
|
)
|
302
314
|
self.mup_embedding_multiplier = config.mup_embedding_multiplier
|
303
|
-
self.
|
315
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
304
316
|
config.num_hidden_layers,
|
305
317
|
lambda idx, prefix: Phi3SmallDecoderLayer(
|
306
318
|
config,
|
@@ -308,6 +320,8 @@ class Phi3SmallModel(nn.Module):
|
|
308
320
|
quant_config,
|
309
321
|
prefix=prefix,
|
310
322
|
),
|
323
|
+
pp_rank=self.pp_group.rank_in_group,
|
324
|
+
pp_size=self.pp_group.world_size,
|
311
325
|
prefix=add_prefix("layers", prefix),
|
312
326
|
)
|
313
327
|
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -125,16 +125,20 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
125
125
|
self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
|
126
126
|
self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
|
127
127
|
if attn_implementation == "sdpa":
|
128
|
-
use_context_forward = False
|
129
128
|
softmax_in_single_precision = False
|
129
|
+
qkv_backend = "sdpa"
|
130
130
|
flatten_batch = True
|
131
131
|
elif attn_implementation == "flash_attention_2":
|
132
132
|
softmax_in_single_precision = False
|
133
|
-
|
133
|
+
qkv_backend = "triton_attn"
|
134
134
|
flatten_batch = True
|
135
135
|
elif attn_implementation == "eager":
|
136
136
|
softmax_in_single_precision = True
|
137
|
-
|
137
|
+
qkv_backend = "sdpa"
|
138
|
+
flatten_batch = True
|
139
|
+
elif attn_implementation == "flash_attention_3":
|
140
|
+
softmax_in_single_precision = False
|
141
|
+
qkv_backend = "fa3"
|
138
142
|
flatten_batch = True
|
139
143
|
|
140
144
|
self.attn = VisionAttention(
|
@@ -142,7 +146,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
142
146
|
num_heads=num_heads,
|
143
147
|
projection_size=dim,
|
144
148
|
use_qkv_parallel=True,
|
145
|
-
|
149
|
+
qkv_backend=qkv_backend,
|
146
150
|
softmax_in_single_precision=softmax_in_single_precision,
|
147
151
|
flatten_batch=flatten_batch,
|
148
152
|
quant_config=quant_config,
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -139,21 +139,21 @@ class Qwen2VisionBlock(nn.Module):
|
|
139
139
|
self.norm2 = norm_layer(dim)
|
140
140
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
141
141
|
if attn_implementation == "sdpa":
|
142
|
-
|
142
|
+
qkv_backend = "sdpa"
|
143
143
|
softmax_in_single_precision = False
|
144
144
|
elif attn_implementation == "flash_attention_2":
|
145
|
+
qkv_backend = "triton_attn"
|
145
146
|
softmax_in_single_precision = False
|
146
|
-
use_context_forward = True
|
147
147
|
elif attn_implementation == "eager":
|
148
|
+
qkv_backend = "sdpa"
|
148
149
|
softmax_in_single_precision = True
|
149
|
-
use_context_forward = False
|
150
150
|
|
151
151
|
self.attn = VisionAttention(
|
152
152
|
embed_dim=dim,
|
153
153
|
num_heads=num_heads,
|
154
154
|
projection_size=dim,
|
155
155
|
use_qkv_parallel=True,
|
156
|
-
|
156
|
+
qkv_backend=qkv_backend,
|
157
157
|
softmax_in_single_precision=softmax_in_single_precision,
|
158
158
|
flatten_batch=True,
|
159
159
|
quant_config=quant_config,
|
@@ -0,0 +1,171 @@
|
|
1
|
+
# Adapted from qwen2.py
|
2
|
+
|
3
|
+
from functools import partial
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
|
9
|
+
from sglang.srt.distributed import (
|
10
|
+
get_tensor_model_parallel_rank,
|
11
|
+
get_tensor_model_parallel_world_size,
|
12
|
+
split_tensor_along_last_dim,
|
13
|
+
tensor_model_parallel_all_gather,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
16
|
+
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
17
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
18
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
19
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
20
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
21
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
22
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
23
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
25
|
+
from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2Model
|
26
|
+
from sglang.srt.utils import add_prefix
|
27
|
+
|
28
|
+
MiMoConfig = None
|
29
|
+
|
30
|
+
|
31
|
+
class MiMoModel(Qwen2Model):
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
config: MiMoConfig,
|
35
|
+
quant_config: Optional[QuantizationConfig] = None,
|
36
|
+
prefix: str = "",
|
37
|
+
) -> None:
|
38
|
+
super().__init__(
|
39
|
+
config=config,
|
40
|
+
quant_config=quant_config,
|
41
|
+
prefix=prefix,
|
42
|
+
decoder_layer_type=Qwen2DecoderLayer,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class MiMoForCausalLM(nn.Module):
|
47
|
+
# BitandBytes specific attributes
|
48
|
+
default_bitsandbytes_target_modules = [
|
49
|
+
".gate_proj.",
|
50
|
+
".down_proj.",
|
51
|
+
".up_proj.",
|
52
|
+
".q_proj.",
|
53
|
+
".k_proj.",
|
54
|
+
".v_proj.",
|
55
|
+
".o_proj.",
|
56
|
+
]
|
57
|
+
bitsandbytes_stacked_params_mapping = {
|
58
|
+
# shard_name, weight_name, index
|
59
|
+
"q_proj": ("qkv_proj", 0),
|
60
|
+
"k_proj": ("qkv_proj", 1),
|
61
|
+
"v_proj": ("qkv_proj", 2),
|
62
|
+
"gate_proj": ("gate_up_proj", 0),
|
63
|
+
"up_proj": ("gate_up_proj", 1),
|
64
|
+
}
|
65
|
+
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
config: MiMoConfig,
|
69
|
+
quant_config: Optional[QuantizationConfig] = None,
|
70
|
+
prefix: str = "",
|
71
|
+
) -> None:
|
72
|
+
super().__init__()
|
73
|
+
self.config = config
|
74
|
+
self.quant_config = quant_config
|
75
|
+
self.model = MiMoModel(
|
76
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
77
|
+
)
|
78
|
+
if config.tie_word_embeddings:
|
79
|
+
self.lm_head = self.model.embed_tokens
|
80
|
+
else:
|
81
|
+
self.lm_head = ParallelLMHead(
|
82
|
+
config.vocab_size,
|
83
|
+
config.hidden_size,
|
84
|
+
quant_config=quant_config,
|
85
|
+
prefix=add_prefix("lm_head", prefix),
|
86
|
+
)
|
87
|
+
self.logits_processor = LogitsProcessor(config)
|
88
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
89
|
+
|
90
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
91
|
+
return self.model.get_input_embeddings(input_ids)
|
92
|
+
|
93
|
+
@torch.no_grad()
|
94
|
+
def forward(
|
95
|
+
self,
|
96
|
+
input_ids: torch.Tensor,
|
97
|
+
positions: torch.Tensor,
|
98
|
+
forward_batch: ForwardBatch,
|
99
|
+
input_embeds: torch.Tensor = None,
|
100
|
+
get_embedding: bool = False,
|
101
|
+
) -> torch.Tensor:
|
102
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
103
|
+
if not get_embedding:
|
104
|
+
return self.logits_processor(
|
105
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
return self.pooler(hidden_states, forward_batch)
|
109
|
+
|
110
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
111
|
+
stacked_params_mapping = [
|
112
|
+
# (param_name, shard_name, shard_id)
|
113
|
+
("qkv_proj", "q_proj", "q"),
|
114
|
+
("qkv_proj", "k_proj", "k"),
|
115
|
+
("qkv_proj", "v_proj", "v"),
|
116
|
+
("gate_up_proj", "gate_proj", 0),
|
117
|
+
("gate_up_proj", "up_proj", 1),
|
118
|
+
]
|
119
|
+
|
120
|
+
params_dict = dict(self.named_parameters())
|
121
|
+
for name, loaded_weight in weights:
|
122
|
+
if (
|
123
|
+
"rotary_emb.inv_freq" in name
|
124
|
+
or "projector" in name
|
125
|
+
or "mtp_layers" in name
|
126
|
+
):
|
127
|
+
continue
|
128
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
129
|
+
# Models trained using ColossalAI may include these tensors in
|
130
|
+
# the checkpoint. Skip them.
|
131
|
+
continue
|
132
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
133
|
+
continue
|
134
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
135
|
+
continue
|
136
|
+
|
137
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
138
|
+
if weight_name not in name:
|
139
|
+
continue
|
140
|
+
name = name.replace(weight_name, param_name)
|
141
|
+
# Skip loading extra bias for GPTQ models.
|
142
|
+
if name.endswith(".bias") and name not in params_dict:
|
143
|
+
continue
|
144
|
+
param = params_dict[name]
|
145
|
+
weight_loader = param.weight_loader
|
146
|
+
weight_loader(param, loaded_weight, shard_id)
|
147
|
+
break
|
148
|
+
else:
|
149
|
+
# Skip loading extra bias for GPTQ models.
|
150
|
+
if name.endswith(".bias") and name not in params_dict:
|
151
|
+
continue
|
152
|
+
param = params_dict[name]
|
153
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
154
|
+
weight_loader(param, loaded_weight)
|
155
|
+
|
156
|
+
def get_embed_and_head(self):
|
157
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
158
|
+
|
159
|
+
def set_embed_and_head(self, embed, head):
|
160
|
+
del self.model.embed_tokens.weight
|
161
|
+
del self.lm_head.weight
|
162
|
+
self.model.embed_tokens.weight = embed
|
163
|
+
self.lm_head.weight = head
|
164
|
+
torch.cuda.empty_cache()
|
165
|
+
torch.cuda.synchronize()
|
166
|
+
|
167
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
168
|
+
self.model.load_kv_cache_scales(quantization_param_path)
|
169
|
+
|
170
|
+
|
171
|
+
EntryClass = MiMoForCausalLM
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Conversion between OpenAI APIs and native SRT APIs"""
|
15
15
|
|
16
16
|
import asyncio
|
17
|
+
import base64
|
17
18
|
import json
|
18
19
|
import logging
|
19
20
|
import os
|
@@ -528,6 +529,7 @@ def v1_generate_request(
|
|
528
529
|
"temperature": request.temperature,
|
529
530
|
"max_new_tokens": request.max_tokens,
|
530
531
|
"min_new_tokens": request.min_tokens,
|
532
|
+
"thinking_budget": request.thinking_budget,
|
531
533
|
"stop": request.stop,
|
532
534
|
"stop_token_ids": request.stop_token_ids,
|
533
535
|
"top_p": request.top_p,
|
@@ -966,47 +968,23 @@ def v1_chat_generate_request(
|
|
966
968
|
|
967
969
|
if chat_template_name is None:
|
968
970
|
openai_compatible_messages = []
|
969
|
-
if (
|
970
|
-
tools
|
971
|
-
and tokenizer_manager.server_args.tool_call_parser == "deepseekv3"
|
972
|
-
):
|
973
|
-
# add function call prompt to deepseekv3
|
974
|
-
openai_compatible_messages.append(
|
975
|
-
{
|
976
|
-
"role": "system",
|
977
|
-
"content": """You are a helpful Assistant.
|
978
|
-
## Tools
|
979
|
-
### Function
|
980
|
-
You have the following functions available:
|
981
|
-
"""
|
982
|
-
+ "".join(
|
983
|
-
[
|
984
|
-
f"""
|
985
|
-
- `{tool['name']}`:
|
986
|
-
```json
|
987
|
-
{json.dumps(tool)}
|
988
|
-
```
|
989
|
-
"""
|
990
|
-
for tool in tools
|
991
|
-
]
|
992
|
-
),
|
993
|
-
}
|
994
|
-
)
|
995
971
|
|
996
972
|
for message in request.messages:
|
997
973
|
if message.content is None:
|
998
974
|
message.content = ""
|
999
|
-
|
1000
|
-
|
1001
|
-
|
1002
|
-
|
975
|
+
msg_dict = message.dict()
|
976
|
+
if isinstance(msg_dict.get("content"), list):
|
977
|
+
for chunk in msg_dict["content"]:
|
978
|
+
if isinstance(chunk, dict) and chunk.get("type") == "text":
|
979
|
+
new_msg = msg_dict.copy()
|
980
|
+
new_msg["content"] = chunk["text"]
|
981
|
+
new_msg = {
|
982
|
+
k: v for k, v in new_msg.items() if v is not None
|
983
|
+
}
|
984
|
+
openai_compatible_messages.append(new_msg)
|
1003
985
|
else:
|
1004
|
-
|
1005
|
-
|
1006
|
-
if content["type"] == "text":
|
1007
|
-
openai_compatible_messages.append(
|
1008
|
-
{"role": message.role, "content": content["text"]}
|
1009
|
-
)
|
986
|
+
msg_dict = {k: v for k, v in msg_dict.items() if v is not None}
|
987
|
+
openai_compatible_messages.append(msg_dict)
|
1010
988
|
if (
|
1011
989
|
openai_compatible_messages
|
1012
990
|
and openai_compatible_messages[-1]["role"] == "assistant"
|
@@ -1124,6 +1102,7 @@ def v1_chat_generate_request(
|
|
1124
1102
|
"temperature": request.temperature,
|
1125
1103
|
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
|
1126
1104
|
"min_new_tokens": request.min_tokens,
|
1105
|
+
"thinking_budget": request.thinking_budget,
|
1127
1106
|
"stop": stop,
|
1128
1107
|
"stop_token_ids": request.stop_token_ids,
|
1129
1108
|
"top_p": request.top_p,
|
@@ -1316,7 +1295,8 @@ def v1_chat_generate_response(
|
|
1316
1295
|
text, call_info_list = parser.parse_non_stream(text)
|
1317
1296
|
tool_calls = [
|
1318
1297
|
ToolCall(
|
1319
|
-
id=
|
1298
|
+
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
1299
|
+
index=call_info.tool_index,
|
1320
1300
|
function=FunctionResponse(
|
1321
1301
|
name=call_info.name, arguments=call_info.parameters
|
1322
1302
|
),
|
@@ -1432,6 +1412,7 @@ async def v1_chat_completions(
|
|
1432
1412
|
reasoning_parser_dict = {}
|
1433
1413
|
|
1434
1414
|
async def generate_stream_resp():
|
1415
|
+
tool_call_first = True
|
1435
1416
|
is_firsts = {}
|
1436
1417
|
stream_buffers = {}
|
1437
1418
|
n_prev_tokens = {}
|
@@ -1598,7 +1579,6 @@ async def v1_chat_completions(
|
|
1598
1579
|
# 2) if we found calls, we output them as separate chunk(s)
|
1599
1580
|
for call_item in calls:
|
1600
1581
|
# transform call_item -> FunctionResponse + ToolCall
|
1601
|
-
|
1602
1582
|
if finish_reason_type == "stop":
|
1603
1583
|
latest_delta_len = 0
|
1604
1584
|
if isinstance(call_item.parameters, str):
|
@@ -1621,15 +1601,19 @@ async def v1_chat_completions(
|
|
1621
1601
|
call_item.parameters = remaining_call
|
1622
1602
|
|
1623
1603
|
finish_reason_type = "tool_calls"
|
1624
|
-
|
1625
1604
|
tool_call = ToolCall(
|
1626
|
-
id=
|
1605
|
+
id=(
|
1606
|
+
f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
|
1607
|
+
if tool_call_first
|
1608
|
+
else None
|
1609
|
+
),
|
1627
1610
|
index=call_item.tool_index,
|
1628
1611
|
function=FunctionResponse(
|
1629
1612
|
name=call_item.name,
|
1630
1613
|
arguments=call_item.parameters,
|
1631
1614
|
),
|
1632
1615
|
)
|
1616
|
+
tool_call_first = False
|
1633
1617
|
choice_data = ChatCompletionResponseStreamChoice(
|
1634
1618
|
index=index,
|
1635
1619
|
delta=DeltaMessage(tool_calls=[tool_call]),
|
@@ -172,6 +172,7 @@ class CompletionRequest(BaseModel):
|
|
172
172
|
top_k: int = -1
|
173
173
|
min_p: float = 0.0
|
174
174
|
min_tokens: int = 0
|
175
|
+
thinking_budget: Optional[int] = None
|
175
176
|
json_schema: Optional[str] = None
|
176
177
|
regex: Optional[str] = None
|
177
178
|
ebnf: Optional[str] = None
|
@@ -250,9 +251,29 @@ ChatCompletionMessageContentPart = Union[
|
|
250
251
|
]
|
251
252
|
|
252
253
|
|
254
|
+
class FunctionResponse(BaseModel):
|
255
|
+
"""Function response."""
|
256
|
+
|
257
|
+
name: Optional[str] = None
|
258
|
+
arguments: Optional[str] = None
|
259
|
+
|
260
|
+
|
261
|
+
class ToolCall(BaseModel):
|
262
|
+
"""Tool call response."""
|
263
|
+
|
264
|
+
id: Optional[str] = None
|
265
|
+
index: Optional[int] = None
|
266
|
+
type: Literal["function"] = "function"
|
267
|
+
function: FunctionResponse
|
268
|
+
|
269
|
+
|
253
270
|
class ChatCompletionMessageGenericParam(BaseModel):
|
254
271
|
role: Literal["system", "assistant", "tool"]
|
255
272
|
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
|
273
|
+
tool_call_id: Optional[str] = None
|
274
|
+
name: Optional[str] = None
|
275
|
+
reasoning_content: Optional[str] = None
|
276
|
+
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
256
277
|
|
257
278
|
|
258
279
|
class ChatCompletionMessageUserParam(BaseModel):
|
@@ -330,6 +351,13 @@ class ChatCompletionRequest(BaseModel):
|
|
330
351
|
description="The maximum number of completion tokens for a chat completion request, "
|
331
352
|
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
|
332
353
|
)
|
354
|
+
thinking_budget: Optional[int] = Field(
|
355
|
+
default=None,
|
356
|
+
description="The maximum number of reasoning tokens that can be generated for a request. "
|
357
|
+
"This setting of does not affect the thinking process of models. "
|
358
|
+
"If the number of tokens generated by the model's thinking process exceeds thinking_budget, "
|
359
|
+
"the reasoning content will be truncated and the final response content will be generated immediately.",
|
360
|
+
)
|
333
361
|
n: int = 1
|
334
362
|
presence_penalty: float = 0.0
|
335
363
|
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
@@ -378,22 +406,6 @@ class ChatCompletionRequest(BaseModel):
|
|
378
406
|
bootstrap_room: Optional[int] = None
|
379
407
|
|
380
408
|
|
381
|
-
class FunctionResponse(BaseModel):
|
382
|
-
"""Function response."""
|
383
|
-
|
384
|
-
name: Optional[str] = None
|
385
|
-
arguments: Optional[str] = None
|
386
|
-
|
387
|
-
|
388
|
-
class ToolCall(BaseModel):
|
389
|
-
"""Tool call response."""
|
390
|
-
|
391
|
-
id: str
|
392
|
-
index: Optional[int] = None
|
393
|
-
type: Literal["function"] = "function"
|
394
|
-
function: FunctionResponse
|
395
|
-
|
396
|
-
|
397
409
|
class ChatMessage(BaseModel):
|
398
410
|
role: Optional[str] = None
|
399
411
|
content: Optional[str] = None
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
|
|
32
32
|
One-time parsing: Detects and parses reasoning sections in the provided text.
|
33
33
|
Returns both reasoning content and normal text separately.
|
34
34
|
"""
|
35
|
-
text = text.replace(self.think_start_token, "")
|
35
|
+
text = text.replace(self.think_start_token, "")
|
36
36
|
if self.think_end_token not in text:
|
37
37
|
# Assume reasoning was truncated before `</think>` token
|
38
38
|
return StreamingParseResult(reasoning_text=text)
|
@@ -73,7 +73,7 @@ class BaseReasoningFormatDetector:
|
|
73
73
|
normal_text = current_text[end_idx + len(self.think_end_token) :]
|
74
74
|
|
75
75
|
return StreamingParseResult(
|
76
|
-
normal_text=normal_text, reasoning_text=reasoning_text
|
76
|
+
normal_text=normal_text, reasoning_text=reasoning_text
|
77
77
|
)
|
78
78
|
|
79
79
|
# Continue with reasoning content
|