sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ Life cycle of a request in the prefill server
20
20
  from __future__ import annotations
21
21
 
22
22
  import logging
23
+ from collections import deque
23
24
  from typing import TYPE_CHECKING, List, Optional
24
25
 
25
26
  import torch
@@ -31,6 +32,8 @@ from sglang.srt.disaggregation.utils import (
31
32
  ReqToMetadataIdxAllocator,
32
33
  TransferBackend,
33
34
  get_kv_class,
35
+ kv_to_page_indices,
36
+ kv_to_page_num,
34
37
  poll_and_all_reduce,
35
38
  )
36
39
  from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
@@ -103,7 +106,7 @@ class PrefillBootstrapQueue:
103
106
  kv_args.aux_item_lens = [
104
107
  metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
105
108
  ]
106
- kv_args.ib_device = "mock-ib-device"
109
+ kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
107
110
  kv_args.gpu_id = self.scheduler.gpu_id
108
111
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
109
112
  kv_manager = kv_manager_class(
@@ -154,7 +157,8 @@ class PrefillBootstrapQueue:
154
157
  self.req_to_metadata_buffer_idx_allocator.alloc()
155
158
  )
156
159
  assert req.metadata_buffer_index is not None
157
- req.disagg_kv_sender.init(num_kv_indices, req.metadata_buffer_index)
160
+ num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
161
+ req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
158
162
 
159
163
  bootstrapped_reqs.append(req)
160
164
  indices_to_remove.add(i)
@@ -171,6 +175,70 @@ class SchedulerDisaggregationPrefillMixin:
171
175
  Mixin for Scheduler to handle disaggregation prefill
172
176
  """
173
177
 
178
+ @torch.no_grad()
179
+ def event_loop_normal_disagg_prefill(self):
180
+ """A normal scheduler loop for prefill worker in disaggregation mode."""
181
+
182
+ while True:
183
+ recv_reqs = self.recv_requests()
184
+ self.process_input_requests(recv_reqs)
185
+ self.waiting_queue.extend(
186
+ self.disagg_prefill_pending_queue.pop_bootstrapped()
187
+ )
188
+ self.process_prefill_chunk()
189
+ batch = self.get_new_batch_prefill()
190
+ self.cur_batch = batch
191
+
192
+ if batch:
193
+ result = self.run_batch(batch)
194
+ self.process_batch_result_disagg_prefill(batch, result)
195
+
196
+ if len(self.disagg_prefill_inflight_queue) > 0:
197
+ self.process_disagg_prefill_inflight_queue()
198
+
199
+ if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
200
+ self.check_memory()
201
+ self.new_token_ratio = self.init_new_token_ratio
202
+
203
+ self.last_batch = batch
204
+ # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
205
+ # Otherwise, it hangs under high concurrency
206
+ self.running_batch.batch_is_full = False
207
+
208
+ @torch.no_grad()
209
+ def event_loop_overlap_disagg_prefill(self):
210
+ self.result_queue = deque()
211
+
212
+ while True:
213
+ recv_reqs = self.recv_requests()
214
+ self.process_input_requests(recv_reqs)
215
+ self.waiting_queue.extend(
216
+ self.disagg_prefill_pending_queue.pop_bootstrapped()
217
+ )
218
+ self.process_prefill_chunk()
219
+ batch = self.get_new_batch_prefill()
220
+ self.cur_batch = batch
221
+
222
+ if batch:
223
+ result = self.run_batch(batch)
224
+ self.result_queue.append((batch.copy(), result))
225
+
226
+ if self.last_batch:
227
+ tmp_batch, tmp_result = self.result_queue.popleft()
228
+ self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)
229
+
230
+ if len(self.disagg_prefill_inflight_queue) > 0:
231
+ self.process_disagg_prefill_inflight_queue()
232
+
233
+ if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
234
+ self.check_memory()
235
+ self.new_token_ratio = self.init_new_token_ratio
236
+
237
+ self.last_batch = batch
238
+ # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
239
+ # Otherwise, it hangs under high concurrency
240
+ self.running_batch.batch_is_full = False
241
+
174
242
  def process_batch_result_disagg_prefill(
175
243
  self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
176
244
  ) -> None:
@@ -179,7 +247,26 @@ class SchedulerDisaggregationPrefillMixin:
179
247
  Adapted from process_batch_result_prefill
180
248
  """
