sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -27,12 +27,16 @@ import signal
27
27
  import threading
28
28
  from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
29
29
 
30
+ import zmq
31
+ import zmq.asyncio
32
+
30
33
  # Fix a bug of Python threading
31
34
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
32
35
 
33
36
  import torch
34
37
  import uvloop
35
38
 
39
+ from sglang.srt.code_completion_parser import load_completion_template_for_openai_api
36
40
  from sglang.srt.managers.data_parallel_controller import (
37
41
  run_data_parallel_controller_process,
38
42
  )
@@ -44,6 +48,8 @@ from sglang.srt.managers.io_struct import (
44
48
  InitWeightsUpdateGroupReqInput,
45
49
  ReleaseMemoryOccupationReqInput,
46
50
  ResumeMemoryOccupationReqInput,
51
+ RpcReqInput,
52
+ RpcReqOutput,
47
53
  UpdateWeightFromDiskReqInput,
48
54
  UpdateWeightsFromDistributedReqInput,
49
55
  UpdateWeightsFromTensorReqInput,
@@ -57,6 +63,7 @@ from sglang.srt.utils import (
57
63
  MultiprocessingSerializer,
58
64
  assert_pkg_version,
59
65
  configure_logger,
66
+ get_zmq_socket,
60
67
  kill_process_tree,
61
68
  launch_dummy_health_check_server,
62
69
  maybe_set_triton_cache_manager,
@@ -102,15 +109,25 @@ class Engine:
102
109
  # Shutdown the subprocesses automatically when the program exits
103
110
  atexit.register(self.shutdown)
104
111
 
112
+ # Allocate ports for inter-process communications
113
+ port_args = PortArgs.init_new(server_args)
114
+ logger.info(f"{server_args=}")
115
+
105
116
  # Launch subprocesses
106
117
  tokenizer_manager, scheduler_info = _launch_subprocesses(
107
- server_args=server_args
118
+ server_args=server_args,
119
+ port_args=port_args,
108
120
  )
109
121
 
110
122
  self.server_args = server_args
111
123
  self.tokenizer_manager = tokenizer_manager
112
124
  self.scheduler_info = scheduler_info
113
125
 
126
+ context = zmq.Context(2)
127
+ self.send_to_rpc = get_zmq_socket(
128
+ context, zmq.DEALER, port_args.rpc_ipc_name, True
129
+ )
130
+
114
131
  def generate(
115
132
  self,
116
133
  # The input prompt. It can be a single prompt or a batch of prompts.
@@ -232,6 +249,13 @@ class Engine:
232
249
  """Shutdown the engine"""
233
250
  kill_process_tree(os.getpid(), include_parent=False)
234
251
 
252
+ def __enter__(self):
253
+ return self
254
+
255
+ def __exit__(self, exc_type, exc_value, traceback):
256
+ self.shutdown()
257
+ return False
258
+
235
259
  def start_profile(self):
236
260
  loop = asyncio.get_event_loop()
237
261
  loop.run_until_complete(self.tokenizer_manager.start_profile())
@@ -296,7 +320,10 @@ class Engine:
296
320
  """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
297
321
  to avoid duplicated operations such as clearing cache."""
298
322
  obj = UpdateWeightsFromTensorReqInput(
299
- serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors),
323
+ serialized_named_tensors=[
324
+ MultiprocessingSerializer.serialize(named_tensors)
325
+ for _ in range(self.server_args.tp_size)
326
+ ],
300
327
  load_format=load_format,
301
328
  flush_cache=flush_cache,
302
329
  )
@@ -350,6 +377,23 @@ class Engine:
350
377
  self.tokenizer_manager.resume_memory_occupation(obj, None)
351
378
  )
352
379
 
380
+ """
381
+ Execute an RPC call on all scheduler processes.
382
+ """
383
+
384
+ def collective_rpc(self, method: str, **kwargs):
385
+ obj = RpcReqInput(method=method, parameters=kwargs)
386
+ self.send_to_rpc.send_pyobj(obj)
387
+ recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY)
388
+ assert isinstance(recv_req, RpcReqOutput)
389
+ assert recv_req.success, recv_req.message
390
+
391
+ def save_remote_model(self, **kwargs):
392
+ self.collective_rpc("save_remote_model", **kwargs)
393
+
394
+ def save_sharded_model(self, **kwargs):
395
+ self.collective_rpc("save_sharded_model", **kwargs)
396
+
353
397
 
