sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -42,6 +42,7 @@ from torch.distributed import Backend, ProcessGroup
42
42
  from sglang.srt.utils import (
43
43
  direct_register_custom_op,
44
44
  is_cuda_alike,
45
+ is_npu,
45
46
  supports_custom_op,
46
47
  )
47
48
 
@@ -206,6 +207,7 @@ class GroupCoordinator:
206
207
  use_custom_allreduce: bool,
207
208
  use_hpu_communicator: bool,
208
209
  use_xpu_communicator: bool,
210
+ use_npu_communicator: bool,
209
211
  use_message_queue_broadcaster: bool = False,
210
212
  group_name: Optional[str] = None,
211
213
  ):
@@ -244,6 +246,7 @@ class GroupCoordinator:
244
246
  self.use_custom_allreduce = use_custom_allreduce
245
247
  self.use_hpu_communicator = use_hpu_communicator
246
248
  self.use_xpu_communicator = use_xpu_communicator
249
+ self.use_npu_communicator = use_npu_communicator
247
250
  self.use_message_queue_broadcaster = use_message_queue_broadcaster
248
251
 
249
252
  # lazy import to avoid documentation build error
@@ -291,6 +294,14 @@ class GroupCoordinator:
291
294
  if use_xpu_communicator and self.world_size > 1:
292
295
  self.xpu_communicator = XpuCommunicator(group=self.device_group)
293
296
 
297
+ from sglang.srt.distributed.device_communicators.npu_communicator import (
298
+ NpuCommunicator,
299
+ )
300
+
301
+ self.npu_communicator: Optional[NpuCommunicator] = None
302
+ if use_npu_communicator and self.world_size > 1:
303
+ self.npu_communicator = NpuCommunicator(group=self.device_group)
304
+
294
305
  from sglang.srt.distributed.device_communicators.shm_broadcast import (
295
306
  MessageQueue,
296
307
  )
@@ -418,6 +429,9 @@ class GroupCoordinator:
418
429
  if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
419
430
  return self.xpu_communicator.all_reduce(input_)
420
431
 
432
+ if self.npu_communicator is not None and not self.npu_communicator.disabled:
433
+ return self.npu_communicator.all_reduce(input_)
434
+
421
435
  if (
422
436
  self.ca_comm is not None
423
437
  and not self.ca_comm.disabled
@@ -497,6 +511,11 @@ class GroupCoordinator:
497
511
  if hpu_comm is not None and not hpu_comm.disabled:
498
512
  return hpu_comm.all_gather(input_, dim)
499
513
 
514
+ # For NPUs, use NPU communicator.
515
+ npu_comm = self.npu_communicator
516
+ if npu_comm is not None and not npu_comm.disabled:
517
+ return npu_comm.all_gather(input_, dim)
518
+
500
519
  if dim < 0:
501
520
  # Convert negative dim to positive.
502
521
  dim += input_.dim()
@@ -941,6 +960,7 @@ def init_world_group(
941
960
  use_custom_allreduce=False,
942
961
  use_hpu_communicator=False,
943
962
  use_xpu_communicator=False,
963
+ use_npu_communicator=False,
944
964
  group_name="world",
945
965
  )
946
966
 
@@ -959,10 +979,11 @@ def init_model_parallel_group(
959
979
  group_ranks=group_ranks,
960
980
  local_rank=local_rank,
961
981
  torch_distributed_backend=backend,
962
- use_pynccl=True,
982
+ use_pynccl=not is_npu(),
963
983
  use_custom_allreduce=use_custom_allreduce,
964
984
  use_hpu_communicator=True,
965
985
  use_xpu_communicator=True,
986
+ use_npu_communicator=True,
966
987
  use_message_queue_broadcaster=use_message_queue_broadcaster,
967
988
  group_name=group_name,
968
989
  )
@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import (
58
58
  )
59
59
  from sglang.srt.managers.scheduler import run_scheduler_process
60
60
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
61
- from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
61
+ from sglang.srt.openai_api.adapter import (
62
+ guess_chat_template_name_from_model_path,
63
+ load_chat_template_for_openai_api,
64
+ )
62
65
  from sglang.srt.server_args import PortArgs, ServerArgs
63
66
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
64
67
  from sglang.srt.utils import (
@@ -123,7 +126,6 @@ class Engine(EngineBase):
123
126
  server_args=server_args,
124
127
  port_args=port_args,
125
128
  )
126
-
127
129
  self.server_args = server_args
128
130
  self.tokenizer_manager = tokenizer_manager
129
131
  self.scheduler_info = scheduler_info
@@ -161,6 +163,9 @@ class Engine(EngineBase):
161
163
  custom_logit_processor: Optional[Union[List[str], str]] = None,
162
164
  return_hidden_states: bool = False,
163
165
  stream: bool = False,
166
+ bootstrap_host: Optional[Union[List[str], str]] = None,
167
+ bootstrap_port: Optional[Union[List[int], int]] = None,
168
+ bootstrap_room: Optional[Union[List[int], int]] = None,
164
169
  ) -> Union[Dict, Iterator[Dict]]:
