sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -0
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +7 -7
- sglang/srt/disaggregation/decode.py +8 -3
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +4 -5
- sglang/srt/entrypoints/openai/protocol.py +0 -9
- sglang/srt/entrypoints/openai/serving_chat.py +59 -265
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +8 -10
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/quantization/__init__.py +5 -3
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/modelopt_quant.py +6 -11
- sglang/srt/layers/quantization/mxfp4.py +4 -1
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +21 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +6 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +35 -20
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +15 -7
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +25 -26
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +22 -3
- sglang/srt/model_executor/forward_batch_info.py +26 -5
- sglang/srt/model_executor/model_runner.py +129 -35
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_v2.py +74 -35
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +9 -9
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +136 -19
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/server_args.py +115 -139
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +12 -4
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +26 -30
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +127 -115
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
sglang/__init__.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
# SGLang public APIs
|
2
2
|
|
3
3
|
# Frontend Language APIs
|
4
|
-
from sglang.
|
4
|
+
from sglang.global_config import global_config
|
5
|
+
from sglang.lang.api import (
|
5
6
|
Engine,
|
6
7
|
Runtime,
|
7
8
|
assistant,
|
@@ -25,22 +26,26 @@ from sglang.api import (
|
|
25
26
|
user_end,
|
26
27
|
video,
|
27
28
|
)
|
28
|
-
from sglang.global_config import global_config
|
29
29
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
30
30
|
from sglang.lang.choices import (
|
31
31
|
greedy_token_selection,
|
32
32
|
token_length_normalized,
|
33
33
|
unconditional_likelihood_normalized,
|
34
34
|
)
|
35
|
+
|
36
|
+
# Lazy import some libraries
|
35
37
|
from sglang.utils import LazyImport
|
36
38
|
from sglang.version import __version__
|
37
39
|
|
38
|
-
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
39
40
|
Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
|
40
41
|
LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
|
41
42
|
OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
|
42
43
|
VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
|
43
44
|
|
45
|
+
# Runtime Engine APIs
|
46
|
+
ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
|
47
|
+
Engine = LazyImport("sglang.srt.entrypoints.engine", "Engine")
|
48
|
+
|
44
49
|
__all__ = [
|
45
50
|
"Engine",
|
46
51
|
"Runtime",
|
sglang/bench_one_batch.py
CHANGED
@@ -61,6 +61,7 @@ from sglang.srt.configs.model_config import ModelConfig
|
|
61
61
|
from sglang.srt.distributed.parallel_state import destroy_distributed_environment
|
62
62
|
from sglang.srt.entrypoints.engine import _set_envs_and_config
|
63
63
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
64
|
+
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
64
65
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
65
66
|
from sglang.srt.managers.scheduler import Scheduler
|
66
67
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
@@ -300,6 +301,11 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
|
|
300
301
|
disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
|
301
302
|
spec_algorithm=SpeculativeAlgorithm.NONE,
|
302
303
|
speculative_num_draft_tokens=None,
|
304
|
+
enable_two_batch_overlap=model_runner.server_args.enable_two_batch_overlap,
|
305
|
+
enable_deepep_moe=MoeA2ABackend(
|
306
|
+
model_runner.server_args.moe_a2a_backend
|
307
|
+
).is_deepep(),
|
308
|
+
deepep_mode=DeepEPMode(model_runner.server_args.deepep_mode),
|
303
309
|
require_mlp_tp_gather=require_mlp_tp_gather(model_runner.server_args),
|
304
310
|
disable_overlap_schedule=model_runner.server_args.disable_overlap_schedule,
|
305
311
|
)
|
sglang/lang/chat_template.py
CHANGED
@@ -505,6 +505,22 @@ register_chat_template(
|
|
505
505
|
)
|
506
506
|
)
|
507
507
|
|
508
|
+
# Reference: https://huggingface.co/docs/transformers/main/model_doc/glm4_v#usage-example
|
509
|
+
register_chat_template(
|
510
|
+
ChatTemplate(
|
511
|
+
name="glm-4v",
|
512
|
+
default_system_prompt=None,
|
513
|
+
role_prefix_and_suffix={
|
514
|
+
"system": ("<|system|>\n", "\n"),
|
515
|
+
"user": ("<|user|>\n", "\n"),
|
516
|
+
"assistant": ("<|assistant|>\n", "\n"),
|
517
|
+
},
|
518
|
+
style=ChatTemplateStyle.PLAIN,
|
519
|
+
stop_str=["<|user|>", "<|endoftext|>", "<|observation|>"],
|
520
|
+
image_token="<|image|>",
|
521
|
+
)
|
522
|
+
)
|
523
|
+
|
508
524
|
|
509
525
|
@register_chat_template_matching_function
|
510
526
|
def match_deepseek(model_path: str):
|
@@ -562,6 +578,8 @@ def match_chat_ml(model_path: str):
|
|
562
578
|
return "chatml"
|
563
579
|
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
564
580
|
return "qwen2-vl"
|
581
|
+
if re.search(r"glm[-_]?4(\.\d+)?v", model_path, re.IGNORECASE):
|
582
|
+
return "glm-4v"
|
565
583
|
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
|
566
584
|
r"llava", model_path, re.IGNORECASE
|
567
585
|
):
|
@@ -0,0 +1,137 @@
|
|
1
|
+
import os
|
2
|
+
import sys
|
3
|
+
from contextlib import nullcontext
|
4
|
+
|
5
|
+
import torch
|
6
|
+
|
7
|
+
|
8
|
+
# NOTE copied and modified from DeepGEMM
|
9
|
+
class suppress_stdout_stderr:
|
10
|
+
def __enter__(self):
|
11
|
+
self.outnull_file = open(os.devnull, "w")
|
12
|
+
self.errnull_file = open(os.devnull, "w")
|
13
|
+
|
14
|
+
self.old_stdout_fileno_undup = sys.stdout.fileno()
|
15
|
+
self.old_stderr_fileno_undup = sys.stderr.fileno()
|
16
|
+
|
17
|
+
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
|
18
|
+
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
|
19
|
+
|
20
|
+
self.old_stdout = sys.stdout
|
21
|
+
self.old_stderr = sys.stderr
|
22
|
+
|
23
|
+
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
|
24
|
+
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
|
25
|
+
|
26
|
+
sys.stdout = self.outnull_file
|
27
|
+
sys.stderr = self.errnull_file
|
28
|
+
return self
|
29
|
+
|
30
|
+
def __exit__(self, *_):
|
31
|
+
sys.stdout = self.old_stdout
|
32
|
+
sys.stderr = self.old_stderr
|
33
|
+
|
34
|
+
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
|
35
|
+
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
|
36
|
+
|
37
|
+
os.close(self.old_stdout_fileno)
|
38
|
+
os.close(self.old_stderr_fileno)
|
39
|
+
|
40
|
+
self.outnull_file.close()
|
41
|
+
self.errnull_file.close()
|
42
|
+
|
43
|
+
|
44
|
+
# NOTE copied and modified from DeepGEMM
|
45
|
+
def bench_kineto(
|
46
|
+
fn,
|
47
|
+
kernel_names,
|
48
|
+
num_tests: int = 30,
|
49
|
+
suppress_kineto_output: bool = False,
|
50
|
+
trace_path: str = None,
|
51
|
+
flush_l2: bool = True,
|
52
|
+
with_multiple_kernels: bool = False,
|
53
|
+
):
|
54
|
+
# Conflict with Nsight Systems
|
55
|
+
using_nsys = int(os.environ.get("SGLANG_NSYS_PROFILING", 0))
|
56
|
+
|
57
|
+
# By default, flush L2 with an excessive 8GB memset to give the GPU some (literal) chill time without full idle
|
58
|
+
flush_l2_size = int(8e9 // 4)
|
59
|
+
|
60
|
+
# For some auto-tuning kernels with prints
|
61
|
+
fn()
|
62
|
+
|
63
|
+
# Profile
|
64
|
+
suppress = (
|
65
|
+
suppress_stdout_stderr
|
66
|
+
if suppress_kineto_output and not using_nsys
|
67
|
+
else nullcontext
|
68
|
+
)
|
69
|
+
with suppress():
|
70
|
+
schedule = (
|
71
|
+
torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1)
|
72
|
+
if not using_nsys
|
73
|
+
else None
|
74
|
+
)
|
75
|
+
profiler = (
|
76
|
+
torch.profiler.profile(
|
77
|
+
activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule
|
78
|
+
)
|
79
|
+
if not using_nsys
|
80
|
+
else nullcontext()
|
81
|
+
)
|
82
|
+
with profiler:
|
83
|
+
for i in range(2):
|
84
|
+
for _ in range(num_tests):
|
85
|
+
if flush_l2:
|
86
|
+
torch.empty(
|
87
|
+
flush_l2_size, dtype=torch.int, device="cuda"
|
88
|
+
).zero_()
|
89
|
+
fn()
|
90
|
+
|
91
|
+
if not using_nsys:
|
92
|
+
profiler.step()
|
93
|
+
|
94
|
+
# Return 1 if using Nsight Systems
|
95
|
+
if using_nsys:
|
96
|
+
return 1
|
97
|
+
|
98
|
+
# Parse the profiling table
|
99
|
+
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
|
100
|
+
is_tuple = isinstance(kernel_names, tuple)
|
101
|
+
prof_lines = (
|
102
|
+
profiler.key_averages()
|
103
|
+
.table(sort_by="cuda_time_total", max_name_column_width=100)
|
104
|
+
.split("\n")
|
105
|
+
)
|
106
|
+
kernel_names = (kernel_names,) if isinstance(kernel_names, str) else kernel_names
|
107
|
+
assert all([isinstance(name, str) for name in kernel_names])
|
108
|
+
if not with_multiple_kernels:
|
109
|
+
for name in kernel_names:
|
110
|
+
assert (
|
111
|
+
sum([name in line for line in prof_lines]) == 1
|
112
|
+
), f"Errors of the kernel {name} in the profiling table (table: {prof_lines})"
|
113
|
+
|
114
|
+
# Save chrome traces
|
115
|
+
if trace_path is not None:
|
116
|
+
profiler.export_chrome_trace(trace_path)
|
117
|
+
|
118
|
+
# Return average kernel times
|
119
|
+
units = {"ms": 1e3, "us": 1e6}
|
120
|
+
kernel_times = []
|
121
|
+
for name in kernel_names:
|
122
|
+
total_time = 0
|
123
|
+
total_num = 0
|
124
|
+
for line in prof_lines:
|
125
|
+
if name in line:
|
126
|
+
time_str = line.split()[-2]
|
127
|
+
num_str = line.split()[-1]
|
128
|
+
for unit, scale in units.items():
|
129
|
+
if unit in time_str:
|
130
|
+
total_time += (
|
131
|
+
float(time_str.replace(unit, "")) / scale * int(num_str)
|
132
|
+
)
|
133
|
+
total_num += int(num_str)
|
134
|
+
break
|
135
|
+
kernel_times.append(total_time / total_num)
|
136
|
+
|
137
|
+
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
@@ -64,13 +64,12 @@ class ModelConfig:
|
|
64
64
|
hybrid_kvcache_ratio: Optional[float] = None,
|
65
65
|
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
66
66
|
) -> None:
|
67
|
-
|
67
|
+
# Parse args
|
68
68
|
self.model_path = model_path
|
69
69
|
self.revision = revision
|
70
70
|
self.quantization = quantization
|
71
71
|
self.model_impl = model_impl
|
72
72
|
|
73
|
-
# Parse args
|
74
73
|
self.maybe_pull_model_tokenizer_from_remote()
|
75
74
|
self.model_override_args = json.loads(model_override_args)
|
76
75
|
kwargs = {}
|
@@ -139,6 +138,7 @@ class ModelConfig:
|
|
139
138
|
and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
|
140
139
|
):
|
141
140
|
self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
|
141
|
+
|
142
142
|
# Check model type
|
143
143
|
self.is_generation = is_generation_model(
|
144
144
|
self.hf_config.architectures, is_embedding
|
@@ -282,12 +282,10 @@ class ModelConfig:
|
|
282
282
|
# Cache attributes
|
283
283
|
self.hf_eos_token_id = self.get_hf_eos_token_id()
|
284
284
|
|
285
|
-
config = self.hf_config
|
286
|
-
|
287
285
|
# multimodal
|
288
|
-
self.image_token_id = getattr(
|
289
|
-
|
290
|
-
)
|
286
|
+
self.image_token_id = getattr(
|
287
|
+
self.hf_config, "image_token_id", None
|
288
|
+
) or getattr(self.hf_config, "image_token_index", None)
|
291
289
|
|
292
290
|
@staticmethod
|
293
291
|
def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
|
@@ -661,6 +659,8 @@ multimodal_model_archs = [
|
|
661
659
|
"DeepseekVL2ForCausalLM",
|
662
660
|
"Gemma3ForConditionalGeneration",
|
663
661
|
"Gemma3nForConditionalGeneration",
|
662
|
+
"Glm4vForConditionalGeneration",
|
663
|
+
"Glm4vMoeForConditionalGeneration",
|
664
664
|
"Grok1VForCausalLM",
|
665
665
|
"Grok1AForCausalLM",
|
666
666
|
"LlavaLlamaForCausalLM",
|
@@ -51,7 +51,7 @@ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
|
51
51
|
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
52
52
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
53
53
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
54
|
-
from sglang.srt.utils import require_mlp_sync
|
54
|
+
from sglang.srt.utils import get_int_env_var, require_mlp_sync
|
55
55
|
|
56
56
|
logger = logging.getLogger(__name__)
|
57
57
|
|
@@ -59,6 +59,8 @@ if TYPE_CHECKING:
|
|
59
59
|
from sglang.srt.managers.schedule_batch import Req
|
60
60
|
from sglang.srt.managers.scheduler import Scheduler
|
61
61
|
|
62
|
+
CLIP_MAX_NEW_TOKEN = get_int_env_var("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", 4096)
|
63
|
+
|
62
64
|
|
63
65
|
class DecodeReqToTokenPool:
|
64
66
|
"""
|
@@ -384,7 +386,10 @@ class DecodePreallocQueue:
|
|
384
386
|
max(
|
385
387
|
required_tokens_for_request,
|
386
388
|
origin_input_len
|
387
|
-
+
|
389
|
+
+ min(
|
390
|
+
decode_req.req.sampling_params.max_new_tokens,
|
391
|
+
CLIP_MAX_NEW_TOKEN,
|
392
|
+
)
|
388
393
|
- retractable_tokens,
|
389
394
|
)
|
390
395
|
> allocatable_tokens
|
@@ -433,7 +438,7 @@ class DecodePreallocQueue:
|
|
433
438
|
need_space_for_single_req = (
|
434
439
|
max(
|
435
440
|
[
|
436
|
-
x.sampling_params.max_new_tokens
|
441
|
+
min(x.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKEN)
|
437
442
|
+ len(x.origin_input_ids)
|
438
443
|
- retractable_tokens
|
439
444
|
for x in self.scheduler.running_batch.reqs
|
@@ -257,15 +257,17 @@ class MooncakeKVManager(BaseKVManager):
|
|
257
257
|
)
|
258
258
|
|
259
259
|
def register_buffer_to_engine(self):
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
260
|
+
# Batch register KV data buffers
|
261
|
+
if self.kv_args.kv_data_ptrs and self.kv_args.kv_data_lens:
|
262
|
+
self.engine.batch_register(
|
263
|
+
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
264
|
+
)
|
264
265
|
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
266
|
+
# Batch register auxiliary data buffers
|
267
|
+
if self.kv_args.aux_data_ptrs and self.kv_args.aux_data_lens:
|
268
|
+
self.engine.batch_register(
|
269
|
+
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
270
|
+
)
|
269
271
|
|
270
272
|
@cache
|
271
273
|
def _connect(self, endpoint: str, is_ipv6: bool = False):
|
@@ -356,33 +358,49 @@ class MooncakeKVManager(BaseKVManager):
|
|
356
358
|
]
|
357
359
|
assert layers_params is not None
|
358
360
|
|
359
|
-
|
360
|
-
|
361
|
+
def set_transfer_blocks(
|
362
|
+
src_ptr: int, dst_ptr: int, item_len: int
|
363
|
+
) -> List[Tuple[int, int, int]]:
|
361
364
|
transfer_blocks = []
|
362
365
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
363
366
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
364
367
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
365
368
|
length = item_len * len(prefill_index)
|
366
369
|
transfer_blocks.append((src_addr, dst_addr, length))
|
370
|
+
return transfer_blocks
|
367
371
|
|
372
|
+
# Worker function for processing a single layer
|
373
|
+
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
374
|
+
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
|
368
375
|
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
369
376
|
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
dst_ptr,
|
375
|
-
|
376
|
-
)
|
377
|
-
for (src_ptr, dst_ptr, item_len) in layers_params
|
378
|
-
]
|
377
|
+
# Worker function for processing all layers in a batch
|
378
|
+
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
|
379
|
+
transfer_blocks = []
|
380
|
+
for src_ptr, dst_ptr, item_len in layers_params:
|
381
|
+
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
|
382
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
379
383
|
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
384
|
+
if self.enable_custom_mem_pool:
|
385
|
+
futures = [
|
386
|
+
executor.submit(
|
387
|
+
process_layer,
|
388
|
+
src_ptr,
|
389
|
+
dst_ptr,
|
390
|
+
item_len,
|
391
|
+
)
|
392
|
+
for (src_ptr, dst_ptr, item_len) in layers_params
|
393
|
+
]
|
394
|
+
for future in concurrent.futures.as_completed(futures):
|
395
|
+
status = future.result()
|
396
|
+
if status != 0:
|
397
|
+
for f in futures:
|
398
|
+
f.cancel()
|
399
|
+
return status
|
400
|
+
else:
|
401
|
+
# Combining all layers' params in one batch transfer is more efficient
|
402
|
+
# compared to using multiple threads
|
403
|
+
return process_layers(layers_params)
|
386
404
|
|
387
405
|
return 0
|
388
406
|
|
@@ -51,6 +51,35 @@ class MooncakeTransferEngine:
|
|
51
51
|
if ret_value != 0:
|
52
52
|
logger.debug("Mooncake memory deregistration %s failed.", ptr)
|
53
53
|
|
54
|
+
def batch_register(self, ptrs: List[int], lengths: List[int]) -> int:
|
55
|
+
"""Batch register multiple memory regions."""
|
56
|
+
try:
|
57
|
+
ret_value = self.engine.batch_register_memory(ptrs, lengths)
|
58
|
+
except Exception:
|
59
|
+
# Mark batch register as failed
|
60
|
+
ret_value = -1
|
61
|
+
if not hasattr(self.engine, "batch_register_memory"):
|
62
|
+
raise RuntimeError(
|
63
|
+
"Mooncake's batch register requires a newer version of mooncake-transfer-engine. "
|
64
|
+
"Please upgrade Mooncake."
|
65
|
+
)
|
66
|
+
|
67
|
+
if ret_value != 0:
|
68
|
+
logger.debug("Mooncake batch memory registration failed.")
|
69
|
+
return ret_value
|
70
|
+
|
71
|
+
def batch_deregister(self, ptrs: List[int]) -> int:
|
72
|
+
"""Batch deregister multiple memory regions."""
|
73
|
+
try:
|
74
|
+
ret_value = self.engine.batch_unregister_memory(ptrs)
|
75
|
+
except Exception:
|
76
|
+
# Mark batch deregister as failed
|
77
|
+
ret_value = -1
|
78
|
+
|
79
|
+
if ret_value != 0:
|
80
|
+
logger.debug("Mooncake batch memory deregistration failed.")
|
81
|
+
return ret_value
|
82
|
+
|
54
83
|
def initialize(
|
55
84
|
self,
|
56
85
|
hostname: str,
|
@@ -50,6 +50,8 @@ from sglang.srt.utils import (
|
|
50
50
|
supports_custom_op,
|
51
51
|
)
|
52
52
|
|
53
|
+
_is_npu = is_npu()
|
54
|
+
|
53
55
|
|
54
56
|
@dataclass
|
55
57
|
class GraphCaptureContext:
|
@@ -591,7 +593,7 @@ class GroupCoordinator:
|
|
591
593
|
)
|
592
594
|
|
593
595
|
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
594
|
-
if not supports_custom_op():
|
596
|
+
if _is_npu or not supports_custom_op():
|
595
597
|
self._all_gather_into_tensor(output, input)
|
596
598
|
else:
|
597
599
|
torch.ops.sglang.reg_all_gather_into_tensor(
|
@@ -1127,7 +1129,7 @@ def init_model_parallel_group(
|
|
1127
1129
|
group_ranks=group_ranks,
|
1128
1130
|
local_rank=local_rank,
|
1129
1131
|
torch_distributed_backend=backend,
|
1130
|
-
use_pynccl=not
|
1132
|
+
use_pynccl=not _is_npu,
|
1131
1133
|
use_pymscclpp=use_mscclpp_allreduce,
|
1132
1134
|
use_custom_allreduce=use_custom_allreduce,
|
1133
1135
|
use_hpu_communicator=True,
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
|
-
# Copied from vLLM
|
2
|
+
# Copied from vLLM: https://github.com/zyongye/vllm/blob/6a70830065701b163e36a86fd331b41b5feac401/vllm/entrypoints/context.py
|
3
3
|
import json
|
4
4
|
import logging
|
5
5
|
from abc import ABC, abstractmethod
|
@@ -9,8 +9,8 @@ logger = logging.getLogger(__name__)
|
|
9
9
|
|
10
10
|
try:
|
11
11
|
from mcp import ClientSession
|
12
|
-
except ImportError:
|
13
|
-
|
12
|
+
except ImportError as e:
|
13
|
+
mcp = e
|
14
14
|
|
15
15
|
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
16
16
|
|
@@ -83,14 +83,6 @@ class HarmonyContext(ConversationContext):
|
|
83
83
|
if isinstance(output, dict) and "output_ids" in output:
|
84
84
|
output_token_ids = output["output_ids"]
|
85
85
|
|
86
|
-
# TODO: REMOVE here:
|
87
|
-
# Very hacky, find the first occurrence of token 200006 and cut from there
|
88
|
-
try:
|
89
|
-
start_index = output_token_ids.index(200006)
|
90
|
-
output_token_ids = output_token_ids[start_index:]
|
91
|
-
except ValueError:
|
92
|
-
pass
|
93
|
-
|
94
86
|
for token_id in output_token_ids:
|
95
87
|
self.parser.process(token_id)
|
96
88
|
output_msgs = self.parser.messages
|
@@ -196,15 +188,6 @@ class StreamingHarmonyContext(HarmonyContext):
|
|
196
188
|
# RequestOutput from SGLang with outputs
|
197
189
|
output_token_ids = output["output_ids"]
|
198
190
|
|
199
|
-
# TODO: REMOVE here:
|
200
|
-
# Very hacky, find the first occurrence of token 200006 and cut from there
|
201
|
-
# Find the first occurrence of token 200006 and cut from there
|
202
|
-
try:
|
203
|
-
start_index = output_token_ids.index(200006)
|
204
|
-
output_token_ids = output_token_ids[start_index:]
|
205
|
-
except ValueError:
|
206
|
-
pass
|
207
|
-
|
208
191
|
for token_id in output_token_ids:
|
209
192
|
self.parser.process(token_id)
|
210
193
|
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -67,6 +67,7 @@ from sglang.srt.utils import (
|
|
67
67
|
MultiprocessingSerializer,
|
68
68
|
assert_pkg_version,
|
69
69
|
configure_logger,
|
70
|
+
get_bool_env_var,
|
70
71
|
get_zmq_socket,
|
71
72
|
is_cuda,
|
72
73
|
kill_process_tree,
|
@@ -259,7 +260,7 @@ class Engine(EngineBase):
|
|
259
260
|
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
|
260
261
|
)
|
261
262
|
|
262
|
-
logger.
|
263
|
+
logger.debug(f"data_parallel_rank: {data_parallel_rank}")
|
263
264
|
obj = GenerateReqInput(
|
264
265
|
text=prompt,
|
265
266
|
input_ids=input_ids,
|
@@ -450,15 +451,20 @@ class Engine(EngineBase):
|
|
450
451
|
):
|
451
452
|
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
|
452
453
|
to avoid duplicated cache cleaning operation."""
|
453
|
-
|
454
|
-
serialized_named_tensors=
|
454
|
+
if load_format == "flattened_bucket":
|
455
|
+
serialized_named_tensors = named_tensors
|
456
|
+
else:
|
457
|
+
serialized_named_tensors = [
|
455
458
|
MultiprocessingSerializer.serialize(named_tensors)
|
456
459
|
for _ in range(self.server_args.tp_size)
|
457
|
-
]
|
460
|
+
]
|
461
|
+
obj = UpdateWeightsFromTensorReqInput(
|
462
|
+
serialized_named_tensors=serialized_named_tensors,
|
458
463
|
load_format=load_format,
|
459
464
|
flush_cache=flush_cache,
|
460
465
|
)
|
461
466
|
loop = asyncio.get_event_loop()
|
467
|
+
|
462
468
|
return loop.run_until_complete(
|
463
469
|
self.tokenizer_manager.update_weights_from_tensor(obj, None)
|
464
470
|
)
|
@@ -627,7 +633,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
627
633
|
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
|
628
634
|
if not server_args.enable_symm_mem:
|
629
635
|
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
630
|
-
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
631
636
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
632
637
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
633
638
|
|
@@ -642,15 +647,15 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
642
647
|
if server_args.attention_backend == "flashinfer":
|
643
648
|
assert_pkg_version(
|
644
649
|
"flashinfer_python",
|
645
|
-
"0.2.
|
650
|
+
"0.2.11.post1",
|
646
651
|
"Please uninstall the old version and "
|
647
652
|
"reinstall the latest version by following the instructions "
|
648
653
|
"at https://docs.flashinfer.ai/installation.html.",
|
649
654
|
)
|
650
|
-
if _is_cuda:
|
655
|
+
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
651
656
|
assert_pkg_version(
|
652
657
|
"sgl-kernel",
|
653
|
-
"0.3.
|
658
|
+
"0.3.4",
|
654
659
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
655
660
|
)
|
656
661
|
|
@@ -1,5 +1,7 @@
|
|
1
1
|
# SPDX-License-Identifier: Apache-2.0
|
2
2
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3
|
+
# Adapted from vLLM: https://github.com/vllm-project/vllm/blob/1b9902806915040ac9b3029f2ab7522ec505afc3/vllm/entrypoints/harmony_utils.py
|
4
|
+
# Slight differences in processing chat messages
|
3
5
|
import datetime
|
4
6
|
import json
|
5
7
|
from collections.abc import Iterable
|
@@ -26,7 +26,7 @@ import os
|
|
26
26
|
import threading
|
27
27
|
import time
|
28
28
|
from http import HTTPStatus
|
29
|
-
from typing import AsyncIterator, Callable, Dict, Optional
|
29
|
+
from typing import Any, AsyncIterator, Callable, Dict, List, Optional
|
30
30
|
|
31
31
|
# Fix a bug of Python threading
|
32
32
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -174,7 +174,6 @@ async def lifespan(fast_api_app: FastAPI):
|
|
174
174
|
tool_server=tool_server,
|
175
175
|
)
|
176
176
|
except Exception as e:
|
177
|
-
# print stack trace
|
178
177
|
import traceback
|
179
178
|
|
180
179
|
traceback.print_exc()
|
@@ -277,7 +276,7 @@ async def health_generate(request: Request) -> Response:
|
|
277
276
|
logger.info("Health check request received during shutdown. Returning 503.")
|
278
277
|
return Response(status_code=503)
|
279
278
|
|
280
|
-
if
|
279
|
+
if _global_state.tokenizer_manager.server_status == ServerStatus.Starting:
|
281
280
|
return Response(status_code=503)
|
282
281
|
|
283
282
|
sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
|
@@ -317,7 +316,7 @@ async def health_generate(request: Request) -> Response:
|
|
317
316
|
if _global_state.tokenizer_manager.last_receive_tstamp > tic:
|
318
317
|
task.cancel()
|
319
318
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
320
|
-
_global_state.tokenizer_manager.
|
319
|
+
_global_state.tokenizer_manager.server_status = ServerStatus.Up
|
321
320
|
return Response(status_code=200)
|
322
321
|
|
323
322
|
task.cancel()
|
@@ -331,7 +330,7 @@ async def health_generate(request: Request) -> Response:
|
|
331
330
|
f"last_heartbeat time: {last_receive_time}"
|
332
331
|
)
|
333
332
|
_global_state.tokenizer_manager.rid_to_state.pop(rid, None)
|
334
|
-
_global_state.tokenizer_manager.
|
333
|
+
_global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
|
335
334
|
return Response(status_code=503)
|
336
335
|
|
337
336
|
|
@@ -859,15 +859,6 @@ class ResponseReasoningTextContent(BaseModel):
|
|
859
859
|
type: Literal["reasoning_text"] = "reasoning_text"
|
860
860
|
|
861
861
|
|
862
|
-
class ResponseReasoningItem(BaseModel):
|
863
|
-
id: str
|
864
|
-
content: list[ResponseReasoningTextContent] = Field(default_factory=list)
|
865
|
-
summary: list = Field(default_factory=list)
|
866
|
-
type: Literal["reasoning"] = "reasoning"
|
867
|
-
encrypted_content: Optional[str] = None
|
868
|
-
status: Optional[Literal["in_progress", "completed", "incomplete"]]
|
869
|
-
|
870
|
-
|
871
862
|
ResponseInputOutputItem: TypeAlias = Union[
|
872
863
|
ResponseInputItemParam, "ResponseReasoningItem", ResponseFunctionToolCall
|
873
864
|
]
|