354
398
  def _set_envs_and_config(server_args: ServerArgs):
355
399
  # Set global environments
@@ -408,7 +452,9 @@ def _set_envs_and_config(server_args: ServerArgs):
408
452
  mp.set_start_method("spawn", force=True)
409
453
 
410
454
 
411
- def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
455
+ def _launch_subprocesses(
456
+ server_args: ServerArgs, port_args: Optional[PortArgs] = None
457
+ ) -> Tuple[TokenizerManager, Dict]:
412
458
  """
413
459
  Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
414
460
  """
@@ -418,8 +464,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
418
464
  _set_envs_and_config(server_args)
419
465
 
420
466
  # Allocate ports for inter-process communications
421
- port_args = PortArgs.init_new(server_args)
422
- logger.info(f"{server_args=}")
467
+ if port_args is None:
468
+ port_args = PortArgs.init_new(server_args)
469
+ logger.info(f"{server_args=}")
423
470
 
424
471
  # If using model from www.modelscope.cn, first download the model.
425
472
  server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
@@ -502,6 +549,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
502
549
  tokenizer_manager, server_args.chat_template, server_args.model_path
503
550
  )
504
551
 
552
+ if server_args.completion_template:
553
+ load_completion_template_for_openai_api(server_args.completion_template)
554
+
505
555
  # Wait for the model to finish loading
506
556
  scheduler_infos = []
507
557
  for i in range(len(scheduler_pipe_readers)):
@@ -14,11 +14,12 @@
14
14
  """
15
15
  The entry point of inference server. (SRT = SGLang Runtime)
16
16
 
17
- This file implements HTTP APIs for the inferenc engine via fastapi.
17
+ This file implements HTTP APIs for the inference engine via fastapi.
18
18
  """
19
19
 
20
20
  import asyncio
21
21
  import dataclasses
22
+ import json
22
23
  import logging
23
24
  import multiprocessing as multiprocessing
24
25
  import os
@@ -259,6 +260,29 @@ async def generate_request(obj: GenerateReqInput, request: Request):
259
260
  return _create_error_response(e)
260
261
 
261
262
 
263
+ @app.api_route("/generate_from_file", methods=["POST"])
264
+ async def generate_from_file_request(file: UploadFile, request: Request):
265
+ """Handle a generate request, this is purely to work with input_embeds."""
266
+ content = await file.read()
267
+ input_embeds = json.loads(content.decode("utf-8"))
268
+
269
+ obj = GenerateReqInput(
270
+ input_embeds=input_embeds,
271
+ sampling_params={
272
+ "repetition_penalty": 1.2,
273
+ "temperature": 0.2,
274
+ "max_new_tokens": 512,
275
+ },
276
+ )
277
+
278
+ try:
279
+ ret = await _global_state.generate_request(obj, request).__anext__()
280
+ return ret
281
+ except ValueError as e:
282
+ logger.error(f"Error: {e}")
283
+ return _create_error_response(e)
284
+
285
+
262
286
  @app.api_route("/encode", methods=["POST", "PUT"])
263
287
  async def encode_request(obj: EmbeddingReqInput, request: Request):
264
288
  """Handle an embedding request."""
@@ -283,7 +307,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
283
307
  return _create_error_response(e)
284
308
 
285
309
 
286
- @app.post("/flush_cache")
310
+ @app.api_route("/flush_cache", methods=["GET", "POST"])
287
311
  async def flush_cache():