181
249
 
182
- next_token_ids = result.next_token_ids.tolist()
250
+ (
251
+ logits_output,
252
+ next_token_ids,
253
+ extend_input_len_per_req,
254
+ extend_logprob_start_len_per_req,
255
+ bid,
256
+ ) = (
257
+ result.logits_output,
258
+ result.next_token_ids,
259
+ result.extend_input_len_per_req,
260
+ result.extend_logprob_start_len_per_req,
261
+ result.bid,
262
+ )
263
+
264
+ # Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
265
+ if self.enable_overlap:
266
+ # wait
267
+ _, next_token_ids = self.tp_worker.resolve_batch_result(bid)
268
+ else:
269
+ next_token_ids = result.next_token_ids.tolist()
183
270
 
184
271
  for req, next_token_id in zip(batch.reqs, next_token_ids, strict=True):
185
272
  req: Req
@@ -193,12 +280,8 @@ class SchedulerDisaggregationPrefillMixin:
193
280
  # being chunked reqs' prefill is not finished
194
281
  req.is_chunked -= 1
195
282
 
196
- # TODO: Not sure if this is necessary
197
- if batch.next_batch_sampling_info:
198
- batch.next_batch_sampling_info.update_regex_vocab_mask()
199
- # We need to remove this for overlap schedule.
200
- self.current_stream.synchronize()
201
- batch.next_batch_sampling_info.sampling_info_done.set()
283
+ if self.enable_overlap:
284
+ self.send_kv_chunk(req, end_idx=req.tmp_end_idx)
202
285
 
203
286
  def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
204
287
  """
@@ -210,7 +293,7 @@ class SchedulerDisaggregationPrefillMixin:
210
293
 
211
294
  polls = poll_and_all_reduce(
212
295
  [req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
213
- self.tp_worker.get_tp_cpu_group(),
296
+ self.attn_tp_cpu_group,
214
297
  )
215
298
 
216
299
  undone_reqs: List[Req] = []
@@ -243,31 +326,68 @@ class SchedulerDisaggregationPrefillMixin:
243
326
  # only finished requests to running_batch.
244
327
  self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
245
328
  self.tree_cache.cache_unfinished_req(self.chunked_req)
246
- self.send_kv_chunk(self.chunked_req)
329
+ if (
330
+ self.enable_overlap
331
+ ): # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
332
+ self.chunked_req.tmp_end_idx = min(
333
+ len(self.chunked_req.fill_ids),
334
+ len(self.chunked_req.origin_input_ids),
335
+ )
336
+ else:
337
+ self.send_kv_chunk(self.chunked_req)
247
338
  # chunked request keeps its rid but will get a new req_pool_idx
248
339
  self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
249
340
  self.running_batch.batch_is_full = False
250
341
 
251
342
  def send_kv_chunk(
252
- self: Scheduler, req: Req, token_id: Optional[int] = None
343
+ self: Scheduler,
344
+ req: Req,
345
+ token_id: Optional[int] = None,
346
+ end_idx: Optional[int] = None,
253
347
  ) -> None:
254
348
  """
255
349
  Send a prefilled chunk to the decode server