165
170
  """
166
171
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
@@ -179,6 +184,9 @@ class Engine(EngineBase):
179
184
  custom_logit_processor=custom_logit_processor,
180
185
  return_hidden_states=return_hidden_states,
181
186
  stream=stream,
187
+ bootstrap_host=bootstrap_host,
188
+ bootstrap_port=bootstrap_port,
189
+ bootstrap_room=bootstrap_room,
182
190
  )
183
191
  loop = asyncio.get_event_loop()
184
192
  generator = self.tokenizer_manager.generate_request(obj, None)
@@ -225,6 +233,9 @@ class Engine(EngineBase):
225
233
  lora_path: Optional[List[Optional[str]]] = None,
226
234
  custom_logit_processor: Optional[Union[List[str], str]] = None,
227
235
  stream: bool = False,
236
+ bootstrap_host: Optional[Union[List[str], str]] = None,
237
+ bootstrap_port: Optional[Union[List[int], int]] = None,
238
+ bootstrap_room: Optional[Union[List[int], int]] = None,
228
239
  ) -> Union[Dict, AsyncIterator[Dict]]:
229
240
  """
230
241
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
@@ -242,6 +253,9 @@ class Engine(EngineBase):
242
253
  lora_path=lora_path,
243
254
  stream=stream,
244
255
  custom_logit_processor=custom_logit_processor,
256
+ bootstrap_host=bootstrap_host,
257
+ bootstrap_port=bootstrap_port,
258
+ bootstrap_room=bootstrap_room,
245
259
  )
246
260
  generator = self.tokenizer_manager.generate_request(obj, None)
247
261
 
@@ -298,7 +312,6 @@ class Engine(EngineBase):
298
312
  internal_states = loop.run_until_complete(
299
313
  self.tokenizer_manager.get_internal_state()
300
314
  )
301
-
302
315
  return {
303
316
  **dataclasses.asdict(self.tokenizer_manager.server_args),
304
317
  **self.scheduler_info,
@@ -347,8 +360,8 @@ class Engine(EngineBase):
347
360
  load_format: Optional[str] = None,
348
361
  flush_cache: bool = True,
349
362
  ):