288
312
  """Flush the radix cache."""
289
313
  _global_state.tokenizer_manager.flush_cache()
@@ -319,6 +343,36 @@ async def stop_profile_async():
319
343
  )
320
344
 
321
345
 
346
+ @app.api_route("/start_expert_distribution_record", methods=["GET", "POST"])
347
+ async def start_expert_distribution_record_async():
348
+ """Start recording the expert distribution. Clear the previous record if any."""
349
+ await _global_state.tokenizer_manager.start_expert_distribution_record()
350
+ return Response(
351
+ content="Start recording the expert distribution.\n",
352
+ status_code=200,
353
+ )
354
+
355
+
356
+ @app.api_route("/stop_expert_distribution_record", methods=["GET", "POST"])
357
+ async def stop_expert_distribution_record_async():
358
+ """Stop recording the expert distribution."""
359
+ await _global_state.tokenizer_manager.stop_expert_distribution_record()
360
+ return Response(
361
+ content="Stop recording the expert distribution.\n",
362
+ status_code=200,
363
+ )
364
+
365
+
366
+ @app.api_route("/dump_expert_distribution_record", methods=["GET", "POST"])
367
+ async def dump_expert_distribution_record_async():
368
+ """Dump expert distribution record."""
369
+ await _global_state.tokenizer_manager.dump_expert_distribution_record()
370
+ return Response(
371
+ content="Dump expert distribution record.\n",
372
+ status_code=200,
373
+ )
374
+
375
+
322
376
  @app.post("/update_weights_from_disk")
323
377
  async def update_weights_from_disk(obj: UpdateWeightFromDiskReqInput, request: Request):
324
378
  """Update the weights from disk inplace without re-launching the server."""
@@ -507,7 +561,13 @@ def available_models():
507
561
  served_model_names = [_global_state.tokenizer_manager.served_model_name]
508
562
  model_cards = []
509
563
  for served_model_name in served_model_names:
510
- model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
564
+ model_cards.append(
565
+ ModelCard(
566
+ id=served_model_name,
567
+ root=served_model_name,
568
+ max_model_len=_global_state.tokenizer_manager.model_config.context_len,
569
+ )
570
+ )
511
571
  return ModelList(data=model_cards)
512
572
 
513
573
 
@@ -706,9 +766,15 @@ def _wait_and_warmup(
706
766
  },
707
767
  }
708
768
  if server_args.skip_tokenizer_init:
709
- json_data["input_ids"] = [10, 11, 12]
769
+ json_data["input_ids"] = [[10, 11, 12] for _ in range(server_args.dp_size)]
770
+ # TODO Workaround the bug that embedding errors for list of size 1
771
+ if server_args.dp_size == 1:
772
+ json_data["input_ids"] = json_data["input_ids"][0]
710
773
  else:
711
- json_data["text"] = "The capital city of France is"
774
+ json_data["text"] = ["The capital city of France is"] * server_args.dp_size
775
+ # TODO Workaround the bug that embedding errors for list of size 1
776
+ if server_args.dp_size == 1:
777
+ json_data["text"] = json_data["text"][0]
712
778
 
713
779
  # Debug dumping
714
780
  if server_args.debug_tensor_dump_input_file:
@@ -719,14 +785,13 @@ def _wait_and_warmup(
719
785
  json_data["sampling_params"]["max_new_tokens"] = 0
720
786
 
721
787
  try:
722
- for i in range(server_args.dp_size):
723
- res = requests.post(
724
- url + request_name,
725
- json=json_data,
726
- headers=headers,
727
- timeout=600,
728
- )
729
- assert res.status_code == 200, f"{res}"
788
+ res = requests.post(
789
+ url + request_name,
790
+ json=json_data,
791
+ headers=headers,
792
+ timeout=600,
793
+ )
794
+ assert res.status_code == 200, f"{res}"
730
795
  except Exception:
731
796
  last_traceback = get_exception_traceback()
732
797
  if pipe_finish_writer is not None:
@@ -19,6 +19,7 @@ import torch.distributed as dist
19
19
  from torch.distributed.tensor import DeviceMesh, DTensor
20
20
 
21
21
  from sglang.srt.model_executor.model_runner import LocalSerializedTensor
22
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
22
23
  from sglang.srt.server import Engine
23
24
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
24
25
 
@@ -30,6 +31,7 @@ class VerlEngine:
30
31
  nnodes: int = 1,
31
32
  **kwargs,
32
33
  ):
