sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +23 -1
- sglang/bench_latency.py +48 -33
- sglang/bench_server_latency.py +0 -6
- sglang/bench_serving.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +14 -1
- sglang/lang/interpreter.py +16 -6
- sglang/lang/ir.py +20 -4
- sglang/srt/configs/model_config.py +11 -9
- sglang/srt/constrained/fsm_cache.py +9 -1
- sglang/srt/constrained/jump_forward.py +15 -2
- sglang/srt/hf_transformers_utils.py +1 -0
- sglang/srt/layers/activation.py +4 -4
- sglang/srt/layers/attention/__init__.py +49 -0
- sglang/srt/layers/attention/flashinfer_backend.py +277 -0
- sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
- sglang/srt/layers/attention/triton_backend.py +161 -0
- sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/patch.py +117 -0
- sglang/srt/layers/layernorm.py +4 -4
- sglang/srt/layers/logits_processor.py +19 -15
- sglang/srt/layers/pooler.py +3 -3
- sglang/srt/layers/quantization/__init__.py +0 -2
- sglang/srt/layers/radix_attention.py +6 -4
- sglang/srt/layers/sampler.py +6 -4
- sglang/srt/layers/torchao_utils.py +18 -0
- sglang/srt/lora/lora.py +20 -21
- sglang/srt/lora/lora_manager.py +97 -25
- sglang/srt/managers/detokenizer_manager.py +31 -18
- sglang/srt/managers/image_processor.py +187 -0
- sglang/srt/managers/io_struct.py +99 -75
- sglang/srt/managers/schedule_batch.py +187 -68
- sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
- sglang/srt/managers/scheduler.py +1021 -0
- sglang/srt/managers/tokenizer_manager.py +120 -247
- sglang/srt/managers/tp_worker.py +28 -925
- sglang/srt/mem_cache/memory_pool.py +34 -52
- sglang/srt/mem_cache/radix_cache.py +5 -5
- sglang/srt/model_executor/cuda_graph_runner.py +25 -25
- sglang/srt/model_executor/forward_batch_info.py +94 -97
- sglang/srt/model_executor/model_runner.py +76 -78
- sglang/srt/models/baichuan.py +10 -10
- sglang/srt/models/chatglm.py +12 -12
- sglang/srt/models/commandr.py +10 -10
- sglang/srt/models/dbrx.py +12 -12
- sglang/srt/models/deepseek.py +10 -10
- sglang/srt/models/deepseek_v2.py +14 -15
- sglang/srt/models/exaone.py +10 -10
- sglang/srt/models/gemma.py +10 -10
- sglang/srt/models/gemma2.py +11 -11
- sglang/srt/models/gpt_bigcode.py +10 -10
- sglang/srt/models/grok.py +10 -10
- sglang/srt/models/internlm2.py +10 -10
- sglang/srt/models/llama.py +22 -10
- sglang/srt/models/llama_classification.py +5 -5
- sglang/srt/models/llama_embedding.py +4 -4
- sglang/srt/models/llama_reward.py +142 -0
- sglang/srt/models/llava.py +39 -33
- sglang/srt/models/llavavid.py +31 -28
- sglang/srt/models/minicpm.py +10 -10
- sglang/srt/models/minicpm3.py +14 -15
- sglang/srt/models/mixtral.py +10 -10
- sglang/srt/models/mixtral_quant.py +10 -10
- sglang/srt/models/olmoe.py +10 -10
- sglang/srt/models/qwen.py +10 -10
- sglang/srt/models/qwen2.py +11 -11
- sglang/srt/models/qwen2_moe.py +10 -10
- sglang/srt/models/stablelm.py +10 -10
- sglang/srt/models/torch_native_llama.py +506 -0
- sglang/srt/models/xverse.py +10 -10
- sglang/srt/models/xverse_moe.py +10 -10
- sglang/srt/openai_api/adapter.py +7 -0
- sglang/srt/sampling/sampling_batch_info.py +36 -27
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +170 -119
- sglang/srt/server_args.py +54 -27
- sglang/srt/utils.py +101 -128
- sglang/test/runners.py +76 -33
- sglang/test/test_programs.py +38 -5
- sglang/test/test_utils.py +53 -9
- sglang/version.py +1 -1
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
- sglang-0.3.3.dist-info/RECORD +139 -0
- sglang/srt/layers/attention_backend.py +0 -482
- sglang/srt/managers/controller_multi.py +0 -207
- sglang/srt/managers/controller_single.py +0 -164
- sglang-0.3.1.post3.dist-info/RECORD +0 -134
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
- /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
- {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek.py
CHANGED
@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.model_executor.forward_batch_info import
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
50
|
|
51
51
|
|
52
52
|
class DeepseekMLP(nn.Module):
|
@@ -246,12 +246,12 @@ class DeepseekAttention(nn.Module):
|
|
246
246
|
self,
|
247
247
|
positions: torch.Tensor,
|
248
248
|
hidden_states: torch.Tensor,
|
249
|
-
|
249
|
+
forward_batch: ForwardBatch,
|
250
250
|
) -> torch.Tensor:
|
251
251
|
qkv, _ = self.qkv_proj(hidden_states)
|
252
252
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
253
253
|
q, k = self.rotary_emb(positions, q, k)
|
254
|
-
attn_output = self.attn(q, k, v,
|
254
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
255
255
|
output, _ = self.o_proj(attn_output)
|
256
256
|
return output
|
257
257
|
|
@@ -303,7 +303,7 @@ class DeepseekDecoderLayer(nn.Module):
|
|
303
303
|
self,
|
304
304
|
positions: torch.Tensor,
|
305
305
|
hidden_states: torch.Tensor,
|
306
|
-
|
306
|
+
forward_batch: ForwardBatch,
|
307
307
|
residual: Optional[torch.Tensor],
|
308
308
|
) -> torch.Tensor:
|
309
309
|
# Self Attention
|
@@ -315,7 +315,7 @@ class DeepseekDecoderLayer(nn.Module):
|
|
315
315
|
hidden_states = self.self_attn(
|
316
316
|
positions=positions,
|
317
317
|
hidden_states=hidden_states,
|
318
|
-
|
318
|
+
forward_batch=forward_batch,
|
319
319
|
)
|
320
320
|
|
321
321
|
# Fully Connected
|
@@ -356,14 +356,14 @@ class DeepseekModel(nn.Module):
|
|
356
356
|
self,
|
357
357
|
input_ids: torch.Tensor,
|
358
358
|
positions: torch.Tensor,
|
359
|
-
|
359
|
+
forward_batch: ForwardBatch,
|
360
360
|
) -> torch.Tensor:
|
361
361
|
hidden_states = self.embed_tokens(input_ids)
|
362
362
|
residual = None
|
363
363
|
for i in range(len(self.layers)):
|
364
364
|
layer = self.layers[i]
|
365
365
|
hidden_states, residual = layer(
|
366
|
-
positions, hidden_states,
|
366
|
+
positions, hidden_states, forward_batch, residual
|
367
367
|
)
|
368
368
|
hidden_states, _ = self.norm(hidden_states, residual)
|
369
369
|
return hidden_states
|
@@ -391,11 +391,11 @@ class DeepseekForCausalLM(nn.Module):
|
|
391
391
|
self,
|
392
392
|
input_ids: torch.Tensor,
|
393
393
|
positions: torch.Tensor,
|
394
|
-
|
394
|
+
forward_batch: ForwardBatch,
|
395
395
|
) -> torch.Tensor:
|
396
|
-
hidden_states = self.model(input_ids, positions,
|
396
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
397
397
|
return self.logits_processor(
|
398
|
-
input_ids, hidden_states, self.lm_head.weight,
|
398
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
399
399
|
)
|
400
400
|
|
401
401
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -46,11 +46,10 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
46
46
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
49
|
-
from sglang.srt.model_executor.forward_batch_info import
|
50
|
-
from sglang.srt.utils import
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
|
+
from sglang.srt.utils import is_flashinfer_available
|
51
51
|
|
52
|
-
|
53
|
-
if not is_hip():
|
52
|
+
if is_flashinfer_available():
|
54
53
|
from flashinfer import bmm_fp8
|
55
54
|
|
56
55
|
|
@@ -281,7 +280,7 @@ class DeepseekV2Attention(nn.Module):
|
|
281
280
|
self,
|
282
281
|
positions: torch.Tensor,
|
283
282
|
hidden_states: torch.Tensor,
|
284
|
-
|
283
|
+
forward_batch: ForwardBatch,
|
285
284
|
) -> torch.Tensor:
|
286
285
|
if self.q_lora_rank is not None:
|
287
286
|
q = self.q_a_proj(hidden_states)[0]
|
@@ -314,7 +313,7 @@ class DeepseekV2Attention(nn.Module):
|
|
314
313
|
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
|
315
314
|
-1, self.num_local_heads * 256
|
316
315
|
)
|
317
|
-
attn_output = self.attn(q, k, v,
|
316
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
318
317
|
attn_output = attn_output.view(-1, self.num_local_heads, 256)[
|
319
318
|
..., : self.v_head_dim
|
320
319
|
].reshape(-1, self.num_local_heads * self.v_head_dim)
|
@@ -433,7 +432,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
433
432
|
self,
|
434
433
|
positions: torch.Tensor,
|
435
434
|
hidden_states: torch.Tensor,
|
436
|
-
|
435
|
+
forward_batch: ForwardBatch,
|
437
436
|
) -> torch.Tensor:
|
438
437
|
q_len = hidden_states.shape[0]
|
439
438
|
q_input = hidden_states.new_empty(
|
@@ -471,7 +470,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
471
470
|
q_input[..., self.kv_lora_rank :] = q_pe
|
472
471
|
k_input[..., self.kv_lora_rank :] = k_pe
|
473
472
|
|
474
|
-
attn_output = self.attn(q_input, k_input, v_input,
|
473
|
+
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
|
475
474
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
476
475
|
|
477
476
|
if self.w_vc.dtype == torch.float8_e4m3fn:
|
@@ -567,7 +566,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
567
566
|
self,
|
568
567
|
positions: torch.Tensor,
|
569
568
|
hidden_states: torch.Tensor,
|
570
|
-
|
569
|
+
forward_batch: ForwardBatch,
|
571
570
|
residual: Optional[torch.Tensor],
|
572
571
|
) -> torch.Tensor:
|
573
572
|
# Self Attention
|
@@ -579,7 +578,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
579
578
|
hidden_states = self.self_attn(
|
580
579
|
positions=positions,
|
581
580
|
hidden_states=hidden_states,
|
582
|
-
|
581
|
+
forward_batch=forward_batch,
|
583
582
|
)
|
584
583
|
|
585
584
|
# Fully Connected
|
@@ -623,14 +622,14 @@ class DeepseekV2Model(nn.Module):
|
|
623
622
|
self,
|
624
623
|
input_ids: torch.Tensor,
|
625
624
|
positions: torch.Tensor,
|
626
|
-
|
625
|
+
forward_batch: ForwardBatch,
|
627
626
|
) -> torch.Tensor:
|
628
627
|
hidden_states = self.embed_tokens(input_ids)
|
629
628
|
residual = None
|
630
629
|
for i in range(len(self.layers)):
|
631
630
|
layer = self.layers[i]
|
632
631
|
hidden_states, residual = layer(
|
633
|
-
positions, hidden_states,
|
632
|
+
positions, hidden_states, forward_batch, residual
|
634
633
|
)
|
635
634
|
hidden_states, _ = self.norm(hidden_states, residual)
|
636
635
|
return hidden_states
|
@@ -658,11 +657,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
658
657
|
self,
|
659
658
|
input_ids: torch.Tensor,
|
660
659
|
positions: torch.Tensor,
|
661
|
-
|
660
|
+
forward_batch: ForwardBatch,
|
662
661
|
) -> torch.Tensor:
|
663
|
-
hidden_states = self.model(input_ids, positions,
|
662
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
664
663
|
return self.logits_processor(
|
665
|
-
input_ids, hidden_states, self.lm_head.weight,
|
664
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
666
665
|
)
|
667
666
|
|
668
667
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/exaone.py
CHANGED
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
|
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
41
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.model_executor.forward_batch_info import
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
44
44
|
|
45
45
|
|
46
46
|
class ExaoneGatedMLP(nn.Module):
|
@@ -162,12 +162,12 @@ class ExaoneAttention(nn.Module):
|
|
162
162
|
self,
|
163
163
|
positions: torch.Tensor,
|
164
164
|
hidden_states: torch.Tensor,
|
165
|
-
|
165
|
+
forward_batch: ForwardBatch,
|
166
166
|
) -> torch.Tensor:
|
167
167
|
qkv, _ = self.qkv_proj(hidden_states)
|
168
168
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
169
169
|
q, k = self.rotary_emb(positions, q, k)
|
170
|
-
attn_output = self.attn(q, k, v,
|
170
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
171
171
|
output, _ = self.out_proj(attn_output)
|
172
172
|
return output
|
173
173
|
|
@@ -220,7 +220,7 @@ class ExaoneDecoderLayer(nn.Module):
|
|
220
220
|
self,
|
221
221
|
positions: torch.Tensor,
|
222
222
|
hidden_states: torch.Tensor,
|
223
|
-
|
223
|
+
forward_batch: ForwardBatch,
|
224
224
|
residual: Optional[torch.Tensor],
|
225
225
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
226
226
|
# Self Attention
|
@@ -232,7 +232,7 @@ class ExaoneDecoderLayer(nn.Module):
|
|
232
232
|
hidden_states = self.self_attn(
|
233
233
|
positions=positions,
|
234
234
|
hidden_states=hidden_states,
|
235
|
-
|
235
|
+
forward_batch=forward_batch,
|
236
236
|
)
|
237
237
|
|
238
238
|
# Fully Connected
|
@@ -270,7 +270,7 @@ class ExaoneModel(nn.Module):
|
|
270
270
|
self,
|
271
271
|
input_ids: torch.Tensor,
|
272
272
|
positions: torch.Tensor,
|
273
|
-
|
273
|
+
forward_batch: ForwardBatch,
|
274
274
|
input_embeds: torch.Tensor = None,
|
275
275
|
) -> torch.Tensor:
|
276
276
|
if input_embeds is None:
|
@@ -283,7 +283,7 @@ class ExaoneModel(nn.Module):
|
|
283
283
|
hidden_states, residual = layer(
|
284
284
|
positions,
|
285
285
|
hidden_states,
|
286
|
-
|
286
|
+
forward_batch,
|
287
287
|
residual,
|
288
288
|
)
|
289
289
|
hidden_states, _ = self.ln_f(hidden_states, residual)
|
@@ -309,14 +309,14 @@ class ExaoneForCausalLM(nn.Module):
|
|
309
309
|
self,
|
310
310
|
input_ids: torch.Tensor,
|
311
311
|
positions: torch.Tensor,
|
312
|
-
|
312
|
+
forward_batch: ForwardBatch,
|
313
313
|
input_embeds: torch.Tensor = None,
|
314
314
|
) -> LogitsProcessorOutput:
|
315
315
|
hidden_states = self.transformer(
|
316
|
-
input_ids, positions,
|
316
|
+
input_ids, positions, forward_batch, input_embeds
|
317
317
|
)
|
318
318
|
return self.logits_processor(
|
319
|
-
input_ids, hidden_states, self.lm_head.weight,
|
319
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
320
320
|
)
|
321
321
|
|
322
322
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/gemma.py
CHANGED
@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
|
|
37
37
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.model_executor.forward_batch_info import
|
40
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
41
|
|
42
42
|
|
43
43
|
class GemmaMLP(nn.Module):
|
@@ -137,12 +137,12 @@ class GemmaAttention(nn.Module):
|
|
137
137
|
self,
|
138
138
|
positions: torch.Tensor,
|
139
139
|
hidden_states: torch.Tensor,
|
140
|
-
|
140
|
+
forward_batch: ForwardBatch,
|
141
141
|
) -> torch.Tensor:
|
142
142
|
qkv, _ = self.qkv_proj(hidden_states)
|
143
143
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
144
144
|
q, k = self.rotary_emb(positions, q, k)
|
145
|
-
attn_output = self.attn(q, k, v,
|
145
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
146
146
|
output, _ = self.o_proj(attn_output)
|
147
147
|
return output
|
148
148
|
|
@@ -180,7 +180,7 @@ class GemmaDecoderLayer(nn.Module):
|
|
180
180
|
self,
|
181
181
|
positions: torch.Tensor,
|
182
182
|
hidden_states: torch.Tensor,
|
183
|
-
|
183
|
+
forward_batch: ForwardBatch,
|
184
184
|
residual: Optional[torch.Tensor],
|
185
185
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
186
186
|
# Self Attention
|
@@ -192,7 +192,7 @@ class GemmaDecoderLayer(nn.Module):
|
|
192
192
|
hidden_states = self.self_attn(
|
193
193
|
positions=positions,
|
194
194
|
hidden_states=hidden_states,
|
195
|
-
|
195
|
+
forward_batch=forward_batch,
|
196
196
|
)
|
197
197
|
|
198
198
|
# Fully Connected
|
@@ -226,7 +226,7 @@ class GemmaModel(nn.Module):
|
|
226
226
|
self,
|
227
227
|
input_ids: torch.Tensor,
|
228
228
|
positions: torch.Tensor,
|
229
|
-
|
229
|
+
forward_batch: ForwardBatch,
|
230
230
|
input_embeds: torch.Tensor = None,
|
231
231
|
) -> torch.Tensor:
|
232
232
|
if input_embeds is None:
|
@@ -243,7 +243,7 @@ class GemmaModel(nn.Module):
|
|
243
243
|
hidden_states, residual = layer(
|
244
244
|
positions,
|
245
245
|
hidden_states,
|
246
|
-
|
246
|
+
forward_batch,
|
247
247
|
residual,
|
248
248
|
)
|
249
249
|
hidden_states, _ = self.norm(hidden_states, residual)
|
@@ -293,12 +293,12 @@ class GemmaForCausalLM(nn.Module):
|
|
293
293
|
self,
|
294
294
|
input_ids: torch.Tensor,
|
295
295
|
positions: torch.Tensor,
|
296
|
-
|
296
|
+
forward_batch: ForwardBatch,
|
297
297
|
input_embeds: torch.Tensor = None,
|
298
298
|
) -> torch.Tensor:
|
299
|
-
hidden_states = self.model(input_ids, positions,
|
299
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
300
300
|
return self.logits_processor(
|
301
|
-
input_ids, hidden_states, self.model.embed_tokens.weight,
|
301
|
+
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
302
302
|
)
|
303
303
|
|
304
304
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/gemma2.py
CHANGED
@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
|
|
37
37
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
38
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.model_executor.forward_batch_info import
|
40
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
41
|
|
42
42
|
|
43
43
|
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
@@ -163,24 +163,24 @@ class Gemma2Attention(nn.Module):
|
|
163
163
|
self.scaling,
|
164
164
|
num_kv_heads=self.num_kv_heads,
|
165
165
|
layer_id=layer_idx,
|
166
|
+
logit_cap=self.config.attn_logit_softcapping,
|
166
167
|
sliding_window_size=(
|
167
168
|
get_attention_sliding_window_size(config)
|
168
169
|
if use_sliding_window
|
169
170
|
else None
|
170
171
|
),
|
171
|
-
logit_cap=self.config.attn_logit_softcapping,
|
172
172
|
)
|
173
173
|
|
174
174
|
def forward(
|
175
175
|
self,
|
176
176
|
positions: torch.Tensor,
|
177
177
|
hidden_states: torch.Tensor,
|
178
|
-
|
178
|
+
forward_batch: ForwardBatch,
|
179
179
|
) -> torch.Tensor:
|
180
180
|
qkv, _ = self.qkv_proj(hidden_states)
|
181
181
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
182
182
|
q, k = self.rotary_emb(positions, q, k)
|
183
|
-
attn_output = self.attn(q, k, v,
|
183
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
184
184
|
output, _ = self.o_proj(attn_output)
|
185
185
|
return output
|
186
186
|
|
@@ -230,7 +230,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|
230
230
|
self,
|
231
231
|
positions: torch.Tensor,
|
232
232
|
hidden_states: torch.Tensor,
|
233
|
-
|
233
|
+
forward_batch: ForwardBatch,
|
234
234
|
residual: Optional[torch.Tensor],
|
235
235
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
236
236
|
if residual is None:
|
@@ -241,7 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|
241
241
|
hidden_states = self.self_attn(
|
242
242
|
positions=positions,
|
243
243
|
hidden_states=hidden_states,
|
244
|
-
|
244
|
+
forward_batch=forward_batch,
|
245
245
|
)
|
246
246
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
247
247
|
|
@@ -286,7 +286,7 @@ class Gemma2Model(nn.Module):
|
|
286
286
|
self,
|
287
287
|
input_ids: torch.Tensor,
|
288
288
|
positions: torch.Tensor,
|
289
|
-
|
289
|
+
forward_batch: ForwardBatch,
|
290
290
|
input_embeds: torch.Tensor = None,
|
291
291
|
) -> torch.Tensor:
|
292
292
|
if input_embeds is None:
|
@@ -302,7 +302,7 @@ class Gemma2Model(nn.Module):
|
|
302
302
|
hidden_states, residual = layer(
|
303
303
|
positions,
|
304
304
|
hidden_states,
|
305
|
-
|
305
|
+
forward_batch,
|
306
306
|
residual,
|
307
307
|
)
|
308
308
|
hidden_states, _ = self.norm(hidden_states, residual)
|
@@ -352,12 +352,12 @@ class Gemma2ForCausalLM(nn.Module):
|
|
352
352
|
self,
|
353
353
|
input_ids: torch.Tensor,
|
354
354
|
positions: torch.Tensor,
|
355
|
-
|
355
|
+
forward_batch: ForwardBatch,
|
356
356
|
input_embeds: torch.Tensor = None,
|
357
357
|
) -> torch.Tensor:
|
358
|
-
hidden_states = self.model(input_ids, positions,
|
358
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
359
359
|
return self.logits_processor(
|
360
|
-
input_ids, hidden_states, self.model.embed_tokens.weight,
|
360
|
+
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
|
361
361
|
)
|
362
362
|
|
363
363
|
def get_attention_sliding_window_size(self):
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -35,7 +35,7 @@ from sglang.srt.layers.linear import (
|
|
35
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.model_executor.forward_batch_info import
|
38
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
39
|
|
40
40
|
|
41
41
|
class GPTBigCodeAttention(nn.Module):
|
@@ -90,7 +90,7 @@ class GPTBigCodeAttention(nn.Module):
|
|
90
90
|
def forward(
|
91
91
|
self,
|
92
92
|
hidden_states: torch.Tensor,
|
93
|
-
|
93
|
+
forward_batch: ForwardBatch,
|
94
94
|
) -> torch.Tensor:
|
95
95
|
qkv, _ = self.c_attn(hidden_states)
|
96
96
|
q, k, v = qkv.split(
|
@@ -101,7 +101,7 @@ class GPTBigCodeAttention(nn.Module):
|
|
101
101
|
],
|
102
102
|
dim=-1,
|
103
103
|
)
|
104
|
-
attn_output = self.attn(q, k, v,
|
104
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
105
105
|
attn_output, _ = self.c_proj(attn_output)
|
106
106
|
return attn_output
|
107
107
|
|
@@ -160,12 +160,12 @@ class GPTBigCodeBlock(nn.Module):
|
|
160
160
|
def forward(
|
161
161
|
self,
|
162
162
|
hidden_states: torch.Tensor,
|
163
|
-
|
163
|
+
forward_batch: ForwardBatch,
|
164
164
|
) -> torch.Tensor:
|
165
165
|
residual = hidden_states
|
166
166
|
hidden_states = self.ln_1(hidden_states)
|
167
167
|
attn_output = self.attn(
|
168
|
-
hidden_states=hidden_states,
|
168
|
+
hidden_states=hidden_states, forward_batch=forward_batch
|
169
169
|
)
|
170
170
|
# residual connection
|
171
171
|
hidden_states = attn_output + residual
|
@@ -214,7 +214,7 @@ class GPTBigCodeModel(nn.Module):
|
|
214
214
|
self,
|
215
215
|
input_ids: torch.Tensor,
|
216
216
|
position_ids: torch.Tensor,
|
217
|
-
|
217
|
+
forward_batch: ForwardBatch,
|
218
218
|
) -> torch.Tensor:
|
219
219
|
inputs_embeds = self.wte(input_ids)
|
220
220
|
position_embeds = self.wpe(position_ids)
|
@@ -222,7 +222,7 @@ class GPTBigCodeModel(nn.Module):
|
|
222
222
|
|
223
223
|
for i in range(len(self.h)):
|
224
224
|
layer = self.h[i]
|
225
|
-
hidden_states = layer(hidden_states,
|
225
|
+
hidden_states = layer(hidden_states, forward_batch)
|
226
226
|
|
227
227
|
hidden_states = self.ln_f(hidden_states)
|
228
228
|
return hidden_states
|
@@ -267,11 +267,11 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
267
267
|
self,
|
268
268
|
input_ids: torch.Tensor,
|
269
269
|
positions: torch.Tensor,
|
270
|
-
|
270
|
+
forward_batch: ForwardBatch,
|
271
271
|
) -> torch.Tensor:
|
272
|
-
hidden_states = self.transformer(input_ids, positions,
|
272
|
+
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
273
273
|
return self.logits_processor(
|
274
|
-
input_ids, hidden_states, self.lm_head.weight,
|
274
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
275
275
|
)
|
276
276
|
|
277
277
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/grok.py
CHANGED
@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.model_executor.forward_batch_info import
|
49
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
50
50
|
|
51
51
|
|
52
52
|
class Grok1MoE(nn.Module):
|
@@ -173,12 +173,12 @@ class Grok1Attention(nn.Module):
|
|
173
173
|
self,
|
174
174
|
positions: torch.Tensor,
|
175
175
|
hidden_states: torch.Tensor,
|
176
|
-
|
176
|
+
forward_batch: ForwardBatch,
|
177
177
|
) -> torch.Tensor:
|
178
178
|
qkv, _ = self.qkv_proj(hidden_states)
|
179
179
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
180
180
|
q, k = self.rotary_emb(positions, q, k)
|
181
|
-
attn_output = self.attn(q, k, v,
|
181
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
182
182
|
output, _ = self.o_proj(attn_output)
|
183
183
|
return output
|
184
184
|
|
@@ -219,7 +219,7 @@ class Grok1DecoderLayer(nn.Module):
|
|
219
219
|
self,
|
220
220
|
positions: torch.Tensor,
|
221
221
|
hidden_states: torch.Tensor,
|
222
|
-
|
222
|
+
forward_batch: ForwardBatch,
|
223
223
|
) -> torch.Tensor:
|
224
224
|
# Self Attention
|
225
225
|
hidden_states = (
|
@@ -227,7 +227,7 @@ class Grok1DecoderLayer(nn.Module):
|
|
227
227
|
self.self_attn(
|
228
228
|
positions=positions,
|
229
229
|
hidden_states=self.pre_attn_norm(hidden_states),
|
230
|
-
|
230
|
+
forward_batch=forward_batch,
|
231
231
|
)
|
232
232
|
)
|
233
233
|
+ hidden_states
|
@@ -268,7 +268,7 @@ class Grok1Model(nn.Module):
|
|
268
268
|
self,
|
269
269
|
input_ids: torch.Tensor,
|
270
270
|
positions: torch.Tensor,
|
271
|
-
|
271
|
+
forward_batch: ForwardBatch,
|
272
272
|
input_embeds: torch.Tensor = None,
|
273
273
|
) -> torch.Tensor:
|
274
274
|
if input_embeds is None:
|
@@ -278,7 +278,7 @@ class Grok1Model(nn.Module):
|
|
278
278
|
hidden_states = input_embeds
|
279
279
|
|
280
280
|
for i in range(len(self.layers)):
|
281
|
-
hidden_states = self.layers[i](positions, hidden_states,
|
281
|
+
hidden_states = self.layers[i](positions, hidden_states, forward_batch)
|
282
282
|
hidden_states = self.norm(hidden_states)
|
283
283
|
hidden_states.mul_(self.config.output_multiplier_scale)
|
284
284
|
return hidden_states
|
@@ -309,12 +309,12 @@ class Grok1ForCausalLM(nn.Module):
|
|
309
309
|
self,
|
310
310
|
input_ids: torch.Tensor,
|
311
311
|
positions: torch.Tensor,
|
312
|
-
|
312
|
+
forward_batch: ForwardBatch,
|
313
313
|
input_embeds: torch.Tensor = None,
|
314
314
|
) -> torch.Tensor:
|
315
|
-
hidden_states = self.model(input_ids, positions,
|
315
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
316
316
|
return self.logits_processor(
|
317
|
-
input_ids, hidden_states, self.lm_head.weight,
|
317
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
318
318
|
)
|
319
319
|
|
320
320
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/internlm2.py
CHANGED
@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
|
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.model_executor.forward_batch_info import
|
43
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
44
44
|
|
45
45
|
|
46
46
|
class InternLM2MLP(nn.Module):
|
@@ -137,12 +137,12 @@ class InternLM2Attention(nn.Module):
|
|
137
137
|
self,
|
138
138
|
positions: torch.Tensor,
|
139
139
|
hidden_states: torch.Tensor,
|
140
|
-
|
140
|
+
forward_batch: ForwardBatch,
|
141
141
|
) -> torch.Tensor:
|
142
142
|
qkv, _ = self.wqkv(hidden_states)
|
143
143
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
144
144
|
q, k = self.rotary_emb(positions, q, k)
|
145
|
-
attn_output = self.attn(q, k, v,
|
145
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
146
146
|
output, _ = self.wo(attn_output)
|
147
147
|
return output
|
148
148
|
|
@@ -182,7 +182,7 @@ class InternLMDecoderLayer(nn.Module):
|
|
182
182
|
self,
|
183
183
|
positions: torch.Tensor,
|
184
184
|
hidden_states: torch.Tensor,
|
185
|
-
|
185
|
+
forward_batch: ForwardBatch,
|
186
186
|
residual: Optional[torch.Tensor],
|
187
187
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
188
188
|
# Self Attention
|
@@ -194,7 +194,7 @@ class InternLMDecoderLayer(nn.Module):
|
|
194
194
|
hidden_states = self.attention(
|
195
195
|
positions=positions,
|
196
196
|
hidden_states=hidden_states,
|
197
|
-
|
197
|
+
forward_batch=forward_batch,
|
198
198
|
)
|
199
199
|
|
200
200
|
# Fully Connected
|
@@ -229,7 +229,7 @@ class InternLM2Model(nn.Module):
|
|
229
229
|
self,
|
230
230
|
input_ids: torch.Tensor,
|
231
231
|
positions: torch.Tensor,
|
232
|
-
|
232
|
+
forward_batch: ForwardBatch,
|
233
233
|
input_embeds: torch.Tensor = None,
|
234
234
|
) -> torch.Tensor:
|
235
235
|
if input_embeds is None:
|
@@ -242,7 +242,7 @@ class InternLM2Model(nn.Module):
|
|
242
242
|
hidden_states, residual = layer(
|
243
243
|
positions,
|
244
244
|
hidden_states,
|
245
|
-
|
245
|
+
forward_batch,
|
246
246
|
residual,
|
247
247
|
)
|
248
248
|
hidden_states, _ = self.norm(hidden_states, residual)
|
@@ -268,12 +268,12 @@ class InternLM2ForCausalLM(nn.Module):
|
|
268
268
|
self,
|
269
269
|
input_ids: torch.Tensor,
|
270
270
|
positions: torch.Tensor,
|
271
|
-
|
271
|
+
forward_batch: ForwardBatch,
|
272
272
|
input_embeds: torch.Tensor = None,
|
273
273
|
) -> torch.Tensor:
|
274
|
-
hidden_states = self.model(input_ids, positions,
|
274
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
275
275
|
return self.logits_processor(
|
276
|
-
input_ids, hidden_states, self.output.weight,
|
276
|
+
input_ids, hidden_states, self.output.weight, forward_batch
|
277
277
|
)
|
278
278
|
|
279
279
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|