sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +87 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +26 -7
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +374 -136
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +13 -13
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +25 -27
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +237 -204
- sglang/srt/managers/detokenizer_manager.py +48 -2
- sglang/srt/managers/io_struct.py +57 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +94 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +122 -42
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +51 -23
- sglang/srt/mem_cache/hiradix_cache.py +87 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +77 -14
- sglang/srt/mem_cache/memory_pool_host.py +4 -5
- sglang/srt/mem_cache/radix_cache.py +6 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -5
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +38 -13
- sglang/srt/models/gpt_oss.py +2 -15
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +66 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +122 -56
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +73 -5
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
sglang/srt/models/qwen2.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
# Modify details for the adaptation of Qwen2 model.
|
17
17
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
18
18
|
import logging
|
19
|
-
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
19
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
20
20
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
@@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|
431
431
|
quant_config=quant_config,
|
432
432
|
prefix=add_prefix("lm_head", prefix),
|
433
433
|
)
|
434
|
-
|
435
434
|
else:
|
436
435
|
# ranks other than the last rank will have a placeholder layer
|
437
436
|
self.lm_head = PPMissingLayer()
|
@@ -452,6 +451,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
452
451
|
|
453
452
|
self.logits_processor = LogitsProcessor(config)
|
454
453
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
454
|
+
# For EAGLE3 support
|
455
|
+
self.capture_aux_hidden_states = False
|
455
456
|
|
456
457
|
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
457
458
|
return self.model.get_input_embedding(input_ids)
|
@@ -476,11 +477,18 @@ class Qwen2ForCausalLM(nn.Module):
|
|
476
477
|
input_embeds,
|
477
478
|
pp_proxy_tensors=pp_proxy_tensors,
|
478
479
|
)
|
480
|
+
aux_hidden_states = None
|
481
|
+
if self.capture_aux_hidden_states:
|
482
|
+
hidden_states, aux_hidden_states = hidden_states
|
479
483
|
|
480
484
|
if self.pp_group.is_last_rank:
|
481
485
|
if not get_embedding:
|
482
486
|
return self.logits_processor(
|
483
|
-
input_ids,
|
487
|
+
input_ids,
|
488
|
+
hidden_states,
|
489
|
+
self.lm_head,
|
490
|
+
forward_batch,
|
491
|
+
aux_hidden_states,
|
484
492
|
)
|
485
493
|
else:
|
486
494
|
return self.pooler(hidden_states, forward_batch)
|
@@ -619,5 +627,20 @@ class Qwen2ForCausalLM(nn.Module):
|
|
619
627
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
620
628
|
self.model.load_kv_cache_scales(quantization_param_path)
|
621
629
|
|
630
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
631
|
+
if not self.pp_group.is_last_rank:
|
632
|
+
return
|
633
|
+
|
634
|
+
self.capture_aux_hidden_states = True
|
635
|
+
if layer_ids is None:
|
636
|
+
num_layers = self.config.num_hidden_layers
|
637
|
+
self.model.layers_to_capture = [
|
638
|
+
2,
|
639
|
+
num_layers // 2,
|
640
|
+
num_layers - 3,
|
641
|
+
] # Specific layers for EAGLE3 support
|
642
|
+
else:
|
643
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
644
|
+
|
622
645
|
|
623
646
|
EntryClass = Qwen2ForCausalLM
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -31,7 +31,6 @@ import torch.nn as nn
|
|
31
31
|
import torch.nn.functional as F
|
32
32
|
from einops import rearrange
|
33
33
|
from transformers.activations import ACT2FN
|
34
|
-
from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm
|
35
34
|
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
36
35
|
Qwen2_5_VLConfig,
|
37
36
|
Qwen2_5_VLVisionConfig,
|
@@ -43,7 +42,12 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
43
42
|
|
44
43
|
from sglang.srt.hf_transformers_utils import get_processor
|
45
44
|
from sglang.srt.layers.attention.vision import VisionAttention
|
46
|
-
from sglang.srt.layers.
|
45
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
46
|
+
from sglang.srt.layers.linear import (
|
47
|
+
ColumnParallelLinear,
|
48
|
+
MergedColumnParallelLinear,
|
49
|
+
RowParallelLinear,
|
50
|
+
)
|
47
51
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
52
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
49
53
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__)
|
|
62
66
|
|
63
67
|
|
64
68
|
class Qwen2_5_VLMLP(nn.Module):
|
65
|
-
|
66
69
|
def __init__(
|
67
70
|
self,
|
68
71
|
in_features: int,
|
@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module):
|
|
73
76
|
prefix: str = "",
|
74
77
|
):
|
75
78
|
super().__init__()
|
76
|
-
self.
|
77
|
-
in_features,
|
78
|
-
hidden_features,
|
79
|
-
bias=bias,
|
80
|
-
quant_config=quant_config,
|
81
|
-
prefix=add_prefix("gate_proj", prefix),
|
82
|
-
)
|
83
|
-
self.up_proj = ColumnParallelLinear(
|
84
|
-
in_features,
|
85
|
-
hidden_features,
|
79
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
80
|
+
input_size=in_features,
|
81
|
+
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
86
82
|
bias=bias,
|
87
83
|
quant_config=quant_config,
|
88
|
-
prefix=add_prefix("
|
84
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
89
85
|
)
|
90
86
|
self.down_proj = RowParallelLinear(
|
91
87
|
hidden_features,
|
@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module):
|
|
97
93
|
self.act = ACT2FN[hidden_act]
|
98
94
|
|
99
95
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
return x
|
96
|
+
gate_up, _ = self.gate_up_proj(x)
|
97
|
+
gate, up = gate_up.chunk(2, dim=-1)
|
98
|
+
x = self.act(gate) * up
|
99
|
+
x_down, _ = self.down_proj(x)
|
100
|
+
return x_down
|
106
101
|
|
107
102
|
|
108
103
|
class Qwen2_5_VisionBlock(nn.Module):
|
@@ -122,8 +117,8 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
122
117
|
super().__init__()
|
123
118
|
if norm_layer is None:
|
124
119
|
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
125
|
-
self.norm1 =
|
126
|
-
self.norm2 =
|
120
|
+
self.norm1 = RMSNorm(dim, eps=1e-6)
|
121
|
+
self.norm2 = RMSNorm(dim, eps=1e-6)
|
127
122
|
|
128
123
|
if attn_implementation is None:
|
129
124
|
softmax_in_single_precision = False
|
@@ -174,18 +169,29 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
174
169
|
cu_seqlens: torch.Tensor,
|
175
170
|
position_embeddings: torch.Tensor,
|
176
171
|
) -> torch.Tensor:
|
177
|
-
|
178
|
-
|
172
|
+
S, B, H = x.shape
|
173
|
+
# norm1: flatten to 2D -> [S*B, H], then reshape back
|
174
|
+
x2d = x.reshape(-1, H)
|
175
|
+
hidden_states = self.norm1(x2d).reshape(S, B, H)
|
176
|
+
|
177
|
+
# Attention expects [B, S, H]
|
178
|
+
hidden_states = rearrange(hidden_states, "s b h -> b s h")
|
179
179
|
attn = self.attn(
|
180
180
|
hidden_states,
|
181
181
|
cu_seqlens=cu_seqlens,
|
182
182
|
position_embeddings=position_embeddings,
|
183
183
|
)
|
184
|
-
attn = rearrange(attn, "b s
|
185
|
-
|
186
|
-
norm2
|
187
|
-
|
188
|
-
|
184
|
+
attn = rearrange(attn, "b s h -> s b h")
|
185
|
+
|
186
|
+
# norm2 with fused residual-add: also 2D
|
187
|
+
attn2d = attn.reshape(-1, H)
|
188
|
+
x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
|
189
|
+
x_norm = x_norm_2d.reshape(S, B, H)
|
190
|
+
x_after_add = x_after_add_2d.reshape(S, B, H)
|
191
|
+
|
192
|
+
# MLP and final residual
|
193
|
+
mlp_out = self.mlp(x_norm)
|
194
|
+
x = x_after_add + mlp_out
|
189
195
|
return x
|
190
196
|
|
191
197
|
|
@@ -201,7 +207,7 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|
201
207
|
) -> None:
|
202
208
|
super().__init__()
|
203
209
|
self.hidden_size = context_dim * (spatial_merge_size**2)
|
204
|
-
self.ln_q =
|
210
|
+
self.ln_q = RMSNorm(context_dim, eps=1e-6)
|
205
211
|
self.mlp = nn.ModuleList(
|
206
212
|
[
|
207
213
|
ColumnParallelLinear(
|
@@ -223,11 +229,13 @@ class Qwen2_5_VisionPatchMerger(nn.Module):
|
|
223
229
|
)
|
224
230
|
|
225
231
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
226
|
-
x
|
227
|
-
|
228
|
-
|
232
|
+
# x expected shape: [S, B, context_dim]
|
233
|
+
S, B, D = x.shape
|
234
|
+
x2d = x.reshape(-1, D)
|
235
|
+
x2d = self.ln_q(x2d) # RMSNorm expects 2D
|
236
|
+
x2d = x2d.view(-1, self.hidden_size) # group into spatial_merge_unit
|
229
237
|
mlp_fc1, mlp_act, mlp_fc2 = self.mlp
|
230
|
-
x_parallel, _ = mlp_fc1(
|
238
|
+
x_parallel, _ = mlp_fc1(x2d)
|
231
239
|
x_parallel = mlp_act(x_parallel)
|
232
240
|
out, _ = mlp_fc2(x_parallel)
|
233
241
|
return out
|
@@ -340,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
340
348
|
|
341
349
|
@property
|
342
350
|
def device(self) -> torch.device:
|
343
|
-
return self.
|
351
|
+
return self.patch_embed.proj.weight.device
|
344
352
|
|
345
353
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
346
354
|
pos_ids = []
|
@@ -394,6 +402,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
394
402
|
)
|
395
403
|
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
|
396
404
|
|
405
|
+
# Move window_index to the same device as x before using it to index x
|
406
|
+
window_index = window_index.to(device=x.device)
|
407
|
+
|
408
|
+
# Ensure rotary_pos_emb is on the same device/dtype as x
|
409
|
+
rotary_pos_emb = rotary_pos_emb.to(device=x.device, dtype=x.dtype)
|
410
|
+
|
397
411
|
seq_len, _ = x.size()
|
398
412
|
|
399
413
|
x = x.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
|
@@ -406,12 +420,19 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
406
420
|
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
|
407
421
|
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
|
408
422
|
position_embeddings = (emb.cos(), emb.sin())
|
423
|
+
# After building position_embeddings, make sure both cos and sin are on the same device/dtype as the attention input
|
424
|
+
position_embeddings = (
|
425
|
+
position_embeddings[0].to(x.device, x.dtype),
|
426
|
+
position_embeddings[1].to(x.device, x.dtype),
|
427
|
+
)
|
409
428
|
|
410
|
-
# compute cu_seqlens
|
429
|
+
# compute cu_seqlens - move cu_seqlens to GPU and make it int32
|
411
430
|
cu_seqlens = torch.cat(
|
412
431
|
[
|
413
|
-
torch.tensor([0], device=
|
414
|
-
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
|
432
|
+
torch.tensor([0], device=x.device, dtype=torch.int32),
|
433
|
+
(grid_thw[:, 0] * grid_thw[:, 1] * grid_thw[:, 2])
|
434
|
+
.cumsum(dim=0)
|
435
|
+
.to(device=x.device, dtype=torch.int32),
|
415
436
|
]
|
416
437
|
)
|
417
438
|
cu_seqlens = F.pad(cu_seqlens, (1, 0), "constant", 0)
|
@@ -442,9 +463,8 @@ cached_get_processor = lru_cache(get_processor)
|
|
442
463
|
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
443
464
|
# BitandBytes specific attributes
|
444
465
|
default_bitsandbytes_target_modules = [
|
445
|
-
".
|
466
|
+
".gate_up_proj.",
|
446
467
|
".down_proj.",
|
447
|
-
".up_proj.",
|
448
468
|
".q_proj.",
|
449
469
|
".k_proj.",
|
450
470
|
".v_proj.",
|
@@ -526,6 +546,7 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
526
546
|
def get_input_embeddings(self):
|
527
547
|
return self.model.embed_tokens
|
528
548
|
|
549
|
+
@torch.no_grad()
|
529
550
|
def forward(
|
530
551
|
self,
|
531
552
|
input_ids: torch.Tensor,
|
@@ -590,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|
590
611
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
591
612
|
if weight_name not in name:
|
592
613
|
continue
|
593
|
-
if
|
614
|
+
if (
|
615
|
+
"visual" in name
|
616
|
+
and "up_proj" not in name
|
617
|
+
and "gate_proj" not in name
|
618
|
+
):
|
594
619
|
continue
|
595
620
|
name = name.replace(weight_name, param_name)
|
596
621
|
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -17,7 +17,7 @@
|
|
17
17
|
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
18
18
|
|
19
19
|
import logging
|
20
|
-
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
20
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
21
21
|
|
22
22
|
import torch
|
23
23
|
import torch.nn.functional as F
|
@@ -536,6 +536,8 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
536
536
|
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
537
537
|
)
|
538
538
|
self.logits_processor = LogitsProcessor(config)
|
539
|
+
# For EAGLE3 support
|
540
|
+
self.capture_aux_hidden_states = False
|
539
541
|
|
540
542
|
@torch.no_grad()
|
541
543
|
def forward(
|
@@ -553,9 +555,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
553
555
|
input_embeds,
|
554
556
|
pp_proxy_tensors=pp_proxy_tensors,
|
555
557
|
)
|
558
|
+
aux_hidden_states = None
|
559
|
+
if self.capture_aux_hidden_states:
|
560
|
+
hidden_states, aux_hidden_states = hidden_states
|
556
561
|
if self.pp_group.is_last_rank:
|
557
562
|
return self.logits_processor(
|
558
|
-
input_ids, hidden_states, self.lm_head, forward_batch
|
563
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
559
564
|
)
|
560
565
|
else:
|
561
566
|
return hidden_states
|
@@ -705,5 +710,20 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
705
710
|
num_groups=None,
|
706
711
|
)
|
707
712
|
|
713
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
714
|
+
if not self.pp_group.is_last_rank:
|
715
|
+
return
|
716
|
+
|
717
|
+
self.capture_aux_hidden_states = True
|
718
|
+
if layer_ids is None:
|
719
|
+
num_layers = self.config.num_hidden_layers
|
720
|
+
self.model.layers_to_capture = [
|
721
|
+
2,
|
722
|
+
num_layers // 2,
|
723
|
+
num_layers - 3,
|
724
|
+
] # Specific layers for EAGLE3 support
|
725
|
+
else:
|
726
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
727
|
+
|
708
728
|
|
709
729
|
EntryClass = Qwen2MoeForCausalLM
|
@@ -213,7 +213,7 @@ class TransformersForCausalLM(nn.Module):
|
|
213
213
|
"""
|
214
214
|
tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
|
215
215
|
|
216
|
-
if not tp_plan and
|
216
|
+
if not tp_plan and tp_size > 1:
|
217
217
|
raise ValueError(
|
218
218
|
f"{type(self.model)} does not support tensor parallel yet!"
|
219
219
|
)
|
@@ -13,7 +13,9 @@ from PIL import Image
|
|
13
13
|
from transformers import BaseImageProcessorFast
|
14
14
|
|
15
15
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
16
|
-
from sglang.srt.utils import load_audio, load_image, load_video, logger
|
16
|
+
from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger
|
17
|
+
|
18
|
+
_is_npu = is_npu()
|
17
19
|
|
18
20
|
|
19
21
|
@dataclasses.dataclass
|
@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC):
|
|
232
234
|
and isinstance(processor.image_processor, BaseImageProcessorFast)
|
233
235
|
and not self.server_args.disable_fast_image_processor
|
234
236
|
):
|
235
|
-
kwargs["device"] = "cuda"
|
237
|
+
kwargs["device"] = "cuda" if not _is_npu else "npu"
|
236
238
|
result = processor.__call__(
|
237
239
|
text=[input_text],
|
238
240
|
padding=True,
|