34
+ monkey_patch_torch_reductions()
33
35
  self._device_mesh_cpu = device_mesh_cpu
34
36
  self._tp_rank = device_mesh_cpu.get_local_rank()
35
37
  self._tp_size = device_mesh_cpu.size()
@@ -1,12 +1,21 @@
1
1
  import json
2
2
  import logging
3
3
  import re
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
4
6
  from json import JSONDecodeError, JSONDecoder
5
- from typing import Any, Dict, List, Optional, Tuple
7
+ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
6
8
 
7
9
  import partial_json_parser
10
+ from partial_json_parser.core.exceptions import MalformedJSON
8
11
  from partial_json_parser.core.options import Allow
9
- from pydantic import BaseModel, Field
12
+ from pydantic import BaseModel
13
+
14
+ from sglang.srt.openai_api.protocol import (
15
+ StructuralTagResponseFormat,
16
+ StructuresResponseFormat,
17
+ Tool,
18
+ )
10
19
 
11
20
  logger = logging.getLogger(__name__)
12
21
 
@@ -19,14 +28,6 @@ TOOLS_TAG_LIST = [
19
28
  ]
20
29
 
21
30
 
22
- class Function(BaseModel):
23
- """Function Tool Template."""
24
-
25
- description: Optional[str] = Field(default=None, examples=[None])
26
- name: Optional[str] = None
27
- parameters: Optional[object] = None
28
-
29
-
30
31
  class ToolCallItem(BaseModel):
31
32
  """Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
32
33
 
@@ -74,7 +75,22 @@ class StreamingParseResult:
74
75
  self.calls = calls or []
75
76
 
76
77
 
77
- class BaseFormatDetector:
78
+ @dataclass
79
+ class StructureInfo:
80
+ begin: str
81
+ end: str
82
+ trigger: str
83
+
84
+
85
+ _GetInfoFunc = Callable[[str], StructureInfo]
86
+ """
87
+ helper alias of function
88
+ ususally it is a function that takes a name string and returns a StructureInfo object,
89
+ which can be used to construct a structural_tag object
90
+ """
91
+
92
+
93
+ class BaseFormatDetector(ABC):
78
94
  """Base class providing two sets of interfaces: one-time and streaming incremental."""
79
95
 
80
96
  def __init__(self):
@@ -90,26 +106,12 @@ class BaseFormatDetector:
90
106
  self.bot_token = ""
91
107
  self.eot_token = ""
92
108
 
93
- def parse_base_json(self, action: Any, tools: List[Function]) -> List[ToolCallItem]:
109
+ def parse_base_json(self, action: Any, tools: List[Tool]) -> List[ToolCallItem]:
94
110
  tool_indices = {
95
111
  tool.function.name: i for i, tool in enumerate(tools) if tool.function.name
96
112
  }
97
113
  if not isinstance(action, list):
98
- name = action.get("name")
99
- if not name or name not in tool_indices:
100
- logger.warning(f"Model attempted to call undefined function: {name}")
101
- return []
102
-
103
- return [
104
- ToolCallItem(
105
- tool_index=tool_indices[name],
106
- name=name,
107
- parameters=json.dumps(
108
- action.get("parameters") or action.get("arguments", {}),
109
- ensure_ascii=False,
110
- ),
111
- )
112
- ]
114
+ action = [action]
113
115
 
114
116
  results = []
115
117
  for act in action:
@@ -125,19 +127,22 @@ class BaseFormatDetector:
125
127
  ),
126
128
  )
127
129
  )
130
+ else:
131
+ logger.warning(f"Model attempted to call undefined function: {name}")
128
132
 
129
133
  return results
130
134
 
131
- def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
135
+ @abstractmethod
136
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
132
137
  """
