sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,35 @@ class MooncakeTransferEngine:
|
|
51
51
|
if ret_value != 0:
|
52
52
|
logger.debug("Mooncake memory deregistration %s failed.", ptr)
|
53
53
|
|
54
|
+
def batch_register(self, ptrs: List[int], lengths: List[int]) -> int:
|
55
|
+
"""Batch register multiple memory regions."""
|
56
|
+
try:
|
57
|
+
ret_value = self.engine.batch_register_memory(ptrs, lengths)
|
58
|
+
except Exception:
|
59
|
+
# Mark batch register as failed
|
60
|
+
ret_value = -1
|
61
|
+
if not hasattr(self.engine, "batch_register_memory"):
|
62
|
+
raise RuntimeError(
|
63
|
+
"Mooncake's batch register requires a newer version of mooncake-transfer-engine. "
|
64
|
+
"Please upgrade Mooncake."
|
65
|
+
)
|
66
|
+
|
67
|
+
if ret_value != 0:
|
68
|
+
logger.debug("Mooncake batch memory registration failed.")
|
69
|
+
return ret_value
|
70
|
+
|
71
|
+
def batch_deregister(self, ptrs: List[int]) -> int:
|
72
|
+
"""Batch deregister multiple memory regions."""
|
73
|
+
try:
|
74
|
+
ret_value = self.engine.batch_unregister_memory(ptrs)
|
75
|
+
except Exception:
|
76
|
+
# Mark batch deregister as failed
|
77
|
+
ret_value = -1
|
78
|
+
|
79
|
+
if ret_value != 0:
|
80
|
+
logger.debug("Mooncake batch memory deregistration failed.")
|
81
|
+
return ret_value
|
82
|
+
|
54
83
|
def initialize(
|
55
84
|
self,
|
56
85
|
hostname: str,
|
@@ -103,6 +103,8 @@ class PrefillBootstrapQueue:
|
|
103
103
|
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
104
104
|
kv_args = kv_args_class()
|
105
105
|
kv_args.engine_rank = self.tp_rank
|
106
|
+
kv_args.pp_rank = self.pp_rank
|
107
|
+
kv_args.system_dp_rank = self.scheduler.dp_rank
|
106
108
|
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
107
109
|
kv_args.prefill_pp_size = self.pp_size
|
108
110
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
@@ -50,6 +50,8 @@ from sglang.srt.utils import (
|
|
50
50
|
supports_custom_op,
|
51
51
|
)
|
52
52
|
|
53
|
+
_is_npu = is_npu()
|
54
|
+
|
53
55
|
|
54
56
|
@dataclass
|
55
57
|
class GraphCaptureContext:
|
@@ -591,7 +593,7 @@ class GroupCoordinator:
|
|
591
593
|
)
|
592
594
|
|
593
595
|
def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
594
|
-
if not supports_custom_op():
|
596
|
+
if _is_npu or not supports_custom_op():
|
595
597
|
self._all_gather_into_tensor(output, input)
|
596
598
|
else:
|
597
599
|
torch.ops.sglang.reg_all_gather_into_tensor(
|
@@ -650,17 +652,19 @@ class GroupCoordinator:
|
|
650
652
|
output_size, dtype=input_.dtype, device=input_.device
|
651
653
|
)
|
652
654
|
|
655
|
+
# All-gather.
|
656
|
+
if input_.is_cpu and is_shm_available(
|
657
|
+
input_.dtype, self.world_size, self.local_size
|
658
|
+
):
|
659
|
+
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
660
|
+
|
653
661
|
if input_.is_cpu:
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
)
|
660
|
-
return output_tensor
|
662
|
+
torch.distributed.all_gather_into_tensor(
|
663
|
+
output_tensor, input_, group=self.device_group
|
664
|
+
)
|
665
|
+
else:
|
666
|
+
self.all_gather_into_tensor(output_tensor, input_)
|
661
667
|
|
662
|
-
# All-gather.
|
663
|
-
self.all_gather_into_tensor(output_tensor, input_)
|
664
668
|
# Reshape
|
665
669
|
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
666
670
|
output_tensor = output_tensor.movedim(0, dim)
|
@@ -1125,7 +1129,7 @@ def init_model_parallel_group(
|
|
1125
1129
|
group_ranks=group_ranks,
|
1126
1130
|
local_rank=local_rank,
|
1127
1131
|
torch_distributed_backend=backend,
|
1128
|
-
use_pynccl=not
|
1132
|
+
use_pynccl=not _is_npu,
|
1129
1133
|
use_pymscclpp=use_mscclpp_allreduce,
|
1130
1134
|
use_custom_allreduce=use_custom_allreduce,
|
1131
1135
|
use_hpu_communicator=True,
|
@@ -0,0 +1,227 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Copied from vLLM: https://github.com/zyongye/vllm/blob/6a70830065701b163e36a86fd331b41b5feac401/vllm/entrypoints/context.py
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from typing import Union
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
try:
|
11
|
+
from mcp import ClientSession
|
12
|
+
except ImportError as e:
|
13
|
+
mcp = e
|
14
|
+
|
15
|
+
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
16
|
+
|
17
|
+
from sglang.srt.entrypoints.harmony_utils import (
|
18
|
+
get_encoding,
|
19
|
+
get_streamable_parser_for_assistant,
|
20
|
+
render_for_completion,
|
21
|
+
)
|
22
|
+
from sglang.srt.entrypoints.tool import Tool
|
23
|
+
|
24
|
+
|
25
|
+
class ConversationContext(ABC):
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def append_output(self, output) -> None:
|
29
|
+
pass
|
30
|
+
|
31
|
+
@abstractmethod
|
32
|
+
async def call_tool(self) -> list[Message]:
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def need_builtin_tool_call(self) -> bool:
|
37
|
+
pass
|
38
|
+
|
39
|
+
@abstractmethod
|
40
|
+
def render_for_completion(self) -> list[int]:
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class SimpleContext(ConversationContext):
|
45
|
+
|
46
|
+
def __init__(self):
|
47
|
+
self.last_output = None
|
48
|
+
|
49
|
+
def append_output(self, output) -> None:
|
50
|
+
self.last_output = output
|
51
|
+
|
52
|
+
def need_builtin_tool_call(self) -> bool:
|
53
|
+
return False
|
54
|
+
|
55
|
+
async def call_tool(self) -> list[Message]:
|
56
|
+
raise NotImplementedError("Should not be called.")
|
57
|
+
|
58
|
+
def render_for_completion(self) -> list[int]:
|
59
|
+
raise NotImplementedError("Should not be called.")
|
60
|
+
|
61
|
+
|
62
|
+
class HarmonyContext(ConversationContext):
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
messages: list,
|
67
|
+
tool_sessions: dict[str, Union["ClientSession", Tool]],
|
68
|
+
):
|
69
|
+
# TODO: Remove the hack of Union[ClientSession, Tool] by using MCP
|
70
|
+
# when demo.
|
71
|
+
self._messages = messages
|
72
|
+
self.tool_sessions = tool_sessions
|
73
|
+
|
74
|
+
self.parser = get_streamable_parser_for_assistant()
|
75
|
+
self.num_init_messages = len(messages)
|
76
|
+
# TODO
|
77
|
+
self.num_prompt_tokens = 0
|
78
|
+
self.num_cached_tokens = 0
|
79
|
+
self.num_output_tokens = 0
|
80
|
+
self.num_reasoning_tokens = 0
|
81
|
+
|
82
|
+
def append_output(self, output) -> None:
|
83
|
+
if isinstance(output, dict) and "output_ids" in output:
|
84
|
+
output_token_ids = output["output_ids"]
|
85
|
+
|
86
|
+
for token_id in output_token_ids:
|
87
|
+
self.parser.process(token_id)
|
88
|
+
output_msgs = self.parser.messages
|
89
|
+
|
90
|
+
meta_info = output["meta_info"]
|
91
|
+
|
92
|
+
if isinstance(meta_info, dict):
|
93
|
+
if "prompt_token_ids" in meta_info:
|
94
|
+
self.num_prompt_tokens = meta_info["prompt_tokens"]
|
95
|
+
if "cached_tokens" in meta_info:
|
96
|
+
self.num_cached_tokens = meta_info["cached_tokens"]
|
97
|
+
if "completion_tokens" in meta_info:
|
98
|
+
self.num_output_tokens += meta_info["completion_tokens"]
|
99
|
+
|
100
|
+
else:
|
101
|
+
output_msgs = output
|
102
|
+
|
103
|
+
self._messages.extend(output_msgs)
|
104
|
+
|
105
|
+
@property
|
106
|
+
def messages(self) -> list:
|
107
|
+
return self._messages
|
108
|
+
|
109
|
+
def need_builtin_tool_call(self) -> bool:
|
110
|
+
last_msg = self.messages[-1]
|
111
|
+
recipient = last_msg.recipient
|
112
|
+
return recipient is not None and (
|
113
|
+
recipient.startswith("browser.") or recipient.startswith("python")
|
114
|
+
)
|
115
|
+
|
116
|
+
async def call_tool(self) -> list[Message]:
|
117
|
+
if not self.messages:
|
118
|
+
return []
|
119
|
+
last_msg = self.messages[-1]
|
120
|
+
recipient = last_msg.recipient
|
121
|
+
if recipient is not None:
|
122
|
+
if recipient.startswith("browser."):
|
123
|
+
return await self.call_search_tool(
|
124
|
+
self.tool_sessions["browser"], last_msg
|
125
|
+
)
|
126
|
+
elif recipient.startswith("python"):
|
127
|
+
return await self.call_python_tool(
|
128
|
+
self.tool_sessions["python"], last_msg
|
129
|
+
)
|
130
|
+
raise ValueError("No tool call found")
|
131
|
+
|
132
|
+
def render_for_completion(self) -> list[int]:
|
133
|
+
return render_for_completion(self.messages)
|
134
|
+
|
135
|
+
async def call_search_tool(
|
136
|
+
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
137
|
+
) -> list[Message]:
|
138
|
+
if isinstance(tool_session, Tool):
|
139
|
+
return await tool_session.get_result(self)
|
140
|
+
tool_name = last_msg.recipient.split(".")[1]
|
141
|
+
args = json.loads(last_msg.content[0].text)
|
142
|
+
result = await tool_session.call_tool(tool_name, args)
|
143
|
+
result_str = result.content[0].text
|
144
|
+
content = TextContent(text=result_str)
|
145
|
+
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
146
|
+
return [Message(author=author, content=[content], recipient=Role.ASSISTANT)]
|
147
|
+
|
148
|
+
async def call_python_tool(
|
149
|
+
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
150
|
+
) -> list[Message]:
|
151
|
+
if isinstance(tool_session, Tool):
|
152
|
+
return await tool_session.get_result(self)
|
153
|
+
param = {
|
154
|
+
"code": last_msg.content[0].text,
|
155
|
+
}
|
156
|
+
result = await tool_session.call_tool("python", param)
|
157
|
+
result_str = result.content[0].text
|
158
|
+
|
159
|
+
content = TextContent(text=result_str)
|
160
|
+
author = Author(role=Role.TOOL, name="python")
|
161
|
+
|
162
|
+
return [
|
163
|
+
Message(
|
164
|
+
author=author,
|
165
|
+
content=[content],
|
166
|
+
channel=last_msg.channel,
|
167
|
+
recipient=Role.ASSISTANT,
|
168
|
+
)
|
169
|
+
]
|
170
|
+
|
171
|
+
|
172
|
+
class StreamingHarmonyContext(HarmonyContext):
|
173
|
+
|
174
|
+
def __init__(self, *args, **kwargs):
|
175
|
+
super().__init__(*args, **kwargs)
|
176
|
+
self.last_output = None
|
177
|
+
|
178
|
+
self.parser = get_streamable_parser_for_assistant()
|
179
|
+
self.encoding = get_encoding()
|
180
|
+
self.last_tok = None
|
181
|
+
|
182
|
+
@property
|
183
|
+
def messages(self) -> list:
|
184
|
+
return self.parser.messages
|
185
|
+
|
186
|
+
def append_output(self, output) -> None:
|
187
|
+
if isinstance(output, dict) and "output_ids" in output:
|
188
|
+
# RequestOutput from SGLang with outputs
|
189
|
+
output_token_ids = output["output_ids"]
|
190
|
+
|
191
|
+
for token_id in output_token_ids:
|
192
|
+
self.parser.process(token_id)
|
193
|
+
|
194
|
+
else:
|
195
|
+
# Handle the case of tool output in direct message format
|
196
|
+
assert len(output) == 1, "Tool output should be a single message"
|
197
|
+
msg = output[0]
|
198
|
+
# Sometimes the recipient is not set for tool messages,
|
199
|
+
# so we set it to "assistant"
|
200
|
+
if msg.author.role == Role.TOOL and msg.recipient is None:
|
201
|
+
msg.recipient = "assistant"
|
202
|
+
toks = self.encoding.render(msg)
|
203
|
+
for tok in toks:
|
204
|
+
self.parser.process(tok)
|
205
|
+
self.last_tok = toks[-1]
|
206
|
+
|
207
|
+
def is_expecting_start(self) -> bool:
|
208
|
+
return self.parser.state == StreamState.EXPECT_START
|
209
|
+
|
210
|
+
def is_assistant_action_turn(self) -> bool:
|
211
|
+
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
|
212
|
+
|
213
|
+
def render_for_completion(self) -> list[int]:
|
214
|
+
# now this list of tokens as next turn's starting tokens
|
215
|
+
# `<|start|>assistant``,
|
216
|
+
# we need to process them in parser.
|
217
|
+
rendered_tokens = super().render_for_completion()
|
218
|
+
|
219
|
+
last_n = -1
|
220
|
+
to_process = []
|
221
|
+
while rendered_tokens[last_n] != self.last_tok:
|
222
|
+
to_process.append(rendered_tokens[last_n])
|
223
|
+
last_n -= 1
|
224
|
+
for tok in reversed(to_process):
|
225
|
+
self.parser.process(tok)
|
226
|
+
|
227
|
+
return rendered_tokens
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -67,6 +67,7 @@ from sglang.srt.utils import (
|
|
67
67
|
MultiprocessingSerializer,
|
68
68
|
assert_pkg_version,
|
69
69
|
configure_logger,
|
70
|
+
get_bool_env_var,
|
70
71
|
get_zmq_socket,
|
71
72
|
is_cuda,
|
72
73
|
kill_process_tree,
|
@@ -259,7 +260,7 @@ class Engine(EngineBase):
|
|
259
260
|
f"data_parallel_rank must be in range [0, {self.server_args.dp_size-1}]"
|
260
261
|
)
|
261
262
|
|
262
|
-
logger.
|
263
|
+
logger.debug(f"data_parallel_rank: {data_parallel_rank}")
|
263
264
|
obj = GenerateReqInput(
|
264
265
|
text=prompt,
|
265
266
|
input_ids=input_ids,
|
@@ -450,15 +451,20 @@ class Engine(EngineBase):
|
|
450
451
|
):
|
451
452
|
"""Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
|
452
453
|
to avoid duplicated cache cleaning operation."""
|
453
|
-
|
454
|
-
serialized_named_tensors=
|
454
|
+
if load_format == "flattened_bucket":
|
455
|
+
serialized_named_tensors = named_tensors
|
456
|
+
else:
|
457
|
+
serialized_named_tensors = [
|
455
458
|
MultiprocessingSerializer.serialize(named_tensors)
|
456
459
|
for _ in range(self.server_args.tp_size)
|
457
|
-
]
|
460
|
+
]
|
461
|
+
obj = UpdateWeightsFromTensorReqInput(
|
462
|
+
serialized_named_tensors=serialized_named_tensors,
|
458
463
|
load_format=load_format,
|
459
464
|
flush_cache=flush_cache,
|
460
465
|
)
|
461
466
|
loop = asyncio.get_event_loop()
|
467
|
+
|
462
468
|
return loop.run_until_complete(
|
463
469
|
self.tokenizer_manager.update_weights_from_tensor(obj, None)
|
464
470
|
)
|
@@ -492,12 +498,13 @@ class Engine(EngineBase):
|
|
492
498
|
self.tokenizer_manager.get_weights_by_name(obj, None)
|
493
499
|
)
|
494
500
|
|
495
|
-
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
501
|
+
def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
|
496
502
|
"""Load a new LoRA adapter without re-launching the engine."""
|
497
503
|
|
498
504
|
obj = LoadLoRAAdapterReqInput(
|
499
505
|
lora_name=lora_name,
|
500
506
|
lora_path=lora_path,
|
507
|
+
pinned=pinned,
|
501
508
|
)
|
502
509
|
|
503
510
|
loop = asyncio.get_event_loop()
|
@@ -626,7 +633,6 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
626
633
|
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
|
627
634
|
if not server_args.enable_symm_mem:
|
628
635
|
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
629
|
-
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
630
636
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
631
637
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
632
638
|
|
@@ -641,15 +647,15 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
641
647
|
if server_args.attention_backend == "flashinfer":
|
642
648
|
assert_pkg_version(
|
643
649
|
"flashinfer_python",
|
644
|
-
"0.2.
|
650
|
+
"0.2.11.post1",
|
645
651
|
"Please uninstall the old version and "
|
646
652
|
"reinstall the latest version by following the instructions "
|
647
653
|
"at https://docs.flashinfer.ai/installation.html.",
|
648
654
|
)
|
649
|
-
if _is_cuda:
|
655
|
+
if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
|
650
656
|
assert_pkg_version(
|
651
657
|
"sgl-kernel",
|
652
|
-
"0.
|
658
|
+
"0.3.4",
|
653
659
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
654
660
|
)
|
655
661
|
|