256
350
  """
351
+ page_size = self.token_to_kv_pool_allocator.page_size
257
352
  start_idx = req.start_send_idx
258
- end_idx = min(len(req.fill_ids), len(req.origin_input_ids))
353
+ # if end_idx is specified, use it as the end index of the kv chunk because in overlap schedule,
354
+ # the resolved length is not the same as fill_ids's length
355
+ end_idx = (
356
+ end_idx
357
+ if end_idx is not None
358
+ else min(len(req.fill_ids), len(req.origin_input_ids))
359
+ )
360
+ last_chunk = token_id is not None
361
+
362
+ if (not last_chunk) and (
363
+ end_idx % page_size != 0
364
+ ): # todo: remove the second condition
365
+ # if not the last chunk and the last page is partial, delay the last partial page to the next send
366
+ end_idx = end_idx - end_idx % page_size
259
367
 
260
368
  # Update next start_send_idx
261
369
  req.start_send_idx = end_idx
262
370
 
263
371
  kv_indices = (
264
- self.req_to_token_pool.req_to_token[req.req_pool_idx][start_idx:end_idx]
372
+ self.req_to_token_pool.req_to_token[req.req_pool_idx, start_idx:end_idx]
265
373
  .cpu()
266
374
  .numpy()
267
375
  )
268
- if token_id is not None:
376
+ if last_chunk is True:
269
377
  self.disagg_prefill_pending_queue.store_prefill_results(
270
378
  req.metadata_buffer_index, token_id
271
379
  )
272
- is_last = token_id is not None
273
- req.disagg_kv_sender.send(kv_indices, slice(start_idx, end_idx), is_last)
380
+ page_indices = kv_to_page_indices(kv_indices, page_size)
381
+
382
+ page_start_idx = start_idx // page_size
383
+ page_end_idx = page_start_idx + len(page_indices)
384
+
385
+ if len(page_indices) == 0:
386
+ logger.info(
387
+ f"Skip sending kv chunk for request {req.rid=} {req.bootstrap_room=} because page_indices is empty"
388
+ )
389
+ return
390
+
391
+ req.disagg_kv_sender.send(
392
+ page_indices, slice(page_start_idx, page_end_idx), last_chunk
393
+ )
@@ -4,6 +4,7 @@ from collections import deque
4
4
  from enum import Enum
5
5
  from typing import List
6
6
 
7
+ import numpy as np
7
8
  import torch
8
9
  import torch.distributed as dist
9
10
 
@@ -46,6 +47,7 @@ class ReqToMetadataIdxAllocator:
46
47
 
47
48
  class TransferBackend(Enum):
48
49
  MOONCAKE = "mooncake"
50
+ NIXL = "nixl"
49
51
  FAKE = "fake"
50
52
 
51
53
 
@@ -72,4 +74,34 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
72
74
  KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
73
75
  }
74
76
  return class_mapping.get(class_type)
77
+ if transfer_backend == TransferBackend.NIXL:
78
+ from sglang.srt.disaggregation.nixl import (
79
+ NixlKVBootstrapServer,
80
+ NixlKVManager,
81
+ NixlKVReceiver,
82
+ NixlKVSender,
83
+ )
84
+
85
+ class_mapping = {
86
+ KVClassType.MANAGER: NixlKVManager,
87
+ KVClassType.SENDER: NixlKVSender,
88
+ KVClassType.RECEIVER: NixlKVReceiver,
89
+ KVClassType.BOOTSTRAP_SERVER: NixlKVBootstrapServer,
90
+ }
91
+ return class_mapping.get(class_type)
75
92
  raise ValueError(f"Unsupported transfer backend: {transfer_backend}")
93
+
94
+
95
+ def kv_to_page_indices(kv_indices: np.ndarray, page_size: int):
96
+ # 1. The page is guaruanteed to be full except the last page.
97
+ # 2. page index = kv_index // page_size
98
+ # The return vector is kv_indices[::page_size] // page_size
99
+ if page_size == 1: # shortcut
100
+ return kv_indices
101
+
102
+ return kv_indices[::page_size] // page_size
103
+
104
+
105
+ def kv_to_page_num(num_kv_indices: int, page_size: int):
106
+ # ceil(num_kv_indices / page_size)
107
+ return (num_kv_indices + page_size - 1) // page_size
@@ -279,6 +279,10 @@ class Engine(EngineBase):
279
279
  self.shutdown()
280
280
  return False
281
281
 
282
+ def flush_cache(self):
283
+ loop = asyncio.get_event_loop()
284
+ return loop.run_until_complete(self.tokenizer_manager.flush_cache())
285
+
282
286
  def start_profile(self):
283
287
  loop = asyncio.get_event_loop()
284
288
  loop.run_until_complete(self.tokenizer_manager.start_profile())
@@ -25,11 +25,8 @@ import multiprocessing as multiprocessing
25
25
  import os
26
26
  import threading
27
27
  import time
28
- from ast import Mult
29
28
  from http import HTTPStatus
30
- from typing import AsyncIterator, Callable, Dict, Optional, Union
31
-
32
- from sglang.srt.model_executor.model_runner import LocalSerializedTensor
29
+ from typing import AsyncIterator, Callable, Dict, Optional
33
30
 
34
31
  # Fix a bug of Python threading
35
32
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -84,7 +81,6 @@ from sglang.srt.openai_api.protocol import ModelCard, ModelList
84
81
  from sglang.srt.reasoning_parser import ReasoningParser
85
82
  from sglang.srt.server_args import ServerArgs
86
83
  from sglang.srt.utils import (
87
- MultiprocessingSerializer,
88
84
  add_api_key_middleware,
89
85
  add_prometheus_middleware,
90
86
  delete_directory,
@@ -315,11 +311,11 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
315
311
  @app.api_route("/flush_cache", methods=["GET", "POST"])
316
312
  async def flush_cache():
317
313
  """Flush the radix cache."""
318
- _global_state.tokenizer_manager.flush_cache()
314
+ ret = await _global_state.tokenizer_manager.flush_cache()
319
315
  return Response(
320
316
  content="Cache flushed.\nPlease check backend logs for more details. "
321
317
  "(When there are running or waiting requests, the operation will not be performed.)\n",
322
- status_code=200,
318
+ status_code=200 if ret.success else HTTPStatus.BAD_REQUEST,
323
319
  )
324
320
 
325
321
 
@@ -12,18 +12,17 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
  import os
15
- from typing import Dict, List, Literal, Optional, Tuple, Union
15
+ from typing import Dict, Iterable, List, Literal, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
18
  import torch.distributed as dist
19
19
  from PIL.Image import Image
20
20
  from torch.distributed.tensor import DeviceMesh, DTensor
21
21
 
22
+ from sglang.srt.entrypoints.engine import Engine
22
23
  from sglang.srt.entrypoints.http_server_engine import HttpServerEngineAdapter
23
24
  from sglang.srt.model_executor.model_runner import LocalSerializedTensor
24
25
  from sglang.srt.patch_torch import monkey_patch_torch_reductions
25
- from sglang.srt.server import Engine
26
- from sglang.srt.server_args import PortArgs, ServerArgs
27
26
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
28
27
 
29
28
 
@@ -125,7 +124,7 @@ class VerlEngine:
125
124
 
126
125
  def update_weights_from_tensor(
127
126
  self,
128
- named_tensors: List[Tuple[str, torch.Tensor]],
127
+ named_tensors: Iterable[Tuple[str, torch.Tensor]],
129
128
  load_format: Optional[str] = None,
130
129
  ):
131
130
  # Most naive implementation, can optimize a lot if it is bottleneck
@@ -154,9 +153,12 @@ class VerlEngine:
154
153
  )
155
154
  ],
156
155
  load_format=load_format,
157
- flush_cache=tensor_index == len(named_tensors) - 1,
156
+ flush_cache=False,
158
157
  )
159
158
 
159
+ if self._tp_rank == 0:
160
+ self._engine.tokenizer_manager.flush_cache()
161
+
160
162
  def release_memory_occupation(self):
161
163
  if self._tp_rank == 0:
162
164
  self._engine.release_memory_occupation()
@@ -25,6 +25,7 @@ TOOLS_TAG_LIST = [
25
25
  "<tool_call>",
26
26
  "<|python_tag|>",
27
27
  "[TOOL_CALLS]",
28
+ "<|tool▁calls▁begin|>",
28
29
  ]
29
30
 
30
31
 
@@ -477,6 +478,64 @@ class Llama32Detector(BaseFormatDetector):
477
478
  )
478
479
 
479
480
 
481
+ class DeepSeekV3Detector(BaseFormatDetector):
482
+ """
483
+ Detector for DeepSeek models.
484
+ Assumes function call format:
485
+ '<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Tokyo"}\n```<|tool▁call▁end|>\n<|tool▁call▁begin|>function<|tool▁sep|>get_current_weather\n```json\n{"location": "Paris"}\n```<|tool▁call▁end|><|tool▁calls▁end|><|end▁of▁sentence|>
486
+ """
487
+
488
+ def __init__(self):
489
+ super().__init__()
490
+ self.bot_token = "<|tool▁calls▁begin|>"
491
+ self.eot_token = "<|tool▁calls▁end|>"
492
+ self.func_call_regex = r"<|tool▁call▁begin|>.*?<|tool▁call▁end|>"
493
+ self.func_detail_regex = r"<|tool▁call▁begin|>(.*)<|tool▁sep|>(.*)\n```json\n(.*)\n```<|tool▁call▁end|>"
494
+
495
+ def has_tool_call(self, text: str) -> bool:
496
+ """Check if the text contains a deepseek format tool call."""
497
+ return self.bot_token in text
498
+
499
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
500
+ """
501
+ One-time parsing: Detects and parses tool calls in the provided text.
502
+
503
+ :param text: The complete text to parse.
504
+ :param tools: List of available tools.
505
+ :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
506
+ """
507
+ idx = text.find(self.bot_token)
508
+ normal_text = text[:idx].strip() if idx != -1 else text
509
+ if self.bot_token not in text:
510
+ return StreamingParseResult(normal_text=normal_text, calls=[])
511
+ match_result_list = re.findall(self.func_call_regex, text, re.DOTALL)
512
+ calls = []
513
+ try:
514
+ for match_result in match_result_list:
515
+ # Get function name
516
+ func_detail = re.search(self.func_detail_regex, match_result, re.DOTALL)
517
+ func_name = func_detail.group(2)
518
+ func_args = func_detail.group(3)
519
+ func_args = json.loads(func_args)
520
+ # construct match_result for parse_base_json
521
+ match_result = {"name": func_name, "parameters": func_args}
522
+ calls.extend(self.parse_base_json(match_result, tools))
523
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
524
+ except Exception as e:
525
+ logger.error(f"Error in detect_and_parse: {e}")
526
+ # return the normal text if parsing fails
527
+ return StreamingParseResult(normal_text=text)
528
+
529
+ def structure_info(self) -> _GetInfoFunc:
530
+ return lambda name: StructureInfo(
531
+ begin="<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>"
532
+ + name
533
+ + "\n```json\n",
534
+ end="\n```<|tool▁call▁end|><|tool▁calls▁end|>",
535
+ trigger="<|tool▁calls▁begin|>",
536
+ )
537
+
538
+
480
539
  class MultiFormatParser:
481
540
  def __init__(self, detectors: List[BaseFormatDetector]):
482
541
  """