133
138
  Parses the text in one go. Returns success=True if the format matches, otherwise False.
134
139
  Note that leftover_text here represents "content that this parser will not consume further".
135
140
  """
136
141
  action = json.loads(text)
137
- return self.parse_base_json(action, tools)
142
+ return StreamingParseResult(calls=self.parse_base_json(action, tools))
138
143
 
139
144
  def parse_streaming_increment(
140
- self, new_text: str, tools: List[Function]
145
+ self, new_text: str, tools: List[Tool]
141
146
  ) -> StreamingParseResult:
142
147
  """
143
148
  Streaming incremental parsing with tool validation.
@@ -196,7 +201,7 @@ class BaseFormatDetector:
196
201
  obj["arguments"] = obj["parameters"]
197
202
  tool_call_arr.append(obj)
198
203
 
199
- except partial_json_parser.core.exceptions.MalformedJSON:
204
+ except MalformedJSON:
200
205
  return StreamingParseResult()
201
206
 
202
207
  if len(tool_call_arr) == 0:
@@ -285,7 +290,6 @@ class BaseFormatDetector:
285
290
  calls=[
286
291
  ToolCallItem(
287
292
  tool_index=self.current_tool_id,
288
- name="",
289
293
  parameters=argument_diff,
290
294
  )
291
295
  ],
@@ -302,6 +306,14 @@ class BaseFormatDetector:
302
306
  logger.error(f"Error in parse_streaming_increment: {e}")
303
307
  return StreamingParseResult()
304
308
 
309
+ @abstractmethod
310
+ def has_tool_call(self, text: str) -> bool:
311
+ raise NotImplementedError()
312
+
313
+ @abstractmethod
314
+ def structure_info(self) -> _GetInfoFunc:
315
+ raise NotImplementedError()
316
+
305
317
 
306
318
  class Qwen25Detector(BaseFormatDetector):
307
319
  """
@@ -322,7 +334,7 @@ class Qwen25Detector(BaseFormatDetector):
322
334
  """Check if the text contains a Qwen 2.5 format tool call."""
323
335
  return self.bot_token in text
324
336
 
325
- def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
337
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
326
338
  """
327
339
  One-time parsing: Detects and parses tool calls in the provided text.
328
340
 
@@ -330,15 +342,24 @@ class Qwen25Detector(BaseFormatDetector):
330
342
  :param tools: List of available tools.
331
343
  :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
332
344
  """
333
- if "<tool_call>" not in text:
334
- return []
335
- pattern = r"<tool_call>(.*?)</tool_call>"
345
+ idx = text.find(self.bot_token)
346
+ normal_text = text[:idx].strip() if idx != -1 else text
347
+ if self.bot_token not in text:
348
+ return StreamingParseResult(normal_text=normal_text, calls=[])
349
+ pattern = rf"{self.bot_token}(.*?){self.eot_token}"
336
350
  match_result_list = re.findall(pattern, text, re.DOTALL)
337
351
  calls = []
338
352
  for match_result in match_result_list:
339
353
  match_result = json.loads(match_result)
340
354
  calls.extend(self.parse_base_json(match_result, tools))
341
- return calls
355
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
356
+
357
+ def structure_info(self) -> _GetInfoFunc:
358
+ return lambda name: StructureInfo(
359
+ begin='<tool_call>{"name":"' + name + '", "arguments":',
360
+ end="}</tool_call>",
361
+ trigger="<tool_call>",
362
+ )
342
363
 
343
364
 
344
365
  class MistralDetector(BaseFormatDetector):
@@ -374,7 +395,7 @@ class MistralDetector(BaseFormatDetector):
374
395
  else:
375
396
  return ""
376
397
 
377
- def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
398
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
378
399
  """
