sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/srt/models/minicpm3.py
CHANGED
@@ -42,11 +42,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
42
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
44
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
|
-
from sglang.srt.model_executor.forward_batch_info import
|
46
|
-
from sglang.srt.utils import
|
45
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
46
|
+
from sglang.srt.utils import is_flashinfer_available
|
47
47
|
|
48
|
-
|
49
|
-
if not is_hip():
|
48
|
+
if is_flashinfer_available():
|
50
49
|
from flashinfer import bmm_fp8
|
51
50
|
|
52
51
|
|
@@ -193,7 +192,7 @@ class MiniCPM3Attention(nn.Module):
|
|
193
192
|
self,
|
194
193
|
positions: torch.Tensor,
|
195
194
|
hidden_states: torch.Tensor,
|
196
|
-
|
195
|
+
forward_batch: ForwardBatch,
|
197
196
|
) -> torch.Tensor:
|
198
197
|
if self.q_lora_rank is not None:
|
199
198
|
q = self.q_a_proj(hidden_states)[0]
|
@@ -230,7 +229,7 @@ class MiniCPM3Attention(nn.Module):
|
|
230
229
|
v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
|
231
230
|
-1, self.num_local_heads * 128
|
232
231
|
)
|
233
|
-
attn_output = self.attn(q, k, v,
|
232
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
234
233
|
attn_output = attn_output.view(-1, self.num_local_heads, 128)[
|
235
234
|
..., : self.v_head_dim
|
236
235
|
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
@@ -341,7 +340,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
341
340
|
self,
|
342
341
|
positions: torch.Tensor,
|
343
342
|
hidden_states: torch.Tensor,
|
344
|
-
|
343
|
+
forward_batch: ForwardBatch,
|
345
344
|
) -> torch.Tensor:
|
346
345
|
q_len = hidden_states.shape[0]
|
347
346
|
q_input = hidden_states.new_empty(
|
@@ -383,7 +382,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
383
382
|
q_input[..., self.kv_lora_rank :] = q_pe
|
384
383
|
k_input[..., self.kv_lora_rank :] = k_pe
|
385
384
|
|
386
|
-
attn_output = self.attn(q_input, k_input, v_input,
|
385
|
+
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
|
387
386
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
388
387
|
|
389
388
|
if self.w_vc.dtype == torch.float8_e4m3fn:
|
@@ -472,7 +471,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
472
471
|
self,
|
473
472
|
positions: torch.Tensor,
|
474
473
|
hidden_states: torch.Tensor,
|
475
|
-
|
474
|
+
forward_batch: ForwardBatch,
|
476
475
|
residual: Optional[torch.Tensor],
|
477
476
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
478
477
|
# Self Attention
|
@@ -481,7 +480,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
481
480
|
hidden_states = self.self_attn(
|
482
481
|
positions=positions,
|
483
482
|
hidden_states=hidden_states,
|
484
|
-
|
483
|
+
forward_batch=forward_batch,
|
485
484
|
)
|
486
485
|
hidden_states = residual + hidden_states * (
|
487
486
|
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
|
@@ -528,7 +527,7 @@ class MiniCPM3Model(nn.Module):
|
|
528
527
|
self,
|
529
528
|
input_ids: torch.Tensor,
|
530
529
|
positions: torch.Tensor,
|
531
|
-
|
530
|
+
forward_batch: ForwardBatch,
|
532
531
|
input_embeds: torch.Tensor = None,
|
533
532
|
) -> torch.Tensor:
|
534
533
|
if input_embeds is None:
|
@@ -542,7 +541,7 @@ class MiniCPM3Model(nn.Module):
|
|
542
541
|
hidden_states, residual = layer(
|
543
542
|
positions,
|
544
543
|
hidden_states,
|
545
|
-
|
544
|
+
forward_batch,
|
546
545
|
residual,
|
547
546
|
)
|
548
547
|
hidden_states = self.norm(hidden_states)
|
@@ -581,19 +580,19 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
581
580
|
self,
|
582
581
|
input_ids: torch.Tensor,
|
583
582
|
positions: torch.Tensor,
|
584
|
-
|
583
|
+
forward_batch: ForwardBatch,
|
585
584
|
input_embeds: torch.Tensor = None,
|
586
585
|
) -> torch.Tensor:
|
587
586
|
if input_embeds is not None:
|
588
587
|
input_embeds = input_embeds * self.config.scale_emb
|
589
|
-
hidden_states = self.model(input_ids, positions,
|
588
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
590
589
|
hidden_states = hidden_states / self.scale_width
|
591
590
|
if self.config.tie_word_embeddings:
|
592
591
|
lm_head_weight = self.model.embed_tokens.weight
|
593
592
|
else:
|
594
593
|
lm_head_weight = self.lm_head.weight
|
595
594
|
return self.logits_processor(
|
596
|
-
input_ids, hidden_states, lm_head_weight,
|
595
|
+
input_ids, hidden_states, lm_head_weight, forward_batch
|
597
596
|
)
|
598
597
|
|
599
598
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/mixtral.py
CHANGED
@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
44
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
45
45
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
46
|
-
from sglang.srt.model_executor.forward_batch_info import
|
46
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
47
|
|
48
48
|
|
49
49
|
class MixtralMoE(nn.Module):
|
@@ -171,12 +171,12 @@ class MixtralAttention(nn.Module):
|
|
171
171
|
self,
|
172
172
|
positions: torch.Tensor,
|
173
173
|
hidden_states: torch.Tensor,
|
174
|
-
|
174
|
+
forward_batch: ForwardBatch,
|
175
175
|
) -> torch.Tensor:
|
176
176
|
qkv, _ = self.qkv_proj(hidden_states)
|
177
177
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
178
178
|
q, k = self.rotary_emb(positions, q, k)
|
179
|
-
attn_output = self.attn(q, k, v,
|
179
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
180
180
|
output, _ = self.o_proj(attn_output)
|
181
181
|
return output
|
182
182
|
|
@@ -220,7 +220,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
220
220
|
self,
|
221
221
|
positions: torch.Tensor,
|
222
222
|
hidden_states: torch.Tensor,
|
223
|
-
|
223
|
+
forward_batch: ForwardBatch,
|
224
224
|
residual: Optional[torch.Tensor],
|
225
225
|
) -> torch.Tensor:
|
226
226
|
# Self Attention
|
@@ -232,7 +232,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
232
232
|
hidden_states = self.self_attn(
|
233
233
|
positions=positions,
|
234
234
|
hidden_states=hidden_states,
|
235
|
-
|
235
|
+
forward_batch=forward_batch,
|
236
236
|
)
|
237
237
|
|
238
238
|
# Fully Connected
|
@@ -270,7 +270,7 @@ class MixtralModel(nn.Module):
|
|
270
270
|
self,
|
271
271
|
input_ids: torch.Tensor,
|
272
272
|
positions: torch.Tensor,
|
273
|
-
|
273
|
+
forward_batch: ForwardBatch,
|
274
274
|
input_embeds: torch.Tensor = None,
|
275
275
|
) -> torch.Tensor:
|
276
276
|
if input_embeds is None:
|
@@ -281,7 +281,7 @@ class MixtralModel(nn.Module):
|
|
281
281
|
for i in range(len(self.layers)):
|
282
282
|
layer = self.layers[i]
|
283
283
|
hidden_states, residual = layer(
|
284
|
-
positions, hidden_states,
|
284
|
+
positions, hidden_states, forward_batch, residual
|
285
285
|
)
|
286
286
|
hidden_states, _ = self.norm(hidden_states, residual)
|
287
287
|
return hidden_states
|
@@ -307,12 +307,12 @@ class MixtralForCausalLM(nn.Module):
|
|
307
307
|
self,
|
308
308
|
input_ids: torch.Tensor,
|
309
309
|
positions: torch.Tensor,
|
310
|
-
|
310
|
+
forward_batch: ForwardBatch,
|
311
311
|
input_embeds: torch.Tensor = None,
|
312
312
|
) -> torch.Tensor:
|
313
|
-
hidden_states = self.model(input_ids, positions,
|
313
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
314
314
|
return self.logits_processor(
|
315
|
-
input_ids, hidden_states, self.lm_head.weight,
|
315
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
316
316
|
)
|
317
317
|
|
318
318
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
|
|
45
45
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
46
46
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.model_executor.forward_batch_info import
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
49
|
|
50
50
|
|
51
51
|
class MixtralMLP(nn.Module):
|
@@ -216,12 +216,12 @@ class MixtralAttention(nn.Module):
|
|
216
216
|
self,
|
217
217
|
positions: torch.Tensor,
|
218
218
|
hidden_states: torch.Tensor,
|
219
|
-
|
219
|
+
forward_batch: ForwardBatch,
|
220
220
|
) -> torch.Tensor:
|
221
221
|
qkv, _ = self.qkv_proj(hidden_states)
|
222
222
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
223
223
|
q, k = self.rotary_emb(positions, q, k)
|
224
|
-
attn_output = self.attn(q, k, v,
|
224
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
225
225
|
output, _ = self.o_proj(attn_output)
|
226
226
|
return output
|
227
227
|
|
@@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
256
256
|
self,
|
257
257
|
positions: torch.Tensor,
|
258
258
|
hidden_states: torch.Tensor,
|
259
|
-
|
259
|
+
forward_batch: ForwardBatch,
|
260
260
|
residual: Optional[torch.Tensor],
|
261
261
|
) -> torch.Tensor:
|
262
262
|
# Self Attention
|
@@ -268,7 +268,7 @@ class MixtralDecoderLayer(nn.Module):
|
|
268
268
|
hidden_states = self.self_attn(
|
269
269
|
positions=positions,
|
270
270
|
hidden_states=hidden_states,
|
271
|
-
|
271
|
+
forward_batch=forward_batch,
|
272
272
|
)
|
273
273
|
|
274
274
|
# Fully Connected
|
@@ -303,7 +303,7 @@ class MixtralModel(nn.Module):
|
|
303
303
|
self,
|
304
304
|
input_ids: torch.Tensor,
|
305
305
|
positions: torch.Tensor,
|
306
|
-
|
306
|
+
forward_batch: ForwardBatch,
|
307
307
|
input_embeds: torch.Tensor = None,
|
308
308
|
) -> torch.Tensor:
|
309
309
|
if input_embeds is None:
|
@@ -314,7 +314,7 @@ class MixtralModel(nn.Module):
|
|
314
314
|
for i in range(len(self.layers)):
|
315
315
|
layer = self.layers[i]
|
316
316
|
hidden_states, residual = layer(
|
317
|
-
positions, hidden_states,
|
317
|
+
positions, hidden_states, forward_batch, residual
|
318
318
|
)
|
319
319
|
hidden_states, _ = self.norm(hidden_states, residual)
|
320
320
|
return hidden_states
|
@@ -339,12 +339,12 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
339
339
|
self,
|
340
340
|
input_ids: torch.Tensor,
|
341
341
|
positions: torch.Tensor,
|
342
|
-
|
342
|
+
forward_batch: ForwardBatch,
|
343
343
|
input_embeds: torch.Tensor = None,
|
344
344
|
) -> torch.Tensor:
|
345
|
-
hidden_states = self.model(input_ids, positions,
|
345
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
346
346
|
return self.logits_processor(
|
347
|
-
input_ids, hidden_states, self.lm_head.weight,
|
347
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
348
348
|
)
|
349
349
|
|
350
350
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/olmoe.py
CHANGED
@@ -48,7 +48,7 @@ from sglang.srt.layers.layernorm import RMSNorm
|
|
48
48
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
49
49
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
50
50
|
from sglang.srt.layers.radix_attention import RadixAttention
|
51
|
-
from sglang.srt.model_executor.forward_batch_info import
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
52
52
|
|
53
53
|
|
54
54
|
class OlmoeMoE(nn.Module):
|
@@ -175,13 +175,13 @@ class OlmoeAttention(nn.Module):
|
|
175
175
|
self,
|
176
176
|
positions: torch.Tensor,
|
177
177
|
hidden_states: torch.Tensor,
|
178
|
-
|
178
|
+
forward_batch: ForwardBatch,
|
179
179
|
) -> torch.Tensor:
|
180
180
|
qkv, _ = self.qkv_proj(hidden_states)
|
181
181
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
182
182
|
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
|
183
183
|
q, k = self.rotary_emb(positions, q, k)
|
184
|
-
attn_output = self.attn(q, k, v,
|
184
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
185
185
|
output, _ = self.o_proj(attn_output)
|
186
186
|
return output
|
187
187
|
|
@@ -225,7 +225,7 @@ class OlmoeDecoderLayer(nn.Module):
|
|
225
225
|
self,
|
226
226
|
positions: torch.Tensor,
|
227
227
|
hidden_states: torch.Tensor,
|
228
|
-
|
228
|
+
forward_batch: ForwardBatch,
|
229
229
|
residual: Optional[torch.Tensor],
|
230
230
|
) -> torch.Tensor:
|
231
231
|
# Self Attention
|
@@ -238,7 +238,7 @@ class OlmoeDecoderLayer(nn.Module):
|
|
238
238
|
hidden_states = self.self_attn(
|
239
239
|
positions=positions,
|
240
240
|
hidden_states=hidden_states,
|
241
|
-
|
241
|
+
forward_batch=forward_batch,
|
242
242
|
)
|
243
243
|
|
244
244
|
# Fully Connected
|
@@ -274,7 +274,7 @@ class OlmoeModel(nn.Module):
|
|
274
274
|
self,
|
275
275
|
input_ids: torch.Tensor,
|
276
276
|
positions: torch.Tensor,
|
277
|
-
|
277
|
+
forward_batch: ForwardBatch,
|
278
278
|
input_embeds: torch.Tensor = None,
|
279
279
|
) -> torch.Tensor:
|
280
280
|
if input_embeds is None:
|
@@ -285,7 +285,7 @@ class OlmoeModel(nn.Module):
|
|
285
285
|
for i in range(len(self.layers)):
|
286
286
|
layer = self.layers[i]
|
287
287
|
hidden_states, residual = layer(
|
288
|
-
positions, hidden_states,
|
288
|
+
positions, hidden_states, forward_batch, residual
|
289
289
|
)
|
290
290
|
hidden_states, _ = self.norm(hidden_states, residual)
|
291
291
|
return hidden_states
|
@@ -314,12 +314,12 @@ class OlmoeForCausalLM(nn.Module):
|
|
314
314
|
self,
|
315
315
|
input_ids: torch.Tensor,
|
316
316
|
positions: torch.Tensor,
|
317
|
-
|
317
|
+
forward_batch: ForwardBatch,
|
318
318
|
input_embeds: torch.Tensor = None,
|
319
319
|
) -> torch.Tensor:
|
320
|
-
hidden_states = self.model(input_ids, positions,
|
320
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
321
321
|
return self.logits_processor(
|
322
|
-
input_ids, hidden_states, self.lm_head.weight,
|
322
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
323
323
|
)
|
324
324
|
|
325
325
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/qwen.py
CHANGED
@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
|
|
39
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.model_executor.forward_batch_info import
|
42
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
43
|
|
44
44
|
|
45
45
|
class QWenMLP(nn.Module):
|
@@ -133,12 +133,12 @@ class QWenAttention(nn.Module):
|
|
133
133
|
self,
|
134
134
|
positions: torch.Tensor,
|
135
135
|
hidden_states: torch.Tensor,
|
136
|
-
|
136
|
+
forward_batch: ForwardBatch,
|
137
137
|
) -> torch.Tensor:
|
138
138
|
qkv, _ = self.c_attn(hidden_states)
|
139
139
|
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
140
140
|
q, k = self.rotary_emb(positions, q, k)
|
141
|
-
attn_output = self.attn(q, k, v,
|
141
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
142
142
|
output, _ = self.c_proj(attn_output)
|
143
143
|
return output
|
144
144
|
|
@@ -177,7 +177,7 @@ class QWenBlock(nn.Module):
|
|
177
177
|
self,
|
178
178
|
positions: torch.Tensor,
|
179
179
|
hidden_states: torch.Tensor,
|
180
|
-
|
180
|
+
forward_batch: ForwardBatch,
|
181
181
|
) -> torch.Tensor:
|
182
182
|
# Self Attention
|
183
183
|
residual = hidden_states
|
@@ -185,7 +185,7 @@ class QWenBlock(nn.Module):
|
|
185
185
|
hidden_states = self.attn(
|
186
186
|
positions=positions,
|
187
187
|
hidden_states=hidden_states,
|
188
|
-
|
188
|
+
forward_batch=forward_batch,
|
189
189
|
)
|
190
190
|
hidden_states = residual + hidden_states
|
191
191
|
|
@@ -224,7 +224,7 @@ class QWenModel(nn.Module):
|
|
224
224
|
self,
|
225
225
|
input_ids: torch.Tensor,
|
226
226
|
positions: torch.Tensor,
|
227
|
-
|
227
|
+
forward_batch: ForwardBatch,
|
228
228
|
) -> torch.Tensor:
|
229
229
|
hidden_states = self.wte(input_ids)
|
230
230
|
for i in range(len(self.h)):
|
@@ -232,7 +232,7 @@ class QWenModel(nn.Module):
|
|
232
232
|
hidden_states = layer(
|
233
233
|
positions,
|
234
234
|
hidden_states,
|
235
|
-
|
235
|
+
forward_batch,
|
236
236
|
)
|
237
237
|
hidden_states = self.ln_f(hidden_states)
|
238
238
|
return hidden_states
|
@@ -257,11 +257,11 @@ class QWenLMHeadModel(nn.Module):
|
|
257
257
|
self,
|
258
258
|
input_ids: torch.Tensor,
|
259
259
|
positions: torch.Tensor,
|
260
|
-
|
260
|
+
forward_batch: ForwardBatch,
|
261
261
|
):
|
262
|
-
hidden_states = self.transformer(input_ids, positions,
|
262
|
+
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
263
263
|
return self.logits_processor(
|
264
|
-
input_ids, hidden_states, self.lm_head.weight,
|
264
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
265
265
|
)
|
266
266
|
|
267
267
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/qwen2.py
CHANGED
@@ -40,7 +40,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
40
40
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
41
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.model_executor.forward_batch_info import
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
44
44
|
|
45
45
|
Qwen2Config = None
|
46
46
|
|
@@ -149,12 +149,12 @@ class Qwen2Attention(nn.Module):
|
|
149
149
|
self,
|
150
150
|
positions: torch.Tensor,
|
151
151
|
hidden_states: torch.Tensor,
|
152
|
-
|
152
|
+
forward_batch: ForwardBatch,
|
153
153
|
) -> torch.Tensor:
|
154
154
|
qkv, _ = self.qkv_proj(hidden_states)
|
155
155
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
156
156
|
q, k = self.rotary_emb(positions, q, k)
|
157
|
-
attn_output = self.attn(q, k, v,
|
157
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
158
158
|
output, _ = self.o_proj(attn_output)
|
159
159
|
return output
|
160
160
|
|
@@ -196,7 +196,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|
196
196
|
self,
|
197
197
|
positions: torch.Tensor,
|
198
198
|
hidden_states: torch.Tensor,
|
199
|
-
|
199
|
+
forward_batch: ForwardBatch,
|
200
200
|
residual: Optional[torch.Tensor],
|
201
201
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
202
202
|
# Self Attention
|
@@ -208,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
|
|
208
208
|
hidden_states = self.self_attn(
|
209
209
|
positions=positions,
|
210
210
|
hidden_states=hidden_states,
|
211
|
-
|
211
|
+
forward_batch=forward_batch,
|
212
212
|
)
|
213
213
|
|
214
214
|
# Fully Connected
|
@@ -243,7 +243,7 @@ class Qwen2Model(nn.Module):
|
|
243
243
|
self,
|
244
244
|
input_ids: torch.Tensor,
|
245
245
|
positions: torch.Tensor,
|
246
|
-
|
246
|
+
forward_batch: ForwardBatch,
|
247
247
|
input_embeds: torch.Tensor = None,
|
248
248
|
) -> torch.Tensor:
|
249
249
|
if input_embeds is None:
|
@@ -256,7 +256,7 @@ class Qwen2Model(nn.Module):
|
|
256
256
|
hidden_states, residual = layer(
|
257
257
|
positions,
|
258
258
|
hidden_states,
|
259
|
-
|
259
|
+
forward_batch,
|
260
260
|
residual,
|
261
261
|
)
|
262
262
|
hidden_states, _ = self.norm(hidden_states, residual)
|
@@ -283,17 +283,17 @@ class Qwen2ForCausalLM(nn.Module):
|
|
283
283
|
self,
|
284
284
|
input_ids: torch.Tensor,
|
285
285
|
positions: torch.Tensor,
|
286
|
-
|
286
|
+
forward_batch: ForwardBatch,
|
287
287
|
input_embeds: torch.Tensor = None,
|
288
288
|
get_embedding: bool = False,
|
289
289
|
) -> torch.Tensor:
|
290
|
-
hidden_states = self.model(input_ids, positions,
|
290
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
291
291
|
if not get_embedding:
|
292
292
|
return self.logits_processor(
|
293
|
-
input_ids, hidden_states, self.lm_head.weight,
|
293
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
294
294
|
)
|
295
295
|
else:
|
296
|
-
return self.pooler(hidden_states,
|
296
|
+
return self.pooler(hidden_states, forward_batch)
|
297
297
|
|
298
298
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
299
299
|
stacked_params_mapping = [
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
49
49
|
from sglang.srt.layers.radix_attention import RadixAttention
|
50
50
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
51
51
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
52
|
-
from sglang.srt.model_executor.forward_batch_info import
|
52
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
53
53
|
|
54
54
|
|
55
55
|
class Qwen2MoeMLP(nn.Module):
|
@@ -221,12 +221,12 @@ class Qwen2MoeAttention(nn.Module):
|
|
221
221
|
self,
|
222
222
|
positions: torch.Tensor,
|
223
223
|
hidden_states: torch.Tensor,
|
224
|
-
|
224
|
+
forward_batch: ForwardBatch,
|
225
225
|
) -> torch.Tensor:
|
226
226
|
qkv, _ = self.qkv_proj(hidden_states)
|
227
227
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
228
228
|
q, k = self.rotary_emb(positions, q, k)
|
229
|
-
attn_output = self.attn(q, k, v,
|
229
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
230
230
|
output, _ = self.o_proj(attn_output)
|
231
231
|
return output
|
232
232
|
|
@@ -281,7 +281,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
281
281
|
self,
|
282
282
|
positions: torch.Tensor,
|
283
283
|
hidden_states: torch.Tensor,
|
284
|
-
|
284
|
+
forward_batch: ForwardBatch,
|
285
285
|
residual: Optional[torch.Tensor],
|
286
286
|
) -> torch.Tensor:
|
287
287
|
# Self Attention
|
@@ -293,7 +293,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
293
293
|
hidden_states = self.self_attn(
|
294
294
|
positions=positions,
|
295
295
|
hidden_states=hidden_states,
|
296
|
-
|
296
|
+
forward_batch=forward_batch,
|
297
297
|
)
|
298
298
|
|
299
299
|
# Fully Connected
|
@@ -331,7 +331,7 @@ class Qwen2MoeModel(nn.Module):
|
|
331
331
|
self,
|
332
332
|
input_ids: torch.Tensor,
|
333
333
|
positions: torch.Tensor,
|
334
|
-
|
334
|
+
forward_batch: ForwardBatch,
|
335
335
|
input_embeds: torch.Tensor = None,
|
336
336
|
) -> torch.Tensor:
|
337
337
|
if input_embeds is None:
|
@@ -342,7 +342,7 @@ class Qwen2MoeModel(nn.Module):
|
|
342
342
|
for i in range(len(self.layers)):
|
343
343
|
layer = self.layers[i]
|
344
344
|
hidden_states, residual = layer(
|
345
|
-
positions, hidden_states,
|
345
|
+
positions, hidden_states, forward_batch, residual
|
346
346
|
)
|
347
347
|
hidden_states, _ = self.norm(hidden_states, residual)
|
348
348
|
return hidden_states
|
@@ -373,12 +373,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
373
373
|
self,
|
374
374
|
input_ids: torch.Tensor,
|
375
375
|
positions: torch.Tensor,
|
376
|
-
|
376
|
+
forward_batch: ForwardBatch,
|
377
377
|
input_embeds: torch.Tensor = None,
|
378
378
|
) -> torch.Tensor:
|
379
|
-
hidden_states = self.model(input_ids, positions,
|
379
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
380
380
|
return self.logits_processor(
|
381
|
-
input_ids, hidden_states, self.lm_head.weight,
|
381
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
382
382
|
)
|
383
383
|
|
384
384
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/stablelm.py
CHANGED
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
|
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.model_executor.forward_batch_info import
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
44
44
|
|
45
45
|
|
46
46
|
class StablelmMLP(nn.Module):
|
@@ -145,12 +145,12 @@ class StablelmAttention(nn.Module):
|
|
145
145
|
self,
|
146
146
|
positions: torch.Tensor,
|
147
147
|
hidden_states: torch.Tensor,
|
148
|
-
|
148
|
+
forward_batch: ForwardBatch,
|
149
149
|
) -> torch.Tensor:
|
150
150
|
qkv, _ = self.qkv_proj(hidden_states)
|
151
151
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
152
152
|
q, k = self.rotary_emb(positions, q, k)
|
153
|
-
attn_output = self.attn(q, k, v,
|
153
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
154
154
|
output, _ = self.o_proj(attn_output)
|
155
155
|
return output
|
156
156
|
|
@@ -173,7 +173,7 @@ class StablelmDecoderLayer(nn.Module):
|
|
173
173
|
self,
|
174
174
|
positions: torch.Tensor,
|
175
175
|
hidden_states: torch.Tensor,
|
176
|
-
|
176
|
+
forward_batch: ForwardBatch,
|
177
177
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
178
178
|
# Self Attention
|
179
179
|
residual = hidden_states
|
@@ -181,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
|
|
181
181
|
hidden_states = self.self_attn(
|
182
182
|
positions=positions,
|
183
183
|
hidden_states=hidden_states,
|
184
|
-
|
184
|
+
forward_batch=forward_batch,
|
185
185
|
)
|
186
186
|
hidden_states = residual + hidden_states
|
187
187
|
|
@@ -218,7 +218,7 @@ class StableLMEpochModel(nn.Module):
|
|
218
218
|
self,
|
219
219
|
input_ids: torch.Tensor,
|
220
220
|
positions: torch.Tensor,
|
221
|
-
|
221
|
+
forward_batch: ForwardBatch,
|
222
222
|
input_embeds: torch.Tensor = None,
|
223
223
|
) -> torch.Tensor:
|
224
224
|
if input_embeds is None:
|
@@ -230,7 +230,7 @@ class StableLMEpochModel(nn.Module):
|
|
230
230
|
hidden_states, residual = layer(
|
231
231
|
positions,
|
232
232
|
hidden_states,
|
233
|
-
|
233
|
+
forward_batch,
|
234
234
|
)
|
235
235
|
hidden_states = self.norm(hidden_states)
|
236
236
|
return hidden_states
|
@@ -255,12 +255,12 @@ class StableLmForCausalLM(nn.Module):
|
|
255
255
|
self,
|
256
256
|
input_ids: torch.Tensor,
|
257
257
|
positions: torch.Tensor,
|
258
|
-
|
258
|
+
forward_batch: ForwardBatch,
|
259
259
|
input_embeds: torch.Tensor = None,
|
260
260
|
) -> torch.Tensor:
|
261
|
-
hidden_states = self.model(input_ids, positions,
|
261
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
262
262
|
return self.logits_processor(
|
263
|
-
input_ids, hidden_states, self.lm_head.weight,
|
263
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
264
264
|
)
|
265
265
|
|
266
266
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|