@@ -543,6 +602,7 @@ class FunctionCallParser:
543
602
  "llama3": Llama32Detector,
544
603
  "qwen25": Qwen25Detector,
545
604
  "mistral": MistralDetector,
605
+ "deepseekv3": DeepSeekV3Detector,
546
606
  }
547
607
 
548
608
  def __init__(self, tools: List[Tool], tool_call_parser: str):
@@ -21,13 +21,6 @@ import torch
21
21
  import torch.nn as nn
22
22
  import torch.nn.functional as F
23
23
 
24
- from sglang.srt.utils import is_cuda_available
25
-
26
- _is_cuda = is_cuda_available()
27
-
28
- if _is_cuda:
29
- from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
30
-
31
24
  from sglang.srt.custom_op import CustomOp
32
25
  from sglang.srt.distributed import (
33
26
  divide,
@@ -35,7 +28,12 @@ from sglang.srt.distributed import (
35
28
  get_tensor_model_parallel_world_size,
36
29
  )
37
30
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
- from sglang.srt.utils import set_weight_attrs
31
+ from sglang.srt.utils import is_cuda, set_weight_attrs
32
+
33
+ _is_cuda = is_cuda()
34
+
35
+ if _is_cuda:
36
+ from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
39
37
 
40
38
  logger = logging.getLogger(__name__)
41
39