379
400
  One-time parsing: Detects and parses tool calls in the provided text.
380
401
 
@@ -382,6 +403,8 @@ class MistralDetector(BaseFormatDetector):
382
403
  :param tools: List of available tools.
383
404
  :return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
384
405
  """
406
+ idx = text.find(self.bot_token)
407
+ normal_text = text[:idx].strip() if idx != -1 else text
385
408
  text = self._clean_text(text)
386
409
  tool_content = text.replace("[TOOL_CALLS]", "").strip()
387
410
  raw_tool_calls = self.tool_call_regex.findall(tool_content)
@@ -391,7 +414,14 @@ class MistralDetector(BaseFormatDetector):
391
414
  function_call_arr = json.loads(raw_tool_call)
392
415
  for match_result in function_call_arr:
393
416
  calls.extend(self.parse_base_json(match_result, tools))
394
- return calls
417
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
418
+
419
+ def structure_info(self) -> _GetInfoFunc:
420
+ return lambda name: StructureInfo(
421
+ begin='[TOOL_CALLS] [{"name":"' + name + '", "arguments":',
422
+ end="}]",
423
+ trigger="[TOOL_CALLS]",
424
+ )
395
425
 
396
426
 
397
427
  class Llama32Detector(BaseFormatDetector):
@@ -411,19 +441,18 @@ class Llama32Detector(BaseFormatDetector):
411
441
  # prefix the output with the <|python_tag|> token
412
442
  return "<|python_tag|>" in text or text.startswith("{")
413
443
 
414
- def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
444
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
415
445
  """Parse function calls from text, handling multiple JSON objects."""
416
446
  if "<|python_tag|>" not in text and not text.startswith("{"):
417
- return []
447
+ return StreamingParseResult(normal_text=text, calls=[])
418
448
 
419
449
  if "<|python_tag|>" in text:
420
- _, action_text = text.split("<|python_tag|>")
450
+ normal_text, action_text = text.split("<|python_tag|>")
421
451
  else:
422
- action_text = text
452
+ normal_text, action_text = "", text
423
453
 
424
454
  # Split by semicolon and process each part
425
455
  json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
426
-
427
456
  all_actions = []
428
457
  for part in json_parts:
429
458
  try:
@@ -434,12 +463,18 @@ class Llama32Detector(BaseFormatDetector):
434
463
  logger.warning(f"Failed to parse JSON part: {part}")
435
464
  logger.warning(f"JSON parse error: {str(e)}")
436
465
  continue
437
-
466
+ calls = []
438
467
  # Only process if we found valid JSON objects
439
468
  if all_actions:
440
- return self.parse_base_json(all_actions, tools)
441
-
442
- return []
469
+ calls = self.parse_base_json(all_actions, tools)
470
+ return StreamingParseResult(normal_text=normal_text, calls=calls)
471
+
472
+ def structure_info(self) -> _GetInfoFunc:
473
+ return lambda name: StructureInfo(
474
+ begin='<|python_tag|>{"name":"' + name + '", "arguments":',
475
+ end="}",
476
+ trigger="<|python_tag|>",
477
+ )
443
478
 
444
479
 
445
480
  class MultiFormatParser:
@@ -449,7 +484,9 @@ class MultiFormatParser:
449
484
  """
450
485
  self.detectors = detectors
451
486
 
