sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -117
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +3 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +22 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +8 -5
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +106 -15
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +55 -13
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +40 -15
- sglang/srt/layers/communicator.py +35 -8
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +9 -8
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +87 -107
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +59 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +8 -7
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +15 -4
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +10 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +61 -32
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +21 -4
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +30 -8
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +170 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +59 -22
- sglang/srt/managers/tokenizer_manager.py +137 -67
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +24 -9
- sglang/srt/model_executor/forward_batch_info.py +48 -17
- sglang/srt/model_executor/model_runner.py +24 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +95 -50
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +102 -27
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +7 -4
- sglang/srt/models/qwen3_moe.py +39 -14
- sglang/srt/models/step3_vl.py +10 -1
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +218 -23
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +163 -9
- sglang/srt/utils.py +41 -26
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +4 -4
- sglang/test/test_utils.py +4 -4
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -103,6 +103,8 @@ class PrefillBootstrapQueue:
|
|
103
103
|
kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
|
104
104
|
kv_args = kv_args_class()
|
105
105
|
kv_args.engine_rank = self.tp_rank
|
106
|
+
kv_args.pp_rank = self.pp_rank
|
107
|
+
kv_args.system_dp_rank = self.scheduler.dp_rank
|
106
108
|
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
107
109
|
kv_args.prefill_pp_size = self.pp_size
|
108
110
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
@@ -460,6 +462,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
460
462
|
|
461
463
|
# We need to remove the sync in the following function for overlap schedule.
|
462
464
|
self.set_next_batch_sampling_info_done(batch)
|
465
|
+
self.maybe_send_health_check_signal()
|
463
466
|
|
464
467
|
def process_disagg_prefill_inflight_queue(
|
465
468
|
self: Scheduler, rids_to_check: Optional[List[str]] = None
|
@@ -75,6 +75,7 @@ class PyNcclCommunicator:
|
|
75
75
|
self.available = True
|
76
76
|
self.disabled = False
|
77
77
|
|
78
|
+
self.nccl_version = self.nccl.ncclGetRawVersion()
|
78
79
|
if self.rank == 0:
|
79
80
|
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
|
80
81
|
|
@@ -259,6 +260,12 @@ class PyNcclCommunicator:
|
|
259
260
|
cudaStream_t(stream.cuda_stream),
|
260
261
|
)
|
261
262
|
|
263
|
+
def register_comm_window_raw(self, ptr: int, size: int):
|
264
|
+
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
|
265
|
+
|
266
|
+
def deregister_comm_window(self, window):
|
267
|
+
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
268
|
+
|
262
269
|
@contextmanager
|
263
270
|
def change_state(
|
264
271
|
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
@@ -0,0 +1,133 @@
|
|
1
|
+
import tempfile
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from packaging import version
|
5
|
+
from torch.cuda.memory import CUDAPluggableAllocator
|
6
|
+
|
7
|
+
from sglang.srt.distributed.parallel_state import GroupCoordinator
|
8
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
+
|
10
|
+
nccl_allocator_source = """
|
11
|
+
#include <nccl.h>
|
12
|
+
extern "C" {
|
13
|
+
|
14
|
+
void* nccl_alloc_plug(size_t size, int device, void* stream) {
|
15
|
+
void* ptr;
|
16
|
+
ncclResult_t err = ncclMemAlloc(&ptr, size);
|
17
|
+
return ptr;
|
18
|
+
|
19
|
+
}
|
20
|
+
|
21
|
+
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
|
22
|
+
ncclResult_t err = ncclMemFree(ptr);
|
23
|
+
}
|
24
|
+
|
25
|
+
}
|
26
|
+
"""
|
27
|
+
|
28
|
+
_allocator = None
|
29
|
+
_mem_pool = None
|
30
|
+
_registered_base_addrs = set()
|
31
|
+
_graph_pool_id = None
|
32
|
+
|
33
|
+
|
34
|
+
def is_symmetric_memory_enabled():
|
35
|
+
return global_server_args_dict["enable_symm_mem"]
|
36
|
+
|
37
|
+
|
38
|
+
def set_graph_pool_id(graph_pool_id):
|
39
|
+
global _graph_pool_id
|
40
|
+
_graph_pool_id = graph_pool_id
|
41
|
+
|
42
|
+
|
43
|
+
def get_nccl_mem_pool():
|
44
|
+
global _allocator, _mem_pool
|
45
|
+
if _mem_pool is None:
|
46
|
+
out_dir = tempfile.gettempdir()
|
47
|
+
nccl_allocator_libname = "nccl_allocator"
|
48
|
+
torch.utils.cpp_extension.load_inline(
|
49
|
+
name=nccl_allocator_libname,
|
50
|
+
cpp_sources=nccl_allocator_source,
|
51
|
+
with_cuda=True,
|
52
|
+
extra_ldflags=["-lnccl"],
|
53
|
+
verbose=True,
|
54
|
+
is_python_module=False,
|
55
|
+
build_directory=out_dir,
|
56
|
+
)
|
57
|
+
_allocator = CUDAPluggableAllocator(
|
58
|
+
f"{out_dir}/{nccl_allocator_libname}.so",
|
59
|
+
"nccl_alloc_plug",
|
60
|
+
"nccl_free_plug",
|
61
|
+
).allocator()
|
62
|
+
_mem_pool = torch.cuda.MemPool(_allocator)
|
63
|
+
return _mem_pool
|
64
|
+
|
65
|
+
|
66
|
+
class use_symmetric_memory:
|
67
|
+
def __init__(self, group_coordinator: GroupCoordinator):
|
68
|
+
if not is_symmetric_memory_enabled():
|
69
|
+
self.group_coordinator = None
|
70
|
+
self._mem_pool_ctx = None
|
71
|
+
self.is_graph_capture = None
|
72
|
+
self.device = None
|
73
|
+
self.pre_2_8_0 = None
|
74
|
+
else:
|
75
|
+
self.group_coordinator = group_coordinator
|
76
|
+
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
|
77
|
+
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
|
78
|
+
self.device = torch.cuda.current_device()
|
79
|
+
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
|
80
|
+
|
81
|
+
def __enter__(self):
|
82
|
+
if not is_symmetric_memory_enabled():
|
83
|
+
return self
|
84
|
+
assert (
|
85
|
+
self.group_coordinator.pynccl_comm is not None
|
86
|
+
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
|
87
|
+
assert (
|
88
|
+
self.group_coordinator.pynccl_comm.nccl_version >= 22703
|
89
|
+
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
|
90
|
+
if self.is_graph_capture:
|
91
|
+
assert (
|
92
|
+
_graph_pool_id is not None
|
93
|
+
), "graph_pool_id is not set under graph capture"
|
94
|
+
# Pause graph memory pool to use symmetric memory with cuda graph
|
95
|
+
if self.pre_2_8_0:
|
96
|
+
torch._C._cuda_endAllocateCurrentStreamToPool(
|
97
|
+
self.device, _graph_pool_id
|
98
|
+
)
|
99
|
+
else:
|
100
|
+
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
|
101
|
+
self._mem_pool_ctx.__enter__()
|
102
|
+
return self
|
103
|
+
|
104
|
+
def tag(self, tensor: torch.Tensor):
|
105
|
+
if not is_symmetric_memory_enabled():
|
106
|
+
return
|
107
|
+
tensor.symmetric_memory = True
|
108
|
+
|
109
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
110
|
+
if not is_symmetric_memory_enabled():
|
111
|
+
return
|
112
|
+
global _registered_base_addrs
|
113
|
+
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
|
114
|
+
for segment in get_nccl_mem_pool().snapshot():
|
115
|
+
if segment["address"] not in _registered_base_addrs:
|
116
|
+
if segment["stream"] == 0 and self.pre_2_8_0:
|
117
|
+
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
|
118
|
+
# See https://github.com/pytorch/pytorch/issues/152861
|
119
|
+
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
|
120
|
+
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
|
121
|
+
continue
|
122
|
+
self.group_coordinator.pynccl_comm.register_comm_window_raw(
|
123
|
+
segment["address"], segment["total_size"]
|
124
|
+
)
|
125
|
+
_registered_base_addrs.add(segment["address"])
|
126
|
+
|
127
|
+
if self.is_graph_capture:
|
128
|
+
if self.pre_2_8_0:
|
129
|
+
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
|
130
|
+
else:
|
131
|
+
torch._C._cuda_beginAllocateCurrentThreadToPool(
|
132
|
+
self.device, _graph_pool_id
|
133
|
+
)
|
@@ -67,6 +67,7 @@ def find_nccl_library() -> str:
|
|
67
67
|
|
68
68
|
ncclResult_t = ctypes.c_int
|
69
69
|
ncclComm_t = ctypes.c_void_p
|
70
|
+
ncclWindow_t = ctypes.c_void_p
|
70
71
|
|
71
72
|
|
72
73
|
class ncclUniqueId(ctypes.Structure):
|
@@ -279,6 +280,23 @@ class NCCLLibrary:
|
|
279
280
|
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
280
281
|
]
|
281
282
|
|
283
|
+
exported_functions_symm_mem = [
|
284
|
+
# ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);
|
285
|
+
Function(
|
286
|
+
"ncclCommWindowRegister",
|
287
|
+
ncclResult_t,
|
288
|
+
[
|
289
|
+
ncclComm_t,
|
290
|
+
buffer_type,
|
291
|
+
ctypes.c_size_t,
|
292
|
+
ctypes.POINTER(ncclWindow_t),
|
293
|
+
ctypes.c_int,
|
294
|
+
],
|
295
|
+
),
|
296
|
+
# ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win);
|
297
|
+
Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
|
298
|
+
]
|
299
|
+
|
282
300
|
# class attribute to store the mapping from the path to the library
|
283
301
|
# to avoid loading the same library multiple times
|
284
302
|
path_to_library_cache: Dict[str, Any] = {}
|
@@ -312,7 +330,10 @@ class NCCLLibrary:
|
|
312
330
|
|
313
331
|
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
314
332
|
_funcs: Dict[str, Any] = {}
|
315
|
-
|
333
|
+
exported_functions = NCCLLibrary.exported_functions
|
334
|
+
if hasattr(self.lib, "ncclCommWindowRegister"):
|
335
|
+
exported_functions.extend(NCCLLibrary.exported_functions_symm_mem)
|
336
|
+
for func in exported_functions:
|
316
337
|
f = getattr(self.lib, func.name)
|
317
338
|
f.restype = func.restype
|
318
339
|
f.argtypes = func.argtypes
|
@@ -328,10 +349,14 @@ class NCCLLibrary:
|
|
328
349
|
error_str = self.ncclGetErrorString(result)
|
329
350
|
raise RuntimeError(f"NCCL error: {error_str}")
|
330
351
|
|
331
|
-
def
|
352
|
+
def ncclGetRawVersion(self) -> int:
|
332
353
|
version = ctypes.c_int()
|
333
354
|
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
334
|
-
|
355
|
+
# something like 21903
|
356
|
+
return version.value
|
357
|
+
|
358
|
+
def ncclGetVersion(self) -> str:
|
359
|
+
version_str = str(self.ncclGetRawVersion())
|
335
360
|
# something like 21903 --> "2.19.3"
|
336
361
|
major = version_str[0].lstrip("0")
|
337
362
|
minor = version_str[1:3].lstrip("0")
|
@@ -460,6 +485,20 @@ class NCCLLibrary:
|
|
460
485
|
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
461
486
|
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
462
487
|
|
488
|
+
def ncclCommWindowRegister(
|
489
|
+
self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
|
490
|
+
) -> ncclWindow_t:
|
491
|
+
window = ncclWindow_t()
|
492
|
+
self.NCCL_CHECK(
|
493
|
+
self._funcs["ncclCommWindowRegister"](
|
494
|
+
comm, buff, size, ctypes.byref(window), win_flags
|
495
|
+
)
|
496
|
+
)
|
497
|
+
return window
|
498
|
+
|
499
|
+
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
|
500
|
+
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
501
|
+
|
463
502
|
|
464
503
|
__all__ = [
|
465
504
|
"NCCLLibrary",
|
@@ -497,6 +497,17 @@ class GroupCoordinator:
|
|
497
497
|
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
498
498
|
return self.npu_communicator.all_reduce(input_)
|
499
499
|
|
500
|
+
if (
|
501
|
+
self.pynccl_comm is not None
|
502
|
+
and hasattr(input_, "symmetric_memory")
|
503
|
+
and input_.symmetric_memory
|
504
|
+
):
|
505
|
+
with self.pynccl_comm.change_state(
|
506
|
+
enable=True, stream=torch.cuda.current_stream()
|
507
|
+
):
|
508
|
+
self.pynccl_comm.all_reduce(input_)
|
509
|
+
return input_
|
510
|
+
|
500
511
|
outplace_all_reduce_method = None
|
501
512
|
if (
|
502
513
|
self.qr_comm is not None
|
@@ -639,17 +650,19 @@ class GroupCoordinator:
|
|
639
650
|
output_size, dtype=input_.dtype, device=input_.device
|
640
651
|
)
|
641
652
|
|
653
|
+
# All-gather.
|
654
|
+
if input_.is_cpu and is_shm_available(
|
655
|
+
input_.dtype, self.world_size, self.local_size
|
656
|
+
):
|
657
|
+
return torch.ops.sgl_kernel.shm_allgather(input_, dim)
|
658
|
+
|
642
659
|
if input_.is_cpu:
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
)
|
649
|
-
return output_tensor
|
660
|
+
torch.distributed.all_gather_into_tensor(
|
661
|
+
output_tensor, input_, group=self.device_group
|
662
|
+
)
|
663
|
+
else:
|
664
|
+
self.all_gather_into_tensor(output_tensor, input_)
|
650
665
|
|
651
|
-
# All-gather.
|
652
|
-
self.all_gather_into_tensor(output_tensor, input_)
|
653
666
|
# Reshape
|
654
667
|
output_tensor = output_tensor.reshape((world_size,) + input_size)
|
655
668
|
output_tensor = output_tensor.movedim(0, dim)
|
@@ -0,0 +1,244 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Copied from vLLM
|
3
|
+
import json
|
4
|
+
import logging
|
5
|
+
from abc import ABC, abstractmethod
|
6
|
+
from typing import Union
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
try:
|
11
|
+
from mcp import ClientSession
|
12
|
+
except ImportError:
|
13
|
+
logger.warning("Ignoring mcp import error")
|
14
|
+
|
15
|
+
from openai_harmony import Author, Message, Role, StreamState, TextContent
|
16
|
+
|
17
|
+
from sglang.srt.entrypoints.harmony_utils import (
|
18
|
+
get_encoding,
|
19
|
+
get_streamable_parser_for_assistant,
|
20
|
+
render_for_completion,
|
21
|
+
)
|
22
|
+
from sglang.srt.entrypoints.tool import Tool
|
23
|
+
|
24
|
+
|
25
|
+
class ConversationContext(ABC):
|
26
|
+
|
27
|
+
@abstractmethod
|
28
|
+
def append_output(self, output) -> None:
|
29
|
+
pass
|
30
|
+
|
31
|
+
@abstractmethod
|
32
|
+
async def call_tool(self) -> list[Message]:
|
33
|
+
pass
|
34
|
+
|
35
|
+
@abstractmethod
|
36
|
+
def need_builtin_tool_call(self) -> bool:
|
37
|
+
pass
|
38
|
+
|
39
|
+
@abstractmethod
|
40
|
+
def render_for_completion(self) -> list[int]:
|
41
|
+
pass
|
42
|
+
|
43
|
+
|
44
|
+
class SimpleContext(ConversationContext):
|
45
|
+
|
46
|
+
def __init__(self):
|
47
|
+
self.last_output = None
|
48
|
+
|
49
|
+
def append_output(self, output) -> None:
|
50
|
+
self.last_output = output
|
51
|
+
|
52
|
+
def need_builtin_tool_call(self) -> bool:
|
53
|
+
return False
|
54
|
+
|
55
|
+
async def call_tool(self) -> list[Message]:
|
56
|
+
raise NotImplementedError("Should not be called.")
|
57
|
+
|
58
|
+
def render_for_completion(self) -> list[int]:
|
59
|
+
raise NotImplementedError("Should not be called.")
|
60
|
+
|
61
|
+
|
62
|
+
class HarmonyContext(ConversationContext):
|
63
|
+
|
64
|
+
def __init__(
|
65
|
+
self,
|
66
|
+
messages: list,
|
67
|
+
tool_sessions: dict[str, Union["ClientSession", Tool]],
|
68
|
+
):
|
69
|
+
# TODO: Remove the hack of Union[ClientSession, Tool] by using MCP
|
70
|
+
# when demo.
|
71
|
+
self._messages = messages
|
72
|
+
self.tool_sessions = tool_sessions
|
73
|
+
|
74
|
+
self.parser = get_streamable_parser_for_assistant()
|
75
|
+
self.num_init_messages = len(messages)
|
76
|
+
# TODO
|
77
|
+
self.num_prompt_tokens = 0
|
78
|
+
self.num_cached_tokens = 0
|
79
|
+
self.num_output_tokens = 0
|
80
|
+
self.num_reasoning_tokens = 0
|
81
|
+
|
82
|
+
def append_output(self, output) -> None:
|
83
|
+
if isinstance(output, dict) and "output_ids" in output:
|
84
|
+
output_token_ids = output["output_ids"]
|
85
|
+
|
86
|
+
# TODO: REMOVE here:
|
87
|
+
# Very hacky, find the first occurrence of token 200006 and cut from there
|
88
|
+
try:
|
89
|
+
start_index = output_token_ids.index(200006)
|
90
|
+
output_token_ids = output_token_ids[start_index:]
|
91
|
+
except ValueError:
|
92
|
+
pass
|
93
|
+
|
94
|
+
for token_id in output_token_ids:
|
95
|
+
self.parser.process(token_id)
|
96
|
+
output_msgs = self.parser.messages
|
97
|
+
|
98
|
+
meta_info = output["meta_info"]
|
99
|
+
|
100
|
+
if isinstance(meta_info, dict):
|
101
|
+
if "prompt_token_ids" in meta_info:
|
102
|
+
self.num_prompt_tokens = meta_info["prompt_tokens"]
|
103
|
+
if "cached_tokens" in meta_info:
|
104
|
+
self.num_cached_tokens = meta_info["cached_tokens"]
|
105
|
+
if "completion_tokens" in meta_info:
|
106
|
+
self.num_output_tokens += meta_info["completion_tokens"]
|
107
|
+
|
108
|
+
else:
|
109
|
+
output_msgs = output
|
110
|
+
|
111
|
+
self._messages.extend(output_msgs)
|
112
|
+
|
113
|
+
@property
|
114
|
+
def messages(self) -> list:
|
115
|
+
return self._messages
|
116
|
+
|
117
|
+
def need_builtin_tool_call(self) -> bool:
|
118
|
+
last_msg = self.messages[-1]
|
119
|
+
recipient = last_msg.recipient
|
120
|
+
return recipient is not None and (
|
121
|
+
recipient.startswith("browser.") or recipient.startswith("python")
|
122
|
+
)
|
123
|
+
|
124
|
+
async def call_tool(self) -> list[Message]:
|
125
|
+
if not self.messages:
|
126
|
+
return []
|
127
|
+
last_msg = self.messages[-1]
|
128
|
+
recipient = last_msg.recipient
|
129
|
+
if recipient is not None:
|
130
|
+
if recipient.startswith("browser."):
|
131
|
+
return await self.call_search_tool(
|
132
|
+
self.tool_sessions["browser"], last_msg
|
133
|
+
)
|
134
|
+
elif recipient.startswith("python"):
|
135
|
+
return await self.call_python_tool(
|
136
|
+
self.tool_sessions["python"], last_msg
|
137
|
+
)
|
138
|
+
raise ValueError("No tool call found")
|
139
|
+
|
140
|
+
def render_for_completion(self) -> list[int]:
|
141
|
+
return render_for_completion(self.messages)
|
142
|
+
|
143
|
+
async def call_search_tool(
|
144
|
+
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
145
|
+
) -> list[Message]:
|
146
|
+
if isinstance(tool_session, Tool):
|
147
|
+
return await tool_session.get_result(self)
|
148
|
+
tool_name = last_msg.recipient.split(".")[1]
|
149
|
+
args = json.loads(last_msg.content[0].text)
|
150
|
+
result = await tool_session.call_tool(tool_name, args)
|
151
|
+
result_str = result.content[0].text
|
152
|
+
content = TextContent(text=result_str)
|
153
|
+
author = Author(role=Role.TOOL, name=last_msg.recipient)
|
154
|
+
return [Message(author=author, content=[content], recipient=Role.ASSISTANT)]
|
155
|
+
|
156
|
+
async def call_python_tool(
|
157
|
+
self, tool_session: Union["ClientSession", Tool], last_msg: Message
|
158
|
+
) -> list[Message]:
|
159
|
+
if isinstance(tool_session, Tool):
|
160
|
+
return await tool_session.get_result(self)
|
161
|
+
param = {
|
162
|
+
"code": last_msg.content[0].text,
|
163
|
+
}
|
164
|
+
result = await tool_session.call_tool("python", param)
|
165
|
+
result_str = result.content[0].text
|
166
|
+
|
167
|
+
content = TextContent(text=result_str)
|
168
|
+
author = Author(role=Role.TOOL, name="python")
|
169
|
+
|
170
|
+
return [
|
171
|
+
Message(
|
172
|
+
author=author,
|
173
|
+
content=[content],
|
174
|
+
channel=last_msg.channel,
|
175
|
+
recipient=Role.ASSISTANT,
|
176
|
+
)
|
177
|
+
]
|
178
|
+
|
179
|
+
|
180
|
+
class StreamingHarmonyContext(HarmonyContext):
|
181
|
+
|
182
|
+
def __init__(self, *args, **kwargs):
|
183
|
+
super().__init__(*args, **kwargs)
|
184
|
+
self.last_output = None
|
185
|
+
|
186
|
+
self.parser = get_streamable_parser_for_assistant()
|
187
|
+
self.encoding = get_encoding()
|
188
|
+
self.last_tok = None
|
189
|
+
|
190
|
+
@property
|
191
|
+
def messages(self) -> list:
|
192
|
+
return self.parser.messages
|
193
|
+
|
194
|
+
def append_output(self, output) -> None:
|
195
|
+
if isinstance(output, dict) and "output_ids" in output:
|
196
|
+
# RequestOutput from SGLang with outputs
|
197
|
+
output_token_ids = output["output_ids"]
|
198
|
+
|
199
|
+
# TODO: REMOVE here:
|
200
|
+
# Very hacky, find the first occurrence of token 200006 and cut from there
|
201
|
+
# Find the first occurrence of token 200006 and cut from there
|
202
|
+
try:
|
203
|
+
start_index = output_token_ids.index(200006)
|
204
|
+
output_token_ids = output_token_ids[start_index:]
|
205
|
+
except ValueError:
|
206
|
+
pass
|
207
|
+
|
208
|
+
for token_id in output_token_ids:
|
209
|
+
self.parser.process(token_id)
|
210
|
+
|
211
|
+
else:
|
212
|
+
# Handle the case of tool output in direct message format
|
213
|
+
assert len(output) == 1, "Tool output should be a single message"
|
214
|
+
msg = output[0]
|
215
|
+
# Sometimes the recipient is not set for tool messages,
|
216
|
+
# so we set it to "assistant"
|
217
|
+
if msg.author.role == Role.TOOL and msg.recipient is None:
|
218
|
+
msg.recipient = "assistant"
|
219
|
+
toks = self.encoding.render(msg)
|
220
|
+
for tok in toks:
|
221
|
+
self.parser.process(tok)
|
222
|
+
self.last_tok = toks[-1]
|
223
|
+
|
224
|
+
def is_expecting_start(self) -> bool:
|
225
|
+
return self.parser.state == StreamState.EXPECT_START
|
226
|
+
|
227
|
+
def is_assistant_action_turn(self) -> bool:
|
228
|
+
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
|
229
|
+
|
230
|
+
def render_for_completion(self) -> list[int]:
|
231
|
+
# now this list of tokens as next turn's starting tokens
|
232
|
+
# `<|start|>assistant``,
|
233
|
+
# we need to process them in parser.
|
234
|
+
rendered_tokens = super().render_for_completion()
|
235
|
+
|
236
|
+
last_n = -1
|
237
|
+
to_process = []
|
238
|
+
while rendered_tokens[last_n] != self.last_tok:
|
239
|
+
to_process.append(rendered_tokens[last_n])
|
240
|
+
last_n -= 1
|
241
|
+
for tok in reversed(to_process):
|
242
|
+
self.parser.process(tok)
|
243
|
+
|
244
|
+
return rendered_tokens
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -492,12 +492,13 @@ class Engine(EngineBase):
|
|
492
492
|
self.tokenizer_manager.get_weights_by_name(obj, None)
|
493
493
|
)
|
494
494
|
|
495
|
-
def load_lora_adapter(self, lora_name: str, lora_path: str):
|
495
|
+
def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
|
496
496
|
"""Load a new LoRA adapter without re-launching the engine."""
|
497
497
|
|
498
498
|
obj = LoadLoRAAdapterReqInput(
|
499
499
|
lora_name=lora_name,
|
500
500
|
lora_path=lora_path,
|
501
|
+
pinned=pinned,
|
501
502
|
)
|
502
503
|
|
503
504
|
loop = asyncio.get_event_loop()
|
@@ -623,8 +624,9 @@ class Engine(EngineBase):
|
|
623
624
|
def _set_envs_and_config(server_args: ServerArgs):
|
624
625
|
# Set global environments
|
625
626
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
626
|
-
os.environ["NCCL_CUMEM_ENABLE"] =
|
627
|
-
|
627
|
+
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
|
628
|
+
if not server_args.enable_symm_mem:
|
629
|
+
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
628
630
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
629
631
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
630
632
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
@@ -640,7 +642,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
640
642
|
if server_args.attention_backend == "flashinfer":
|
641
643
|
assert_pkg_version(
|
642
644
|
"flashinfer_python",
|
643
|
-
"0.2.
|
645
|
+
"0.2.10",
|
644
646
|
"Please uninstall the old version and "
|
645
647
|
"reinstall the latest version by following the instructions "
|
646
648
|
"at https://docs.flashinfer.ai/installation.html.",
|
@@ -648,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
648
650
|
if _is_cuda:
|
649
651
|
assert_pkg_version(
|
650
652
|
"sgl-kernel",
|
651
|
-
"0.2
|
653
|
+
"0.3.2",
|
652
654
|
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
|
653
655
|
)
|
654
656
|
|
@@ -731,6 +733,7 @@ def _launch_subprocesses(
|
|
731
733
|
pp_rank,
|
732
734
|
None,
|
733
735
|
writer,
|
736
|
+
None,
|
734
737
|
),
|
735
738
|
)
|
736
739
|
|