350
- """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be true
351
- to avoid duplicated operations such as clearing cache."""
363
+ """Update weights from distributed source. If there are going to be more updates, set `flush_cache` to be false
364
+ to avoid duplicated cache cleaning operation."""
352
365
  obj = UpdateWeightsFromTensorReqInput(
353
366
  serialized_named_tensors=[
354
367
  MultiprocessingSerializer.serialize(named_tensors)
@@ -450,7 +463,7 @@ def _set_envs_and_config(server_args: ServerArgs):
450
463
  if server_args.attention_backend == "flashinfer":
451
464
  assert_pkg_version(
452
465
  "flashinfer_python",
453
- "0.2.3",
466
+ "0.2.5",
454
467
  "Please uninstall the old version and "
455
468
  "reinstall the latest version by following the instructions "
456
469
  "at https://docs.flashinfer.ai/installation.html.",
@@ -458,7 +471,7 @@ def _set_envs_and_config(server_args: ServerArgs):
458
471
  if _is_cuda:
459
472
  assert_pkg_version(
460
473
  "sgl-kernel",
461
- "0.1.0",
474
+ "0.1.1",
462
475
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
463
476
  )
464
477
 
@@ -517,25 +530,44 @@ def _launch_subprocesses(
517
530
  )
518
531
 
519
532
  scheduler_pipe_readers = []
520
- tp_size_per_node = server_args.tp_size // server_args.nnodes
533
+
534
+ nnodes_per_tp_group = max(server_args.nnodes // server_args.pp_size, 1)
535
+ tp_size_per_node = server_args.tp_size // nnodes_per_tp_group
521
536
  tp_rank_range = range(
522
- tp_size_per_node * server_args.node_rank,
523
- tp_size_per_node * (server_args.node_rank + 1),
537
+ tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group),
538
+ tp_size_per_node * (server_args.node_rank % nnodes_per_tp_group + 1),
524
539
  )
525
- for tp_rank in tp_rank_range:
526
- reader, writer = mp.Pipe(duplex=False)
527
- gpu_id = (
528
- server_args.base_gpu_id
529
- + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
530
- )
531
- proc = mp.Process(
532
- target=run_scheduler_process,
533
- args=(server_args, port_args, gpu_id, tp_rank, None, writer),
534
- )
535
- with memory_saver_adapter.configure_subprocess():
536
- proc.start()
537
- scheduler_procs.append(proc)
538
- scheduler_pipe_readers.append(reader)
540
+
541
+ pp_size_per_node = max(server_args.pp_size // server_args.nnodes, 1)
542
+ pp_rank_range = range(
543
+ pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group),
544
+ pp_size_per_node * (server_args.node_rank // nnodes_per_tp_group + 1),
545
+ )
546
+
547
+ for pp_rank in pp_rank_range:
548
+ for tp_rank in tp_rank_range:
549
+ reader, writer = mp.Pipe(duplex=False)
550
+ gpu_id = (
551
+ server_args.base_gpu_id
552
+ + ((pp_rank % pp_size_per_node) * tp_size_per_node)
553
+ + (tp_rank % tp_size_per_node) * server_args.gpu_id_step
554
+ )
555
+ proc = mp.Process(
556
+ target=run_scheduler_process,
557
+ args=(
558
+ server_args,
559
+ port_args,
560
+ gpu_id,
561
+ tp_rank,
562
+ pp_rank,
563
+ None,
564
+ writer,
565
+ ),
566
+ )
567
+ with memory_saver_adapter.configure_subprocess():
568
+ proc.start()
569
+ scheduler_procs.append(proc)
570
+ scheduler_pipe_readers.append(reader)
539
571
  else:
540
572
  # Launch the data parallel controller
541
573
  reader, writer = mp.Pipe(duplex=False)
@@ -584,6 +616,8 @@ def _launch_subprocesses(
584
616
  load_chat_template_for_openai_api(
585
617
  tokenizer_manager, server_args.chat_template, server_args.model_path
586
618
  )
619
+ else:
620
+ guess_chat_template_name_from_model_path(server_args.model_path)
587
621
 
588
622
  if server_args.completion_template:
589
623
  load_completion_template_for_openai_api(server_args.completion_template)
@@ -42,7 +42,10 @@ from fastapi import FastAPI, File, Form, Request, UploadFile
42
42
  from fastapi.middleware.cors import CORSMiddleware
43
43
  from fastapi.responses import ORJSONResponse, Response, StreamingResponse
44
44
 
45
- from sglang.srt.disaggregation.utils import FakeBootstrapHost
45
+ from sglang.srt.disaggregation.utils import (
46
+ FakeBootstrapHost,
47
+ register_disaggregation_server,
48
+ )
46
49
  from sglang.srt.entrypoints.engine import _launch_subprocesses
47
50
  from sglang.srt.function_call_parser import FunctionCallParser
48
51
  from sglang.srt.managers.io_struct import (
@@ -59,6 +62,7 @@ from sglang.srt.managers.io_struct import (
59
62
  ResumeMemoryOccupationReqInput,
60
63
  SeparateReasoningReqInput,
61
64
  SetInternalStateReq,
65
+ SlowDownReqInput,
62
66
  UpdateWeightFromDiskReqInput,
63
67
  UpdateWeightsFromDistributedReqInput,
64
68
  UpdateWeightsFromTensorReqInput,
@@ -491,6 +495,19 @@ async def resume_memory_occupation(
491
495
  return _create_error_response(e)
492
496
 
493
497
 
498
+ @app.api_route("/slow_down", methods=["GET", "POST"])
499
+ async def slow_down(obj: SlowDownReqInput, request: Request):
500
+ """Slow down the system deliberately. Only for testing. Example scenario:
501
+ when we want to test performance of D in large-scale PD disaggregation and have no enough nodes for P,
502
+ we can use this to slow down D to let it have enough running sequences, and then disable slowdown
503
+ to let it run in full batch size.
504
+ """
505
+ try:
506
+ await _global_state.tokenizer_manager.slow_down(obj, request)
507
+ except Exception as e:
508
+ return _create_error_response(e)
509
+
510
+
494
511
  @app.api_route("/open_session", methods=["GET", "POST"])
495
512
  async def open_session(obj: OpenSessionReqInput, request: Request):
496
513
  """Open a session, and return its unique session id."""
@@ -675,6 +692,8 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
675
692
  **(vertex_req.parameters or {}),
676
693
  )
677
694
  ret = await generate_request(req, raw_request)
695
+ if isinstance(ret, Response):
696
+ return ret
678
697
  return ORJSONResponse({"predictions": ret})
679
698
 
680
699
 
@@ -869,5 +888,13 @@ def _wait_and_warmup(
869
888
  if server_args.debug_tensor_dump_input_file:
870
889
  kill_process_tree(os.getpid())
871
890
 
891
+ if server_args.pdlb_url is not None:
892
+ register_disaggregation_server(
893
+ server_args.disaggregation_mode,
894
+ server_args.port,
895
+ server_args.disaggregation_bootstrap_port,
896
+ server_args.pdlb_url,
897
+ )
898
+
872
899
  if launch_callback is not None:
873
900
  launch_callback()
@@ -37,6 +37,7 @@ class VerlEngine:
37
37
  monkey_patch_torch_reductions()
38
38
  self._device_mesh_cpu = device_mesh_cpu
39
39
  self._tp_rank = device_mesh_cpu.get_local_rank()
40
+ self._rank = device_mesh_cpu.get_rank()
40
41
  self._tp_size = device_mesh_cpu.size()
41
42
  tp_size_per_node = self._tp_size // nnodes
42
43
  node_rank = self._tp_rank // tp_size_per_node
@@ -114,7 +115,7 @@ class VerlEngine:
114
115
  # Most naive implementation, can extract tensor and send via gloo if too slow
115
116
  [output] = broadcast_pyobj(
116
117
  data=[output],
117
- rank=self._tp_rank,
118
+ rank=self._rank,
118
119
  dist_group=self._device_mesh_cpu.get_group(),
119
120
  src=self._device_mesh_cpu.mesh[0].item(),
120
121
  force_cpu_device=False,
@@ -157,7 +158,7 @@ class VerlEngine:
157
158
  )
158
159
 
159
160
  if self._tp_rank == 0:
160
- self._engine.tokenizer_manager.flush_cache()
161
+ self._engine.flush_cache()
161
162
 
162
163
  def release_memory_occupation(self):
163
164
  if self._tp_rank == 0:
@@ -1,3 +1,4 @@
1
+ import ast
1
2
  import json
2
3
  import logging
3
4
  import re
@@ -664,6 +665,101 @@ class MultiFormatParser:
664
665
  return final_normal_text, final_calls
665
666
 
666
667
 
668
+ class PythonicDetector(BaseFormatDetector):
669
+ """
670
+ Detector for Llama-3.2 and Llama-4 models with pythonic tool call format.
671
+ Assumes function call format:
672
+ [tool1(arg1=val1, arg2=val2), tool2(arg1=val3)]
673
+ Arguments are Python literals (not JSON).
674
+ """
675
+
676
+ def __init__(self):
677
+ super().__init__()
678
+ self.tool_call_regex = re.compile(
679
+ r"\[([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s)?\),\s*)*([a-zA-Z]+\w*\(([a-zA-Z]+\w*=.*,\s*)*([a-zA-Z]+\w*=.*\s*)?\)\s*)+\]",
680
+ re.DOTALL,
681
+ )
682
+
683
+ def has_tool_call(self, text: str) -> bool:
684
+ return bool(self.tool_call_regex.match(text.strip()))
685
+
686
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
687
+ # Try parsing the text as a Python list of function calls
688
+ text = text.strip()
689
+ if not (text.startswith("[") and text.endswith("]")):
690
+ # Not a pythonic tool call format
691
+ return StreamingParseResult(normal_text=text, calls=[])
692
+ try:
693
+ module = ast.parse(text)
694
+ parsed = getattr(module.body[0], "value", None)
695
+ if not (
696
+ isinstance(parsed, ast.List)
697
+ and all(isinstance(e, ast.Call) for e in parsed.elts)
698
+ ):
699
+ return StreamingParseResult(normal_text=text, calls=[])
700
+ calls = []
701
+ tool_indices = {
702
+ tool.function.name: i
703
+ for i, tool in enumerate(tools)
704
+ if tool.function.name
705
+ }
706
+ for call in parsed.elts:
707
+ if not isinstance(call.func, ast.Name):
708
+ continue
709
+ function_name = call.func.id
710
+ arguments = {}
711
+ for keyword in call.keywords:
712
+ arguments[keyword.arg] = self._get_parameter_value(keyword.value)
713
+ calls.append(
714
+ ToolCallItem(
715
+ tool_index=tool_indices.get(function_name, -1),
716
+ name=function_name,
717
+ parameters=json.dumps(arguments, ensure_ascii=False),
718
+ )
719
+ )
720
+ return StreamingParseResult(normal_text="", calls=calls)
721
+ except Exception:
722
+ logger.exception("Error in pythonic tool call parsing.")
723
+ return StreamingParseResult(normal_text=text, calls=[])
724
+
725
+ def parse_streaming_increment(
726
+ self, new_text: str, tools: List[Tool]
727
+ ) -> StreamingParseResult:
728
+ """
729
+ Streaming incremental parsing for pythonic tool calls.
730
+ Buffers input until a complete pythonic tool call (from [ to ]) is found,
731
+ then parses and emits any detected calls.
732
+ """
733
+ self._buffer += new_text
734
+ start = self._buffer.find("[")
735
+ end = self._buffer.find("]", start)
736
+ if start != -1 and end != -1:
737
+ call_text = self._buffer[start : end + 1]
738
+ result = self.detect_and_parse(call_text, tools)
739
+ self._buffer = self._buffer[end + 1 :]
740
+ return result
741
+ return StreamingParseResult(normal_text="")
742
+
743
+ def _get_parameter_value(self, val):
744
+ if isinstance(val, ast.Constant):
745
+ return val.value
746
+ elif isinstance(val, ast.Dict):
747
+ return {
748
+ k.value: self._get_parameter_value(v)
749
+ for k, v in zip(val.keys, val.values)
750
+ }
751
+ elif isinstance(val, ast.List):
752
+ return [self._get_parameter_value(v) for v in val.elts]
753
+ else:
754
+ raise ValueError("Tool call arguments must be literals")
755
+
756
+ def structure_info(self) -> _GetInfoFunc:
757
+ def info(name: str):
758
+ return StructureInfo(begin="[", end="]", trigger="")
759
+
760
+ return info
761
+
762
+
667
763
  class FunctionCallParser:
668
764
  """
669
765
  In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
@@ -675,6 +771,7 @@ class FunctionCallParser:
675
771
  "qwen25": Qwen25Detector,
676
772
  "mistral": MistralDetector,
677
773
  "deepseekv3": DeepSeekV3Detector,
774
+ "pythonic": PythonicDetector,
678
775
  }
679
776
 
680
777
  def __init__(self, tools: List[Tool], tool_call_parser: str):
@@ -19,6 +19,7 @@ import warnings
19
19
  from pathlib import Path
20
20
  from typing import Dict, Optional, Type, Union
21
21
 
22
+ import transformers
22
23
  from huggingface_hub import snapshot_download
23
24
  from transformers import (
24
25
  AutoConfig,
@@ -26,6 +27,7 @@ from transformers import (
26
27
  AutoTokenizer,
27
28
  PretrainedConfig,
28
29
  PreTrainedTokenizer,
30
+ PreTrainedTokenizerBase,
29
31
  PreTrainedTokenizerFast,
30
32
  )
31
33
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
@@ -35,8 +37,10 @@ from sglang.srt.configs import (
35
37
  DbrxConfig,
36
38
  DeepseekVL2Config,
37
39
  ExaoneConfig,
40
+ KimiVLConfig,
38
41
  MultiModalityConfig,
39
42
  )
43
+ from sglang.srt.configs.internvl import InternVLChatConfig
40
44
  from sglang.srt.connector import create_remote_connector
41
45
  from sglang.srt.utils import is_remote_url
42
46
 
@@ -46,6 +50,8 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
46
50
  ExaoneConfig.model_type: ExaoneConfig,
47
51
  DeepseekVL2Config.model_type: DeepseekVL2Config,
48
52
  MultiModalityConfig.model_type: MultiModalityConfig,
53
+ KimiVLConfig.model_type: KimiVLConfig,
54
+ InternVLChatConfig.model_type: InternVLChatConfig,
49
55
  }
50
56
 
51
57
  for name, cls in _CONFIG_REGISTRY.items():
@@ -88,6 +94,12 @@ def get_config(
88
94
  config = config_class.from_pretrained(model, revision=revision)
89
95
  # NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
90
96
  setattr(config, "_name_or_path", model)
97
+
98
+ if isinstance(model, str) and config.model_type == "internvl_chat":
99
+ for key, val in config.llm_config.__dict__.items():
100
+ if not hasattr(config, key):
101
+ setattr(config, key, val)
102
+
91
103
  if model_override_args:
92
104
  config.update(model_override_args)
93
105
 
@@ -209,6 +221,13 @@ def get_tokenizer(
209
221
  return tokenizer
210
222
 
211
223
 
224
+ # Some models doesn't have an available processor, e.g.: InternVL
225
+ def get_tokenizer_from_processor(processor):
226
+ if isinstance(processor, PreTrainedTokenizerBase):
227
+ return processor
228
+ return processor.tokenizer
229
+
230
+
212
231
  def get_processor(
213
232
  tokenizer_name: str,
214
233
  *args,
@@ -244,7 +263,9 @@ def get_processor(
244
263
  **kwargs,
245
264
  )
246
265
 
247
- attach_additional_stop_token_ids(processor.tokenizer)
266
+ tokenizer = get_tokenizer_from_processor(processor)
267
+
268
+ attach_additional_stop_token_ids(tokenizer)
248
269
  return processor
249
270
 
250
271
 
@@ -268,7 +268,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
268
268
  reshape_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
269
269
 
270
270
  o = cutlass_mla_decode(
271
- q_nope_and_q_pe=reshape_q,
271
+ q_nope_and_q_pe=reshape_q.to(self.q_data_type),
272
272
  kv_c_and_k_pe_cache=k_cache.view(-1, PAGE_SIZE, self.kv_cache_dim),
273
273
  seq_lens=forward_batch.seq_lens.to(torch.int32),
274
274
  page_table=self.forward_metadata.block_kv_indices,