452
- def parse_once(self, text: str, tools: List[Function]):
487
+ def parse_once(
488
+ self, text: str, tools: List[Tool]
489
+ ) -> Tuple[str, list[ToolCallItem]]:
453
490
  """
454
491
  One-time parsing: Loop through detectors until there are no new matches or text is exhausted
455
492
  Return: (final_text, all_calls)
@@ -459,15 +496,19 @@ class MultiFormatParser:
459
496
  final_calls = []
460
497
  final_normal_text = text
461
498
  for detector in self.detectors:
462
- tool_call_list = detector.detect_and_parse(text, tools)
499
+ parsed_result = detector.detect_and_parse(text, tools)
500
+ tool_call_list = parsed_result.calls
463
501
  if len(tool_call_list) > 0: # parsed successfully
464
502
  final_calls = tool_call_list
503
+ final_normal_text = parsed_result.normal_text
465
504
  break
466
505
 
467
506
  # leftover_text is the normal text not consumed by any Detector
468
507
  return final_normal_text, final_calls
469
508
 
470
- def parse_streaming_increment(self, new_text: str, tools: List[Function]):
509
+ def parse_streaming_increment(
510
+ self, new_text: str, tools: List[Tool]
511
+ ) -> Tuple[str, list[ToolCallItem]]:
471
512
  """
472
513
  Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
473
514
  and merge their produced normal_text/calls to return.
@@ -498,13 +539,13 @@ class FunctionCallParser:
498
539
  and returns the resulting normal_text and calls to the upper layer (or SSE).
499
540
  """
500
541
 
501
- ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
542
+ ToolCallParserEnum: Dict[str, Type[BaseFormatDetector]] = {
502
543
  "llama3": Llama32Detector,
503
544
  "qwen25": Qwen25Detector,
504
545
  "mistral": MistralDetector,
505
546
  }
506
547
 
507
- def __init__(self, tools: List[Function], tool_call_parser: str = None):
548
+ def __init__(self, tools: List[Tool], tool_call_parser: str):
508
549
  detectors = []
509
550
  if tool_call_parser:
510
551
  detector_class = self.ToolCallParserEnum.get(tool_call_parser)
@@ -532,7 +573,7 @@ class FunctionCallParser:
532
573
  return True
533
574
  return False
534
575
 
535
- def parse_non_stream(self, full_text: str):
576
+ def parse_non_stream(self, full_text: str) -> Tuple[str, list[ToolCallItem]]:
536
577
  """
537
578
  Non-streaming call: one-time parsing
538
579
  """
@@ -541,7 +582,7 @@ class FunctionCallParser:
541
582
  )
542
583
  return full_normal_text, calls
543
584
 
544
- def parse_stream_chunk(self, chunk_text: str):
585
+ def parse_stream_chunk(self, chunk_text: str) -> Tuple[str, list[ToolCallItem]]:
545
586
  """
546
587
  Streaming call: incremental parsing
547
588
  """
@@ -549,3 +590,40 @@ class FunctionCallParser:
549
590
  chunk_text, self.tools
550
591
  )
551
592
  return normal_text, calls
593
+
594
+ def structure_infos(self) -> List[_GetInfoFunc]:
595
+ """
596
+ Returns a list of structure_info functions for each detector
597
+ """
598
+ return [
599
+ detector.structure_info() for detector in self.multi_format_parser.detectors
600
+ ]
601
+
602
+ def get_structure_tag(self) -> StructuralTagResponseFormat:
603
+ tool_structures: List[StructuresResponseFormat] = list()
604
+ tool_trigger_set: Set[str] = set()
605
+
606
+ for wrapper in self.structure_infos():
607
+ for tool in self.tools:
608
+ function = tool.function
609
+ name = function.name
610
+ assert name is not None
611
+ info = wrapper(name)
612
+
613
+ # accept all if not strict, otherwise only accept the schema
614
+ schema = function.parameters if function.strict else {}
615
+
616
+ tool_structures.append(
617
+ StructuresResponseFormat(
618
+ begin=info.begin,
619
+ schema=schema, # type: ignore
620
+ end=info.end,
621
+ )
622
+ )
623
+ tool_trigger_set.add(info.trigger)
624
+
625
+ return StructuralTagResponseFormat(
626
+ type="structural_tag",
627
+ structures=tool_structures,
628
+ triggers=list(tool_trigger_set),
629
+ )