sglang 0.4.10.post1__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (143) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +35 -0
  5. sglang/srt/conversation.py +9 -117
  6. sglang/srt/disaggregation/base/conn.py +5 -2
  7. sglang/srt/disaggregation/decode.py +6 -1
  8. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -0
  9. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  10. sglang/srt/disaggregation/prefill.py +3 -0
  11. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  12. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  13. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  14. sglang/srt/distributed/parallel_state.py +22 -9
  15. sglang/srt/entrypoints/context.py +244 -0
  16. sglang/srt/entrypoints/engine.py +8 -5
  17. sglang/srt/entrypoints/harmony_utils.py +370 -0
  18. sglang/srt/entrypoints/http_server.py +106 -15
  19. sglang/srt/entrypoints/openai/protocol.py +227 -1
  20. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  21. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  22. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  23. sglang/srt/entrypoints/tool.py +87 -0
  24. sglang/srt/eplb/expert_distribution.py +4 -2
  25. sglang/srt/eplb/expert_location.py +5 -1
  26. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  27. sglang/srt/hf_transformers_utils.py +55 -13
  28. sglang/srt/jinja_template_utils.py +8 -1
  29. sglang/srt/layers/attention/aiter_backend.py +5 -8
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  31. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  32. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  33. sglang/srt/layers/attention/triton_backend.py +85 -14
  34. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  35. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  36. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  38. sglang/srt/layers/attention/vision.py +40 -15
  39. sglang/srt/layers/communicator.py +35 -8
  40. sglang/srt/layers/dp_attention.py +12 -0
  41. sglang/srt/layers/linear.py +9 -8
  42. sglang/srt/layers/logits_processor.py +9 -1
  43. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  44. sglang/srt/layers/moe/ep_moe/layer.py +87 -107
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +442 -58
  48. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +169 -15
  49. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  50. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  51. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  52. sglang/srt/layers/moe/topk.py +12 -3
  53. sglang/srt/layers/moe/utils.py +59 -0
  54. sglang/srt/layers/quantization/__init__.py +22 -0
  55. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  56. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  57. sglang/srt/layers/quantization/fp4.py +557 -0
  58. sglang/srt/layers/quantization/fp8.py +8 -7
  59. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  60. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  61. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  62. sglang/srt/layers/quantization/mxfp4.py +651 -0
  63. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  64. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  65. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  66. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  67. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  68. sglang/srt/layers/quantization/quark/utils.py +107 -0
  69. sglang/srt/layers/quantization/unquant.py +60 -6
  70. sglang/srt/layers/quantization/w4afp8.py +1 -1
  71. sglang/srt/layers/rotary_embedding.py +225 -1
  72. sglang/srt/layers/utils.py +9 -0
  73. sglang/srt/layers/vocab_parallel_embedding.py +15 -4
  74. sglang/srt/lora/lora_manager.py +70 -14
  75. sglang/srt/lora/lora_registry.py +10 -2
  76. sglang/srt/lora/mem_pool.py +43 -5
  77. sglang/srt/managers/cache_controller.py +61 -32
  78. sglang/srt/managers/data_parallel_controller.py +52 -2
  79. sglang/srt/managers/detokenizer_manager.py +1 -1
  80. sglang/srt/managers/io_struct.py +21 -4
  81. sglang/srt/managers/mm_utils.py +5 -11
  82. sglang/srt/managers/schedule_batch.py +30 -8
  83. sglang/srt/managers/schedule_policy.py +3 -1
  84. sglang/srt/managers/scheduler.py +170 -18
  85. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  86. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  87. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  88. sglang/srt/managers/template_manager.py +59 -22
  89. sglang/srt/managers/tokenizer_manager.py +137 -67
  90. sglang/srt/managers/tp_worker.py +3 -0
  91. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  92. sglang/srt/managers/utils.py +45 -1
  93. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  94. sglang/srt/mem_cache/hicache_storage.py +13 -21
  95. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  96. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  97. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  98. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  99. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  100. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  101. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  102. sglang/srt/model_executor/forward_batch_info.py +48 -17
  103. sglang/srt/model_executor/model_runner.py +24 -2
  104. sglang/srt/model_loader/weight_utils.py +10 -0
  105. sglang/srt/models/bailing_moe.py +425 -0
  106. sglang/srt/models/deepseek_v2.py +95 -50
  107. sglang/srt/models/ernie4.py +426 -0
  108. sglang/srt/models/ernie4_eagle.py +203 -0
  109. sglang/srt/models/gemma3n_mm.py +39 -0
  110. sglang/srt/models/glm4_moe.py +102 -27
  111. sglang/srt/models/gpt_oss.py +1134 -0
  112. sglang/srt/models/grok.py +3 -3
  113. sglang/srt/models/llama4.py +13 -2
  114. sglang/srt/models/mixtral.py +3 -3
  115. sglang/srt/models/mllama4.py +428 -19
  116. sglang/srt/models/qwen2.py +6 -0
  117. sglang/srt/models/qwen2_moe.py +7 -4
  118. sglang/srt/models/qwen3_moe.py +39 -14
  119. sglang/srt/models/step3_vl.py +10 -1
  120. sglang/srt/models/transformers.py +2 -5
  121. sglang/srt/multimodal/processors/base_processor.py +4 -3
  122. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  123. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  124. sglang/srt/operations_strategy.py +1 -1
  125. sglang/srt/reasoning_parser.py +18 -39
  126. sglang/srt/server_args.py +218 -23
  127. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  128. sglang/srt/two_batch_overlap.py +163 -9
  129. sglang/srt/utils.py +41 -26
  130. sglang/srt/weight_sync/utils.py +1 -1
  131. sglang/test/runners.py +4 -4
  132. sglang/test/test_utils.py +4 -4
  133. sglang/version.py +1 -1
  134. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +18 -15
  135. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +143 -116
  136. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  137. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  138. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  139. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  140. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  141. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  142. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  143. {sglang-0.4.10.post1.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -103,6 +103,8 @@ class PrefillBootstrapQueue:
103
103
  kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
104
104
  kv_args = kv_args_class()
105
105
  kv_args.engine_rank = self.tp_rank
106
+ kv_args.pp_rank = self.pp_rank
107
+ kv_args.system_dp_rank = self.scheduler.dp_rank
106
108
  kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
107
109
  kv_args.prefill_pp_size = self.pp_size
108
110
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
@@ -460,6 +462,7 @@ class SchedulerDisaggregationPrefillMixin:
460
462
 
461
463
  # We need to remove the sync in the following function for overlap schedule.
462
464
  self.set_next_batch_sampling_info_done(batch)
465
+ self.maybe_send_health_check_signal()
463
466
 
464
467
  def process_disagg_prefill_inflight_queue(
465
468
  self: Scheduler, rids_to_check: Optional[List[str]] = None
@@ -75,6 +75,7 @@ class PyNcclCommunicator:
75
75
  self.available = True
76
76
  self.disabled = False
77
77
 
78
+ self.nccl_version = self.nccl.ncclGetRawVersion()
78
79
  if self.rank == 0:
79
80
  logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
80
81
 
@@ -259,6 +260,12 @@ class PyNcclCommunicator:
259
260
  cudaStream_t(stream.cuda_stream),
260
261
  )
261
262
 
263
+ def register_comm_window_raw(self, ptr: int, size: int):
264
+ return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
265
+
266
+ def deregister_comm_window(self, window):
267
+ return self.nccl.ncclCommWindowDeregister(self.comm, window)
268
+
262
269
  @contextmanager
263
270
  def change_state(
264
271
  self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
@@ -0,0 +1,133 @@
1
+ import tempfile
2
+
3
+ import torch
4
+ from packaging import version
5
+ from torch.cuda.memory import CUDAPluggableAllocator
6
+
7
+ from sglang.srt.distributed.parallel_state import GroupCoordinator
8
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
9
+
10
+ nccl_allocator_source = """
11
+ #include <nccl.h>
12
+ extern "C" {
13
+
14
+ void* nccl_alloc_plug(size_t size, int device, void* stream) {
15
+ void* ptr;
16
+ ncclResult_t err = ncclMemAlloc(&ptr, size);
17
+ return ptr;
18
+
19
+ }
20
+
21
+ void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
22
+ ncclResult_t err = ncclMemFree(ptr);
23
+ }
24
+
25
+ }
26
+ """
27
+
28
+ _allocator = None
29
+ _mem_pool = None
30
+ _registered_base_addrs = set()
31
+ _graph_pool_id = None
32
+
33
+
34
+ def is_symmetric_memory_enabled():
35
+ return global_server_args_dict["enable_symm_mem"]
36
+
37
+
38
+ def set_graph_pool_id(graph_pool_id):
39
+ global _graph_pool_id
40
+ _graph_pool_id = graph_pool_id
41
+
42
+
43
+ def get_nccl_mem_pool():
44
+ global _allocator, _mem_pool
45
+ if _mem_pool is None:
46
+ out_dir = tempfile.gettempdir()
47
+ nccl_allocator_libname = "nccl_allocator"
48
+ torch.utils.cpp_extension.load_inline(
49
+ name=nccl_allocator_libname,
50
+ cpp_sources=nccl_allocator_source,
51
+ with_cuda=True,
52
+ extra_ldflags=["-lnccl"],
53
+ verbose=True,
54
+ is_python_module=False,
55
+ build_directory=out_dir,
56
+ )
57
+ _allocator = CUDAPluggableAllocator(
58
+ f"{out_dir}/{nccl_allocator_libname}.so",
59
+ "nccl_alloc_plug",
60
+ "nccl_free_plug",
61
+ ).allocator()
62
+ _mem_pool = torch.cuda.MemPool(_allocator)
63
+ return _mem_pool
64
+
65
+
66
+ class use_symmetric_memory:
67
+ def __init__(self, group_coordinator: GroupCoordinator):
68
+ if not is_symmetric_memory_enabled():
69
+ self.group_coordinator = None
70
+ self._mem_pool_ctx = None
71
+ self.is_graph_capture = None
72
+ self.device = None
73
+ self.pre_2_8_0 = None
74
+ else:
75
+ self.group_coordinator = group_coordinator
76
+ self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
77
+ self.is_graph_capture = torch.cuda.is_current_stream_capturing()
78
+ self.device = torch.cuda.current_device()
79
+ self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
80
+
81
+ def __enter__(self):
82
+ if not is_symmetric_memory_enabled():
83
+ return self
84
+ assert (
85
+ self.group_coordinator.pynccl_comm is not None
86
+ ), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
87
+ assert (
88
+ self.group_coordinator.pynccl_comm.nccl_version >= 22703
89
+ ), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
90
+ if self.is_graph_capture:
91
+ assert (
92
+ _graph_pool_id is not None
93
+ ), "graph_pool_id is not set under graph capture"
94
+ # Pause graph memory pool to use symmetric memory with cuda graph
95
+ if self.pre_2_8_0:
96
+ torch._C._cuda_endAllocateCurrentStreamToPool(
97
+ self.device, _graph_pool_id
98
+ )
99
+ else:
100
+ torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
101
+ self._mem_pool_ctx.__enter__()
102
+ return self
103
+
104
+ def tag(self, tensor: torch.Tensor):
105
+ if not is_symmetric_memory_enabled():
106
+ return
107
+ tensor.symmetric_memory = True
108
+
109
+ def __exit__(self, exc_type, exc_val, exc_tb):
110
+ if not is_symmetric_memory_enabled():
111
+ return
112
+ global _registered_base_addrs
113
+ self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
114
+ for segment in get_nccl_mem_pool().snapshot():
115
+ if segment["address"] not in _registered_base_addrs:
116
+ if segment["stream"] == 0 and self.pre_2_8_0:
117
+ # PyTorch version < 2.8.0 has a multi-thread MemPool bug
118
+ # See https://github.com/pytorch/pytorch/issues/152861
119
+ # Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
120
+ # WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
121
+ continue
122
+ self.group_coordinator.pynccl_comm.register_comm_window_raw(
123
+ segment["address"], segment["total_size"]
124
+ )
125
+ _registered_base_addrs.add(segment["address"])
126
+
127
+ if self.is_graph_capture:
128
+ if self.pre_2_8_0:
129
+ torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
130
+ else:
131
+ torch._C._cuda_beginAllocateCurrentThreadToPool(
132
+ self.device, _graph_pool_id
133
+ )
@@ -67,6 +67,7 @@ def find_nccl_library() -> str:
67
67
 
68
68
  ncclResult_t = ctypes.c_int
69
69
  ncclComm_t = ctypes.c_void_p
70
+ ncclWindow_t = ctypes.c_void_p
70
71
 
71
72
 
72
73
  class ncclUniqueId(ctypes.Structure):
@@ -279,6 +280,23 @@ class NCCLLibrary:
279
280
  Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
280
281
  ]
281
282
 
283
+ exported_functions_symm_mem = [
284
+ # ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);
285
+ Function(
286
+ "ncclCommWindowRegister",
287
+ ncclResult_t,
288
+ [
289
+ ncclComm_t,
290
+ buffer_type,
291
+ ctypes.c_size_t,
292
+ ctypes.POINTER(ncclWindow_t),
293
+ ctypes.c_int,
294
+ ],
295
+ ),
296
+ # ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win);
297
+ Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
298
+ ]
299
+
282
300
  # class attribute to store the mapping from the path to the library
283
301
  # to avoid loading the same library multiple times
284
302
  path_to_library_cache: Dict[str, Any] = {}
@@ -312,7 +330,10 @@ class NCCLLibrary:
312
330
 
313
331
  if so_file not in NCCLLibrary.path_to_dict_mapping:
314
332
  _funcs: Dict[str, Any] = {}
315
- for func in NCCLLibrary.exported_functions:
333
+ exported_functions = NCCLLibrary.exported_functions
334
+ if hasattr(self.lib, "ncclCommWindowRegister"):
335
+ exported_functions.extend(NCCLLibrary.exported_functions_symm_mem)
336
+ for func in exported_functions:
316
337
  f = getattr(self.lib, func.name)
317
338
  f.restype = func.restype
318
339
  f.argtypes = func.argtypes
@@ -328,10 +349,14 @@ class NCCLLibrary:
328
349
  error_str = self.ncclGetErrorString(result)
329
350
  raise RuntimeError(f"NCCL error: {error_str}")
330
351
 
331
- def ncclGetVersion(self) -> str:
352
+ def ncclGetRawVersion(self) -> int:
332
353
  version = ctypes.c_int()
333
354
  self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
334
- version_str = str(version.value)
355
+ # something like 21903
356
+ return version.value
357
+
358
+ def ncclGetVersion(self) -> str:
359
+ version_str = str(self.ncclGetRawVersion())
335
360
  # something like 21903 --> "2.19.3"
336
361
  major = version_str[0].lstrip("0")
337
362
  minor = version_str[1:3].lstrip("0")
@@ -460,6 +485,20 @@ class NCCLLibrary:
460
485
  def ncclCommDestroy(self, comm: ncclComm_t) -> None:
461
486
  self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
462
487
 
488
+ def ncclCommWindowRegister(
489
+ self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
490
+ ) -> ncclWindow_t:
491
+ window = ncclWindow_t()
492
+ self.NCCL_CHECK(
493
+ self._funcs["ncclCommWindowRegister"](
494
+ comm, buff, size, ctypes.byref(window), win_flags
495
+ )
496
+ )
497
+ return window
498
+
499
+ def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
500
+ self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
501
+
463
502
 
464
503
  __all__ = [
465
504
  "NCCLLibrary",
@@ -497,6 +497,17 @@ class GroupCoordinator:
497
497
  if self.npu_communicator is not None and not self.npu_communicator.disabled:
498
498
  return self.npu_communicator.all_reduce(input_)
499
499
 
500
+ if (
501
+ self.pynccl_comm is not None
502
+ and hasattr(input_, "symmetric_memory")
503
+ and input_.symmetric_memory
504
+ ):
505
+ with self.pynccl_comm.change_state(
506
+ enable=True, stream=torch.cuda.current_stream()
507
+ ):
508
+ self.pynccl_comm.all_reduce(input_)
509
+ return input_
510
+
500
511
  outplace_all_reduce_method = None
501
512
  if (
502
513
  self.qr_comm is not None
@@ -639,17 +650,19 @@ class GroupCoordinator:
639
650
  output_size, dtype=input_.dtype, device=input_.device
640
651
  )
641
652
 
653
+ # All-gather.
654
+ if input_.is_cpu and is_shm_available(
655
+ input_.dtype, self.world_size, self.local_size
656
+ ):
657
+ return torch.ops.sgl_kernel.shm_allgather(input_, dim)
658
+
642
659
  if input_.is_cpu:
643
- if is_shm_available(input_.dtype, self.world_size, self.local_size):
644
- return torch.ops.sgl_kernel.shm_allgather(input_, dim)
645
- else:
646
- torch.distributed.all_gather_into_tensor(
647
- output_tensor, input_, group=self.device_group
648
- )
649
- return output_tensor
660
+ torch.distributed.all_gather_into_tensor(
661
+ output_tensor, input_, group=self.device_group
662
+ )
663
+ else:
664
+ self.all_gather_into_tensor(output_tensor, input_)
650
665
 
651
- # All-gather.
652
- self.all_gather_into_tensor(output_tensor, input_)
653
666
  # Reshape
654
667
  output_tensor = output_tensor.reshape((world_size,) + input_size)
655
668
  output_tensor = output_tensor.movedim(0, dim)
@@ -0,0 +1,244 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Copied from vLLM
3
+ import json
4
+ import logging
5
+ from abc import ABC, abstractmethod
6
+ from typing import Union
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ try:
11
+ from mcp import ClientSession
12
+ except ImportError:
13
+ logger.warning("Ignoring mcp import error")
14
+
15
+ from openai_harmony import Author, Message, Role, StreamState, TextContent
16
+
17
+ from sglang.srt.entrypoints.harmony_utils import (
18
+ get_encoding,
19
+ get_streamable_parser_for_assistant,
20
+ render_for_completion,
21
+ )
22
+ from sglang.srt.entrypoints.tool import Tool
23
+
24
+
25
+ class ConversationContext(ABC):
26
+
27
+ @abstractmethod
28
+ def append_output(self, output) -> None:
29
+ pass
30
+
31
+ @abstractmethod
32
+ async def call_tool(self) -> list[Message]:
33
+ pass
34
+
35
+ @abstractmethod
36
+ def need_builtin_tool_call(self) -> bool:
37
+ pass
38
+
39
+ @abstractmethod
40
+ def render_for_completion(self) -> list[int]:
41
+ pass
42
+
43
+
44
+ class SimpleContext(ConversationContext):
45
+
46
+ def __init__(self):
47
+ self.last_output = None
48
+
49
+ def append_output(self, output) -> None:
50
+ self.last_output = output
51
+
52
+ def need_builtin_tool_call(self) -> bool:
53
+ return False
54
+
55
+ async def call_tool(self) -> list[Message]:
56
+ raise NotImplementedError("Should not be called.")
57
+
58
+ def render_for_completion(self) -> list[int]:
59
+ raise NotImplementedError("Should not be called.")
60
+
61
+
62
+ class HarmonyContext(ConversationContext):
63
+
64
+ def __init__(
65
+ self,
66
+ messages: list,
67
+ tool_sessions: dict[str, Union["ClientSession", Tool]],
68
+ ):
69
+ # TODO: Remove the hack of Union[ClientSession, Tool] by using MCP
70
+ # when demo.
71
+ self._messages = messages
72
+ self.tool_sessions = tool_sessions
73
+
74
+ self.parser = get_streamable_parser_for_assistant()
75
+ self.num_init_messages = len(messages)
76
+ # TODO
77
+ self.num_prompt_tokens = 0
78
+ self.num_cached_tokens = 0
79
+ self.num_output_tokens = 0
80
+ self.num_reasoning_tokens = 0
81
+
82
+ def append_output(self, output) -> None:
83
+ if isinstance(output, dict) and "output_ids" in output:
84
+ output_token_ids = output["output_ids"]
85
+
86
+ # TODO: REMOVE here:
87
+ # Very hacky, find the first occurrence of token 200006 and cut from there
88
+ try:
89
+ start_index = output_token_ids.index(200006)
90
+ output_token_ids = output_token_ids[start_index:]
91
+ except ValueError:
92
+ pass
93
+
94
+ for token_id in output_token_ids:
95
+ self.parser.process(token_id)
96
+ output_msgs = self.parser.messages
97
+
98
+ meta_info = output["meta_info"]
99
+
100
+ if isinstance(meta_info, dict):
101
+ if "prompt_token_ids" in meta_info:
102
+ self.num_prompt_tokens = meta_info["prompt_tokens"]
103
+ if "cached_tokens" in meta_info:
104
+ self.num_cached_tokens = meta_info["cached_tokens"]
105
+ if "completion_tokens" in meta_info:
106
+ self.num_output_tokens += meta_info["completion_tokens"]
107
+
108
+ else:
109
+ output_msgs = output
110
+
111
+ self._messages.extend(output_msgs)
112
+
113
+ @property
114
+ def messages(self) -> list:
115
+ return self._messages
116
+
117
+ def need_builtin_tool_call(self) -> bool:
118
+ last_msg = self.messages[-1]
119
+ recipient = last_msg.recipient
120
+ return recipient is not None and (
121
+ recipient.startswith("browser.") or recipient.startswith("python")
122
+ )
123
+
124
+ async def call_tool(self) -> list[Message]:
125
+ if not self.messages:
126
+ return []
127
+ last_msg = self.messages[-1]
128
+ recipient = last_msg.recipient
129
+ if recipient is not None:
130
+ if recipient.startswith("browser."):
131
+ return await self.call_search_tool(
132
+ self.tool_sessions["browser"], last_msg
133
+ )
134
+ elif recipient.startswith("python"):
135
+ return await self.call_python_tool(
136
+ self.tool_sessions["python"], last_msg
137
+ )
138
+ raise ValueError("No tool call found")
139
+
140
+ def render_for_completion(self) -> list[int]:
141
+ return render_for_completion(self.messages)
142
+
143
+ async def call_search_tool(
144
+ self, tool_session: Union["ClientSession", Tool], last_msg: Message
145
+ ) -> list[Message]:
146
+ if isinstance(tool_session, Tool):
147
+ return await tool_session.get_result(self)
148
+ tool_name = last_msg.recipient.split(".")[1]
149
+ args = json.loads(last_msg.content[0].text)
150
+ result = await tool_session.call_tool(tool_name, args)
151
+ result_str = result.content[0].text
152
+ content = TextContent(text=result_str)
153
+ author = Author(role=Role.TOOL, name=last_msg.recipient)
154
+ return [Message(author=author, content=[content], recipient=Role.ASSISTANT)]
155
+
156
+ async def call_python_tool(
157
+ self, tool_session: Union["ClientSession", Tool], last_msg: Message
158
+ ) -> list[Message]:
159
+ if isinstance(tool_session, Tool):
160
+ return await tool_session.get_result(self)
161
+ param = {
162
+ "code": last_msg.content[0].text,
163
+ }
164
+ result = await tool_session.call_tool("python", param)
165
+ result_str = result.content[0].text
166
+
167
+ content = TextContent(text=result_str)
168
+ author = Author(role=Role.TOOL, name="python")
169
+
170
+ return [
171
+ Message(
172
+ author=author,
173
+ content=[content],
174
+ channel=last_msg.channel,
175
+ recipient=Role.ASSISTANT,
176
+ )
177
+ ]
178
+
179
+
180
+ class StreamingHarmonyContext(HarmonyContext):
181
+
182
+ def __init__(self, *args, **kwargs):
183
+ super().__init__(*args, **kwargs)
184
+ self.last_output = None
185
+
186
+ self.parser = get_streamable_parser_for_assistant()
187
+ self.encoding = get_encoding()
188
+ self.last_tok = None
189
+
190
+ @property
191
+ def messages(self) -> list:
192
+ return self.parser.messages
193
+
194
+ def append_output(self, output) -> None:
195
+ if isinstance(output, dict) and "output_ids" in output:
196
+ # RequestOutput from SGLang with outputs
197
+ output_token_ids = output["output_ids"]
198
+
199
+ # TODO: REMOVE here:
200
+ # Very hacky, find the first occurrence of token 200006 and cut from there
201
+ # Find the first occurrence of token 200006 and cut from there
202
+ try:
203
+ start_index = output_token_ids.index(200006)
204
+ output_token_ids = output_token_ids[start_index:]
205
+ except ValueError:
206
+ pass
207
+
208
+ for token_id in output_token_ids:
209
+ self.parser.process(token_id)
210
+
211
+ else:
212
+ # Handle the case of tool output in direct message format
213
+ assert len(output) == 1, "Tool output should be a single message"
214
+ msg = output[0]
215
+ # Sometimes the recipient is not set for tool messages,
216
+ # so we set it to "assistant"
217
+ if msg.author.role == Role.TOOL and msg.recipient is None:
218
+ msg.recipient = "assistant"
219
+ toks = self.encoding.render(msg)
220
+ for tok in toks:
221
+ self.parser.process(tok)
222
+ self.last_tok = toks[-1]
223
+
224
+ def is_expecting_start(self) -> bool:
225
+ return self.parser.state == StreamState.EXPECT_START
226
+
227
+ def is_assistant_action_turn(self) -> bool:
228
+ return self.last_tok in self.encoding.stop_tokens_for_assistant_actions()
229
+
230
+ def render_for_completion(self) -> list[int]:
231
+ # now this list of tokens as next turn's starting tokens
232
+ # `<|start|>assistant``,
233
+ # we need to process them in parser.
234
+ rendered_tokens = super().render_for_completion()
235
+
236
+ last_n = -1
237
+ to_process = []
238
+ while rendered_tokens[last_n] != self.last_tok:
239
+ to_process.append(rendered_tokens[last_n])
240
+ last_n -= 1
241
+ for tok in reversed(to_process):
242
+ self.parser.process(tok)
243
+
244
+ return rendered_tokens
@@ -492,12 +492,13 @@ class Engine(EngineBase):
492
492
  self.tokenizer_manager.get_weights_by_name(obj, None)
493
493
  )
494
494
 
495
- def load_lora_adapter(self, lora_name: str, lora_path: str):
495
+ def load_lora_adapter(self, lora_name: str, lora_path: str, pinned: bool = False):
496
496
  """Load a new LoRA adapter without re-launching the engine."""
497
497
 
498
498
  obj = LoadLoRAAdapterReqInput(
499
499
  lora_name=lora_name,
500
500
  lora_path=lora_path,
501
+ pinned=pinned,
501
502
  )
502
503
 
503
504
  loop = asyncio.get_event_loop()
@@ -623,8 +624,9 @@ class Engine(EngineBase):
623
624
  def _set_envs_and_config(server_args: ServerArgs):
624
625
  # Set global environments
625
626
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
626
- os.environ["NCCL_CUMEM_ENABLE"] = "0"
627
- os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
627
+ os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
628
+ if not server_args.enable_symm_mem:
629
+ os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
628
630
  os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
629
631
  os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
630
632
  os.environ["CUDA_MODULE_LOADING"] = "AUTO"
@@ -640,7 +642,7 @@ def _set_envs_and_config(server_args: ServerArgs):
640
642
  if server_args.attention_backend == "flashinfer":
641
643
  assert_pkg_version(
642
644
  "flashinfer_python",
643
- "0.2.9rc2",
645
+ "0.2.10",
644
646
  "Please uninstall the old version and "
645
647
  "reinstall the latest version by following the instructions "
646
648
  "at https://docs.flashinfer.ai/installation.html.",
@@ -648,7 +650,7 @@ def _set_envs_and_config(server_args: ServerArgs):
648
650
  if _is_cuda:
649
651
  assert_pkg_version(
650
652
  "sgl-kernel",
651
- "0.2.8",
653
+ "0.3.2",
652
654
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
653
655
  )
654
656
 
@@ -731,6 +733,7 @@ def _launch_subprocesses(
731
733
  pp_rank,
732
734
  None,
733
735
  writer,
736
+ None,
734
737
  ),
735
738
  )
736
739