sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py
CHANGED
@@ -271,12 +271,13 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
|
271
271
|
batch,
|
272
272
|
dp_size=model_runner.server_args.dp_size,
|
273
273
|
attn_tp_size=1,
|
274
|
-
|
274
|
+
tp_group=model_runner.tp_group,
|
275
275
|
get_idle_batch=None,
|
276
276
|
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
277
277
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
278
278
|
speculative_num_draft_tokens=None,
|
279
279
|
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
280
|
+
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
280
281
|
)
|
281
282
|
|
282
283
|
|
sglang/eval/loogle_eval.py
CHANGED
@@ -73,6 +73,8 @@ async def benchmark(args):
|
|
73
73
|
|
74
74
|
tasks: List[asyncio.Task] = []
|
75
75
|
for idx, ex in enumerate(dataset):
|
76
|
+
if idx >= args.num_prompts:
|
77
|
+
break
|
76
78
|
tasks.append(
|
77
79
|
asyncio.create_task(
|
78
80
|
fetch_response(
|
@@ -103,6 +105,8 @@ def analyse(args):
|
|
103
105
|
hyps: List[str] = []
|
104
106
|
refs: List[str] = []
|
105
107
|
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
|
108
|
+
if idx >= args.num_prompts:
|
109
|
+
break
|
106
110
|
pkl_file = output_dir / f"response_{idx}.pkl"
|
107
111
|
if not pkl_file.exists():
|
108
112
|
raise FileNotFoundError(pkl_file)
|
@@ -150,6 +154,9 @@ if __name__ == "__main__":
|
|
150
154
|
parser.add_argument(
|
151
155
|
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
|
152
156
|
)
|
157
|
+
parser.add_argument(
|
158
|
+
"--num-prompts", type=int, default=10000, help="Number of prompts to run"
|
159
|
+
)
|
153
160
|
args = parser.parse_args()
|
154
161
|
|
155
162
|
asyncio.run(benchmark(args))
|
sglang/srt/_custom_ops.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
|
2
2
|
import logging
|
3
|
-
from typing import List, Tuple
|
3
|
+
from typing import List, Optional, Tuple
|
4
4
|
|
5
5
|
import torch
|
6
6
|
|
@@ -114,6 +114,34 @@ else:
|
|
114
114
|
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
|
115
115
|
return sgl_kernel.allreduce.get_meta_buffer_ipc_handle(inp)
|
116
116
|
|
117
|
+
# ROCM custom quick allreduce
|
118
|
+
|
119
|
+
def init_custom_qr(
|
120
|
+
rank: int, world_size: int, qr_max_size: Optional[int] = None
|
121
|
+
) -> int:
|
122
|
+
return sgl_kernel.allreduce.init_custom_qr(world_size, rank, qr_max_size)
|
123
|
+
|
124
|
+
def qr_get_handle(fa: int) -> torch.Tensor:
|
125
|
+
return sgl_kernel.allreduce.qr_get_handle(fa)
|
126
|
+
|
127
|
+
def qr_open_handles(fa: int, handles: list[torch.Tensor]) -> None:
|
128
|
+
sgl_kernel.allreduce.qr_open_handles(fa, handles)
|
129
|
+
|
130
|
+
def qr_all_reduce(
|
131
|
+
fa: int,
|
132
|
+
inp: torch.Tensor,
|
133
|
+
out: torch.Tensor,
|
134
|
+
quant_level: int,
|
135
|
+
cast_bf2half: bool,
|
136
|
+
) -> None:
|
137
|
+
sgl_kernel.allreduce.qr_all_reduce(fa, inp, out, quant_level, cast_bf2half)
|
138
|
+
|
139
|
+
def qr_destroy(fa: int) -> None:
|
140
|
+
sgl_kernel.allreduce.qr_destroy(fa)
|
141
|
+
|
142
|
+
def qr_max_size() -> int:
|
143
|
+
return sgl_kernel.allreduce.qr_max_size()
|
144
|
+
|
117
145
|
|
118
146
|
def mscclpp_generate_unique_id() -> bytes:
|
119
147
|
return sgl_kernel.allreduce.mscclpp_generate_unique_id()
|
@@ -42,6 +42,9 @@ def select_best_resolution(image_size, candidate_resolutions):
|
|
42
42
|
|
43
43
|
|
44
44
|
class DictOutput(object):
|
45
|
+
def items(self):
|
46
|
+
return self.__dict__.items()
|
47
|
+
|
45
48
|
def keys(self):
|
46
49
|
return self.__dict__.keys()
|
47
50
|
|
@@ -59,7 +62,9 @@ class DictOutput(object):
|
|
59
62
|
class VLChatProcessorOutput(DictOutput):
|
60
63
|
input_ids: torch.LongTensor
|
61
64
|
target_ids: torch.LongTensor
|
62
|
-
|
65
|
+
pixel_values: (
|
66
|
+
torch.Tensor
|
67
|
+
) # rename from "images" to "pixel_values" for compatibility
|
63
68
|
images_seq_mask: torch.BoolTensor
|
64
69
|
images_spatial_crop: torch.LongTensor
|
65
70
|
|
@@ -312,10 +317,14 @@ class DeepseekVLV2Processor(ProcessorMixin):
|
|
312
317
|
images = torch.stack(images_list, dim=0)
|
313
318
|
images_spatial_crop = torch.tensor(images_spatial_crop, dtype=torch.long)
|
314
319
|
|
320
|
+
images_spatial_crop = torch.stack(
|
321
|
+
[images_spatial_crop], dim=0
|
322
|
+
) # stack the tensor to make it a batch of 1
|
323
|
+
|
315
324
|
prepare = VLChatProcessorOutput(
|
316
325
|
input_ids=input_ids,
|
317
326
|
target_ids=target_ids,
|
318
|
-
|
327
|
+
pixel_values=images,
|
319
328
|
images_seq_mask=images_seq_mask,
|
320
329
|
images_spatial_crop=images_spatial_crop,
|
321
330
|
)
|
sglang/srt/configs/internvl.py
CHANGED
@@ -9,6 +9,7 @@ from transformers import (
|
|
9
9
|
LlamaConfig,
|
10
10
|
PretrainedConfig,
|
11
11
|
PreTrainedTokenizer,
|
12
|
+
Qwen2Config,
|
12
13
|
)
|
13
14
|
|
14
15
|
from sglang.utils import logger
|
@@ -311,6 +312,8 @@ class InternVLChatConfig(PretrainedConfig):
|
|
311
312
|
self.llm_config = LlamaConfig(**llm_config)
|
312
313
|
elif llm_config.get("architectures")[0] == "InternLM2ForCausalLM":
|
313
314
|
self.llm_config = InternLM2Config(**llm_config)
|
315
|
+
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
316
|
+
self.llm_config = Qwen2Config(**llm_config)
|
314
317
|
else:
|
315
318
|
raise ValueError(
|
316
319
|
"Unsupported architecture: {}".format(
|
sglang/srt/configs/janus_pro.py
CHANGED
@@ -53,7 +53,7 @@ class ModelConfig:
|
|
53
53
|
trust_remote_code: bool = True,
|
54
54
|
revision: Optional[str] = None,
|
55
55
|
context_length: Optional[int] = None,
|
56
|
-
model_override_args:
|
56
|
+
model_override_args: str = "{}",
|
57
57
|
is_embedding: Optional[bool] = None,
|
58
58
|
enable_multimodal: Optional[bool] = None,
|
59
59
|
dtype: str = "auto",
|
@@ -61,13 +61,13 @@ class ModelConfig:
|
|
61
61
|
override_config_file: Optional[str] = None,
|
62
62
|
is_draft_model: bool = False,
|
63
63
|
hybrid_kvcache_ratio: Optional[float] = None,
|
64
|
-
|
64
|
+
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
65
65
|
) -> None:
|
66
66
|
|
67
67
|
self.model_path = model_path
|
68
68
|
self.revision = revision
|
69
69
|
self.quantization = quantization
|
70
|
-
self.
|
70
|
+
self.model_impl = model_impl
|
71
71
|
|
72
72
|
# Parse args
|
73
73
|
self.maybe_pull_model_tokenizer_from_remote()
|
@@ -286,7 +286,7 @@ class ModelConfig:
|
|
286
286
|
dtype=server_args.dtype,
|
287
287
|
quantization=server_args.quantization,
|
288
288
|
hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
|
289
|
-
|
289
|
+
model_impl=server_args.model_impl,
|
290
290
|
**kwargs,
|
291
291
|
)
|
292
292
|
|
@@ -391,6 +391,7 @@ class ModelConfig:
|
|
391
391
|
"compressed-tensors",
|
392
392
|
"fbgemm_fp8",
|
393
393
|
"w8a8_fp8",
|
394
|
+
"petit_nvfp4",
|
394
395
|
]
|
395
396
|
optimized_quantization_methods = [
|
396
397
|
"fp8",
|
@@ -408,9 +409,11 @@ class ModelConfig:
|
|
408
409
|
"moe_wna16",
|
409
410
|
"qoq",
|
410
411
|
"w4afp8",
|
412
|
+
"petit_nvfp4",
|
411
413
|
]
|
412
414
|
compatible_quantization_methods = {
|
413
415
|
"modelopt_fp4": ["modelopt"],
|
416
|
+
"petit_nvfp4": ["modelopt"],
|
414
417
|
"w8a8_int8": ["compressed-tensors", "compressed_tensors"],
|
415
418
|
"w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
|
416
419
|
}
|
@@ -472,7 +475,7 @@ class ModelConfig:
|
|
472
475
|
|
473
476
|
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
|
474
477
|
eos_ids = getattr(self.hf_config, "eos_token_id", None)
|
475
|
-
if eos_ids:
|
478
|
+
if eos_ids is not None:
|
476
479
|
# it can be either int or list of int
|
477
480
|
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
|
478
481
|
if eos_ids is None:
|
@@ -711,7 +714,6 @@ def get_hybrid_layer_ids(model_architectures: List[str], num_hidden_layers: int)
|
|
711
714
|
i for i in range(num_hidden_layers) if (i + 1) % 4 == 0
|
712
715
|
]
|
713
716
|
else:
|
714
|
-
|
715
|
-
|
716
|
-
)
|
717
|
+
swa_attention_layer_ids = None
|
718
|
+
full_attention_layer_ids = None
|
717
719
|
return swa_attention_layer_ids, full_attention_layer_ids
|
@@ -115,5 +115,7 @@ def adjust_config_with_unaligned_cpu_tp(
|
|
115
115
|
model_config = update_intermediate_size(
|
116
116
|
model_config, "intermediate_size", intermediate_padding_size
|
117
117
|
)
|
118
|
-
|
118
|
+
model_config = update_intermediate_size(
|
119
|
+
model_config, "intermediate_size_mlp", intermediate_padding_size
|
120
|
+
)
|
119
121
|
return model_config
|
sglang/srt/conversation.py
CHANGED
@@ -729,6 +729,7 @@ register_conv_template(
|
|
729
729
|
sep="<|end|>",
|
730
730
|
stop_str="<|end|>",
|
731
731
|
image_token="<|endoftext10|>",
|
732
|
+
audio_token="<|endoftext11|>",
|
732
733
|
)
|
733
734
|
)
|
734
735
|
|
@@ -983,7 +984,7 @@ register_conv_template(
|
|
983
984
|
|
984
985
|
@register_conv_template_matching_function
|
985
986
|
def match_internvl(model_path: str):
|
986
|
-
if re.search(r"
|
987
|
+
if re.search(r"internvl", model_path, re.IGNORECASE):
|
987
988
|
return "internvl-2-5"
|
988
989
|
|
989
990
|
|
sglang/srt/custom_op.py
CHANGED
@@ -29,15 +29,18 @@ class CustomOp(nn.Module):
|
|
29
29
|
|
30
30
|
self._original_forward_method = self._forward_method
|
31
31
|
# NOTE: Temporarily workaround MoE
|
32
|
+
# The performance of torch.compile on this layer is not always good when bs > 1,
|
33
|
+
# so we decide to only use torch.compile when bs=1
|
32
34
|
if "FusedMoE" in self.__class__.__name__:
|
33
35
|
if num_tokens == 1:
|
34
36
|
from sglang.srt.layers.moe.fused_moe_native import (
|
35
37
|
fused_moe_forward_native,
|
36
38
|
)
|
37
39
|
|
38
|
-
# The performance of torch.compile on this layer is not always good when bs > 1,
|
39
|
-
# so we decide to only use torch.compile when bs =1
|
40
40
|
self._forward_method = fused_moe_forward_native
|
41
|
+
elif "TopK" in self.__class__.__name__:
|
42
|
+
if num_tokens == 1:
|
43
|
+
self._forward_method = self.forward_native
|
41
44
|
else:
|
42
45
|
self._forward_method = self.forward_native
|
43
46
|
self.is_torch_compile = True
|
@@ -23,7 +23,14 @@ from sglang.srt.disaggregation.base.conn import (
|
|
23
23
|
)
|
24
24
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
25
25
|
from sglang.srt.server_args import ServerArgs
|
26
|
-
from sglang.srt.utils import
|
26
|
+
from sglang.srt.utils import (
|
27
|
+
format_tcp_address,
|
28
|
+
get_free_port,
|
29
|
+
get_ip,
|
30
|
+
get_local_ip_by_remote,
|
31
|
+
is_valid_ipv6_address,
|
32
|
+
maybe_wrap_ipv6_address,
|
33
|
+
)
|
27
34
|
|
28
35
|
logger = logging.getLogger(__name__)
|
29
36
|
|
@@ -65,11 +72,18 @@ class CommonKVManager(BaseKVManager):
|
|
65
72
|
def _register_to_bootstrap(self):
|
66
73
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
67
74
|
if self.dist_init_addr:
|
68
|
-
|
75
|
+
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
76
|
+
if self.dist_init_addr.endswith("]"):
|
77
|
+
host = self.dist_init_addr
|
78
|
+
else:
|
79
|
+
host, _ = self.dist_init_addr.rsplit(":", 1)
|
80
|
+
else:
|
81
|
+
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
69
82
|
else:
|
70
|
-
|
83
|
+
host = get_ip()
|
84
|
+
host = maybe_wrap_ipv6_address(host)
|
71
85
|
|
72
|
-
bootstrap_server_url = f"{
|
86
|
+
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
73
87
|
url = f"http://{bootstrap_server_url}/route"
|
74
88
|
payload = {
|
75
89
|
"role": "Prefill",
|
@@ -92,8 +106,10 @@ class CommonKVManager(BaseKVManager):
|
|
92
106
|
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
93
107
|
|
94
108
|
@cache
|
95
|
-
def _connect(self, endpoint: str):
|
109
|
+
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
96
110
|
socket = zmq.Context().socket(zmq.PUSH)
|
111
|
+
if is_ipv6:
|
112
|
+
socket.setsockopt(zmq.IPV6, 1)
|
97
113
|
socket.connect(endpoint)
|
98
114
|
return socket
|
99
115
|
|
@@ -263,15 +279,27 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
263
279
|
return None
|
264
280
|
|
265
281
|
@classmethod
|
266
|
-
def _connect(cls, endpoint: str):
|
282
|
+
def _connect(cls, endpoint: str, is_ipv6: bool = False):
|
267
283
|
with cls._global_lock:
|
268
284
|
if endpoint not in cls._socket_cache:
|
269
285
|
sock = cls._ctx.socket(zmq.PUSH)
|
286
|
+
if is_ipv6:
|
287
|
+
sock.setsockopt(zmq.IPV6, 1)
|
270
288
|
sock.connect(endpoint)
|
271
289
|
cls._socket_cache[endpoint] = sock
|
272
290
|
cls._socket_locks[endpoint] = threading.Lock()
|
273
291
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
274
292
|
|
293
|
+
@classmethod
|
294
|
+
def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
|
295
|
+
ip_address = bootstrap_info["rank_ip"]
|
296
|
+
port = bootstrap_info["rank_port"]
|
297
|
+
is_ipv6_address = is_valid_ipv6_address(ip_address)
|
298
|
+
sock, lock = cls._connect(
|
299
|
+
format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
|
300
|
+
)
|
301
|
+
return sock, lock
|
302
|
+
|
275
303
|
def _register_kv_args(self):
|
276
304
|
pass
|
277
305
|
|
@@ -439,7 +439,15 @@ class DecodePreallocQueue:
|
|
439
439
|
else 0
|
440
440
|
)
|
441
441
|
|
442
|
-
|
442
|
+
if self.scheduler.model_config.is_hybrid:
|
443
|
+
available_size = min(
|
444
|
+
self.token_to_kv_pool_allocator.full_available_size(),
|
445
|
+
self.token_to_kv_pool_allocator.swa_available_size(),
|
446
|
+
)
|
447
|
+
else:
|
448
|
+
available_size = self.token_to_kv_pool_allocator.available_size()
|
449
|
+
|
450
|
+
allocatable_tokens = available_size - max(
|
443
451
|
# preserve some space for future decode
|
444
452
|
self.num_reserved_decode_tokens
|
445
453
|
* (
|
@@ -17,6 +17,7 @@ from fastapi import FastAPI, HTTPException
|
|
17
17
|
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
18
18
|
|
19
19
|
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
20
|
+
from sglang.srt.utils import maybe_wrap_ipv6_address
|
20
21
|
|
21
22
|
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
22
23
|
1024 * 64
|
@@ -271,7 +272,7 @@ async def handle_generate_request(request_data: dict):
|
|
271
272
|
|
272
273
|
# Parse and transform prefill_server for bootstrap data
|
273
274
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
274
|
-
hostname = parsed_url.hostname
|
275
|
+
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
275
276
|
modified_request = request_data.copy()
|
276
277
|
|
277
278
|
batch_size = _get_request_batch_size(modified_request)
|
@@ -309,7 +310,7 @@ async def _forward_to_backend(request_data: dict, endpoint_name: str):
|
|
309
310
|
|
310
311
|
# Parse and transform prefill_server for bootstrap data
|
311
312
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
312
|
-
hostname = parsed_url.hostname
|
313
|
+
hostname = maybe_wrap_ipv6_address(parsed_url.hostname)
|
313
314
|
modified_request = request_data.copy()
|
314
315
|
modified_request.update(
|
315
316
|
{
|