sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -16,11 +16,13 @@ from typing import AsyncIterator, Dict, Optional, Tuple
16
16
  import grpc
17
17
  from grpc_reflection.v1alpha import reflection
18
18
 
19
+ from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
19
20
  from sglang.srt.entrypoints.grpc_request_manager import GrpcRequestManager
20
21
  from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
21
22
  from sglang.srt.managers.data_parallel_controller import (
22
23
  run_data_parallel_controller_process,
23
24
  )
25
+ from sglang.srt.managers.disagg_service import start_disagg_service
24
26
  from sglang.srt.managers.io_struct import (
25
27
  TokenizedEmbeddingReqInput,
26
28
  TokenizedGenerateReqInput,
@@ -36,6 +38,20 @@ logger = logging.getLogger(__name__)
36
38
  HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
37
39
 
38
40
 
41
+ def _run_scheduler_with_signal_handling(*args, **kwargs):
42
+ """
43
+ Wrapper for run_scheduler_process that ignores SIGINT.
44
+
45
+ The scheduler process should not handle Ctrl+C - it should only terminate
46
+ when the parent gRPC server exits (via kill_itself_when_parent_died).
47
+ """
48
+ # Ignore SIGINT in this subprocess - let the parent handle it
49
+ signal.signal(signal.SIGINT, signal.SIG_IGN)
50
+
51
+ # Now run the actual scheduler process
52
+ run_scheduler_process(*args, **kwargs)
53
+
54
+
39
55
  def _launch_scheduler_process_only(
40
56
  server_args: ServerArgs,
41
57
  port_args: Optional[PortArgs] = None,
@@ -88,7 +104,7 @@ def _launch_scheduler_process_only(
88
104
  )
89
105
  moe_ep_rank = tp_rank // (server_args.tp_size // server_args.ep_size)
90
106
  proc = mp.Process(
91
- target=run_scheduler_process,
107
+ target=_run_scheduler_with_signal_handling,
92
108
  args=(
93
109
  server_args,
94
110
  port_args,
@@ -98,7 +114,6 @@ def _launch_scheduler_process_only(
98
114
  pp_rank,
99
115
  None,
100
116
  writer,
101
- None,
102
117
  ),
103
118
  )
104
119
 
@@ -181,20 +196,34 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
181
196
  # Convert gRPC request to internal format
182
197
  tokenized_req = self._convert_generate_request(request)
183
198
 
184
- # Submit to request manager
185
- output_queue = await self.request_manager.generate_request(
199
+ # Submit to request manager (automatically handles n>1)
200
+ response_generator = self.request_manager.generate_request(
186
201
  obj=tokenized_req,
187
202
  request_id=request.request_id,
188
203
  grpc_context=context,
189
204
  )
190
205
 
191
- # Stream outputs
192
- while True:
193
- try:
194
- # Get output with timeout
195
- output = await asyncio.wait_for(output_queue.get(), timeout=4)
196
-
197
- # Check for errors
206
+ async for output in response_generator:
207
+ # Handle batch responses (for n>1 non-streaming)
208
+ if isinstance(output, list):
209
+ for batch_output in output:
210
+ if "error" in batch_output:
211
+ yield sglang_scheduler_pb2.GenerateResponse(
212
+ request_id=request.request_id,
213
+ error=sglang_scheduler_pb2.GenerateError(
214
+ message=batch_output["error"],
215
+ http_status_code=(
216
+ "500" if "abort" not in batch_output else "499"
217
+ ),
218
+ ),
219
+ )
220
+ else:
221
+ # All non-error batch outputs are final responses
222
+ yield self._create_completion_response(
223
+ request.request_id, batch_output
224
+ )
225
+ else:
226
+ # Handle single response (for streaming or n=1 non-streaming)
198
227
  if "error" in output:
199
228
  yield sglang_scheduler_pb2.GenerateResponse(
200
229
  request_id=request.request_id,
@@ -205,27 +234,13 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
205
234
  ),
206
235
  ),
207
236
  )
208
- break
209
-
210
- # Check if finished
211
- if output.get("finished", False):
212
- # Send completion
237
+ elif output.get("finished", False):
213
238
  yield self._create_completion_response(
214
239
  request.request_id, output
215
240
  )
216
- break
217
241
  else:
218
- # Send chunk
219
242
  yield self._create_chunk_response(request.request_id, output)
220
243
 
221
- except asyncio.TimeoutError:
222
- # Check if context is still active
223
- if context.cancelled():
224
- # Abort the request
225
- await self.request_manager.abort_request(request.request_id)
226
- break
227
- continue
228
-
229
244
  except Exception as e:
230
245
  logger.error(f"Generate failed: {e}\n{get_exception_traceback()}")
231
246
  yield sglang_scheduler_pb2.GenerateResponse(
@@ -266,7 +281,6 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
266
281
  prompt_tokens=result.get("prompt_tokens", 0),
267
282
  cached_tokens=0,
268
283
  embedding_dim=len(result["embedding"]),
269
- generation_time=time.time() - self.start_time,
270
284
  ),
271
285
  )
272
286
 
@@ -319,17 +333,21 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
319
333
  token_ids_logprob=None,
320
334
  )
321
335
 
336
+ if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
337
+ health_request.bootstrap_host = FAKE_BOOTSTRAP_HOST
338
+ health_request.bootstrap_room = 0
339
+
322
340
  logger.info(f"Sending health check request to request manager...")
323
341
 
324
342
  # Submit and wait for response
325
- output_queue = await self.request_manager.generate_request(
343
+ output_generator = self.request_manager.generate_request(
326
344
  health_request, request_id=rid
327
345
  )
328
346
 
329
347
  try:
330
- # Wait for response with configurable timeout
348
+ # Get first response with timeout
331
349
  response = await asyncio.wait_for(
332
- output_queue.get(), timeout=HEALTH_CHECK_TIMEOUT
350
+ output_generator.__anext__(), timeout=HEALTH_CHECK_TIMEOUT
333
351
  )
334
352
 
335
353
  # Clean up
@@ -394,6 +412,15 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
394
412
  # Convert sampling params
395
413
  sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
396
414
 
415
+ # Extract disaggregated params if present
416
+ bootstrap_host = None
417
+ bootstrap_port = None
418
+ bootstrap_room = None
419
+ if grpc_req.HasField("disaggregated_params"):
420
+ bootstrap_host = grpc_req.disaggregated_params.bootstrap_host or None
421
+ bootstrap_port = grpc_req.disaggregated_params.bootstrap_port or None
422
+ bootstrap_room = grpc_req.disaggregated_params.bootstrap_room or None
423
+
397
424
  # Create request
398
425
  return TokenizedGenerateReqInput(
399
426
  rid=grpc_req.request_id,
@@ -402,13 +429,20 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
402
429
  mm_inputs=None, # TODO: implement mm support
403
430
  sampling_params=sampling_params,
404
431
  return_logprob=grpc_req.return_logprob,
405
- logprob_start_len=grpc_req.logprob_start_len or -1,
432
+ logprob_start_len=(
433
+ grpc_req.logprob_start_len
434
+ if grpc_req.logprob_start_len is not None
435
+ else -1
436
+ ),
406
437
  top_logprobs_num=grpc_req.top_logprobs_num or 0,
407
- stream=True, # Always stream for gRPC
408
- lora_path=grpc_req.lora_id if grpc_req.lora_id else None,
438
+ stream=grpc_req.stream or False,
439
+ lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
409
440
  token_ids_logprob=(
410
441
  list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
411
442
  ),
443
+ bootstrap_host=bootstrap_host,
444
+ bootstrap_port=bootstrap_port,
445
+ bootstrap_room=bootstrap_room,
412
446
  )
413
447
 
414
448
  def _convert_embed_request(
@@ -438,6 +472,7 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
438
472
  regex = None
439
473
  json_schema = None
440
474
  ebnf_grammar = None
475
+ structural_tag = None
441
476
 
442
477
  if grpc_params.HasField("regex"):
443
478
  regex = grpc_params.regex
@@ -445,6 +480,8 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
445
480
  json_schema = grpc_params.json_schema
446
481
  elif grpc_params.HasField("ebnf_grammar"):
447
482
  ebnf_grammar = grpc_params.ebnf_grammar
483
+ elif grpc_params.HasField("structural_tag"):
484
+ structural_tag = grpc_params.structural_tag
448
485
 
449
486
  return SGLSamplingParams(
450
487
  temperature=grpc_params.temperature or 1.0,
@@ -456,33 +493,114 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
456
493
  repetition_penalty=grpc_params.repetition_penalty or 1.0,
457
494
  max_new_tokens=grpc_params.max_new_tokens or 128,
458
495
  min_new_tokens=grpc_params.min_new_tokens or 0,
459
- stop=list(grpc_params.stop) if grpc_params.stop else None,
496
+ stop=list(grpc_params.stop) if grpc_params.stop else [],
460
497
  stop_token_ids=(
461
- list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
498
+ list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else []
462
499
  ),
463
500
  skip_special_tokens=grpc_params.skip_special_tokens,
464
501
  spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
465
502
  regex=regex,
466
503
  json_schema=json_schema,
467
504
  ebnf=ebnf_grammar,
505
+ structural_tag=structural_tag,
468
506
  n=grpc_params.n or 1,
469
507
  ignore_eos=grpc_params.ignore_eos,
470
508
  )
471
509
 
510
+ def _convert_output_logprobs_to_proto(
511
+ self, logprobs_data: Dict
512
+ ) -> Optional[sglang_scheduler_pb2.OutputLogProbs]:
513
+ """Convert output logprobs dict to proto (no None values, plain floats)."""
514
+ if not logprobs_data:
515
+ return None
516
+
517
+ token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
518
+ token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
519
+ top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
520
+ top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
521
+
522
+ # Build TopLogProbs entries
523
+ top_logprobs_proto = []
524
+ if top_logprobs_val and top_logprobs_idx:
525
+ for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
526
+ top_logprobs_proto.append(
527
+ sglang_scheduler_pb2.TopLogProbs(
528
+ values=val_list,
529
+ token_ids=idx_list,
530
+ )
531
+ )
532
+
533
+ return sglang_scheduler_pb2.OutputLogProbs(
534
+ token_logprobs=token_logprobs_val, # Plain float array
535
+ token_ids=token_logprobs_idx,
536
+ top_logprobs=top_logprobs_proto,
537
+ )
538
+
539
+ def _convert_input_logprobs_to_proto(
540
+ self, logprobs_data: Dict
541
+ ) -> Optional[sglang_scheduler_pb2.InputLogProbs]:
542
+ """Convert input logprobs dict to proto (first token is None, wrapped in InputTokenLogProb)."""
543
+ if not logprobs_data:
544
+ return None
545
+
546
+ token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
547
+ token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
548
+ top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
549
+ top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
550
+
551
+ # Wrap values in InputTokenLogProb (None for first token, value for others)
552
+ token_logprobs_wrapped = [
553
+ (
554
+ sglang_scheduler_pb2.InputTokenLogProb()
555
+ if x is None
556
+ else sglang_scheduler_pb2.InputTokenLogProb(value=x)
557
+ )
558
+ for x in token_logprobs_val
559
+ ]
560
+
561
+ # Build TopLogProbs entries
562
+ top_logprobs_proto = []
563
+ if top_logprobs_val and top_logprobs_idx:
564
+ for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
565
+ top_logprobs_proto.append(
566
+ sglang_scheduler_pb2.TopLogProbs(
567
+ values=val_list,
568
+ token_ids=idx_list,
569
+ )
570
+ )
571
+
572
+ return sglang_scheduler_pb2.InputLogProbs(
573
+ token_logprobs=token_logprobs_wrapped,
574
+ token_ids=token_logprobs_idx,
575
+ top_logprobs=top_logprobs_proto,
576
+ )
577
+
472
578
  def _create_chunk_response(
473
579
  self, request_id: str, output: Dict
474
580
  ) -> sglang_scheduler_pb2.GenerateResponse:
475
581
  """Create a streaming chunk response."""
582
+ meta_info = output.get("meta_info", {})
583
+
584
+ # Convert output logprobs if present
585
+ output_logprobs_proto = self._convert_output_logprobs_to_proto(
586
+ output.get("output_logprobs")
587
+ )
588
+
589
+ # Convert input logprobs if present (only in first chunk)
590
+ input_logprobs_proto = self._convert_input_logprobs_to_proto(
591
+ output.get("input_logprobs")
592
+ )
593
+
476
594
  return sglang_scheduler_pb2.GenerateResponse(
477
595
  request_id=request_id,
478
596
  chunk=sglang_scheduler_pb2.GenerateStreamChunk(
479
- token_id=output["token_ids"][-1] if output.get("token_ids") else 0,
480
- text=output.get("text", ""),
481
- prompt_tokens=0,
482
- completion_tokens=len(output.get("token_ids", [])),
483
- cached_tokens=0,
484
- generation_time=time.time() - self.start_time,
485
- queue_time=0.0,
597
+ token_ids=output.get("token_ids", []),
598
+ prompt_tokens=meta_info.get("prompt_tokens", 0),
599
+ completion_tokens=meta_info.get("completion_tokens", 0),
600
+ cached_tokens=meta_info.get("cached_tokens", 0),
601
+ output_logprobs=output_logprobs_proto,
602
+ input_logprobs=input_logprobs_proto,
603
+ index=output.get("index", 0),
486
604
  ),
487
605
  )
488
606
 
@@ -491,20 +609,57 @@ class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer)
491
609
  ) -> sglang_scheduler_pb2.GenerateResponse:
492
610
  """Create a completion response."""
493
611
 
494
- # Determine finish reason
495
- finish_reason = sglang_scheduler_pb2.GenerateComplete.STOP
612
+ # Extract meta info and finish reason details
496
613
  meta_info = output.get("meta_info", {})
497
- if meta_info.get("finish_reason") == "length":
498
- finish_reason = sglang_scheduler_pb2.GenerateComplete.LENGTH
499
- elif meta_info.get("finish_reason") == "eos_token":
500
- finish_reason = sglang_scheduler_pb2.GenerateComplete.EOS_TOKEN
614
+ finish_reason_data = meta_info.get("finish_reason")
615
+
616
+ # Determine finish reason, default is stop
617
+ finish_reason = "stop"
618
+ if finish_reason_data:
619
+ if isinstance(finish_reason_data, dict):
620
+ finish_reason_type = finish_reason_data.get("type")
621
+ else:
622
+ # Handle legacy string format
623
+ finish_reason_type = finish_reason_data
624
+
625
+ if finish_reason_type == "length":
626
+ finish_reason = "length"
627
+ elif finish_reason_type == "abort":
628
+ finish_reason = "abort"
629
+
630
+ # Extract matched_stop information
631
+ matched_stop_kwargs = {}
632
+ if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
633
+ matched = finish_reason_data["matched"]
634
+ if isinstance(matched, int):
635
+ matched_stop_kwargs["matched_token_id"] = matched
636
+ elif isinstance(matched, str):
637
+ matched_stop_kwargs["matched_stop_str"] = matched
638
+
639
+ # Convert output logprobs if present
640
+ output_logprobs_proto = self._convert_output_logprobs_to_proto(
641
+ output.get("output_logprobs")
642
+ )
643
+
644
+ # Convert input logprobs if present
645
+ input_logprobs_proto = self._convert_input_logprobs_to_proto(
646
+ output.get("input_logprobs")
647
+ )
501
648
 
502
649
  return sglang_scheduler_pb2.GenerateResponse(
503
650
  request_id=request_id,
504
651
  complete=sglang_scheduler_pb2.GenerateComplete(
505
652
  output_ids=output.get("token_ids", []),
506
- output_text=output.get("text", ""),
507
653
  finish_reason=finish_reason,
654
+ prompt_tokens=meta_info.get("prompt_tokens", 0),
655
+ completion_tokens=meta_info.get(
656
+ "completion_tokens", len(output.get("token_ids", []))
657
+ ),
658
+ cached_tokens=meta_info.get("cached_tokens", 0),
659
+ output_logprobs=output_logprobs_proto,
660
+ input_logprobs=input_logprobs_proto,
661
+ index=output.get("index", 0),
662
+ **matched_stop_kwargs,
508
663
  ),
509
664
  )
510
665
 
@@ -522,6 +677,16 @@ async def serve_grpc(
522
677
  ):
523
678
  """Start the standalone gRPC server with integrated scheduler."""
524
679
 
680
+ # Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode)
681
+ # This ensures the bootstrap server is ready when prefill schedulers try to register
682
+ bootstrap_server = None
683
+ if server_args.disaggregation_mode == "prefill":
684
+ bootstrap_server = start_disagg_service(server_args)
685
+ if bootstrap_server:
686
+ logger.info(
687
+ f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}"
688
+ )
689
+
525
690
  # Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
526
691
  logger.info("Launching scheduler process(es)...")
527
692
  scheduler_info, port_args, scheduler_procs = _launch_scheduler_process_only(
@@ -545,9 +710,11 @@ async def serve_grpc(
545
710
  }
546
711
 
547
712
  # Create request manager with the correct port args
713
+ # Note: We pass None for bootstrap_server since it's already started above
548
714
  request_manager = GrpcRequestManager(
549
715
  server_args=server_args,
550
716
  port_args=port_args,
717
+ bootstrap_server=bootstrap_server,
551
718
  )
552
719
 
553
720
  # Create gRPC server
@@ -597,19 +764,28 @@ async def serve_grpc(
597
764
  await stop_event.wait()
598
765
  finally:
599
766
  logger.info("Shutting down gRPC server")
767
+
768
+ # Shutdown request manager first - this closes ZMQ sockets and stops background tasks
600
769
  await servicer.shutdown()
770
+
771
+ # Stop the gRPC server
601
772
  await server.stop(5.0)
602
773
 
603
- # Terminate scheduler processes
774
+ # Terminate scheduler processes before exiting to avoid atexit hang
775
+ # The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt
604
776
  for i, proc in enumerate(scheduler_procs):
605
- if proc and proc.is_alive():
777
+ if proc.is_alive():
606
778
  logger.info(f"Terminating scheduler process {i}...")
607
779
  proc.terminate()
608
- proc.join(timeout=5.0)
780
+ proc.join(timeout=2.0)
609
781
  if proc.is_alive():
610
- logger.warning(f"Force killing scheduler process {i}...")
782
+ logger.warning(
783
+ f"Scheduler process {i} did not terminate, killing..."
784
+ )
611
785
  proc.kill()
612
- proc.join()
786
+ proc.join(timeout=1.0)
787
+
788
+ logger.info("All scheduler processes terminated")
613
789
 
614
790
 
615
791
  def main():
@@ -618,55 +794,9 @@ def main():
618
794
  mp.set_start_method("spawn", force=True)
619
795
 
620
796
  parser = argparse.ArgumentParser(description="SGLang Standalone gRPC Server")
621
-
622
- # Server arguments
623
- parser.add_argument("--host", type=str, default="0.0.0.0", help="Host to bind to")
624
- parser.add_argument("--port", type=int, default=30000, help="gRPC server port")
625
-
626
- # Model arguments
627
- parser.add_argument("--model-path", type=str, required=True, help="Model path")
628
- parser.add_argument("--tokenizer-path", type=str, help="Tokenizer path")
629
- parser.add_argument("--context-length", type=int, help="Context length")
630
- parser.add_argument("--tp-size", type=int, default=1, help="Tensor parallel size")
631
- parser.add_argument("--dp-size", type=int, default=1, help="Data parallel size")
632
-
633
- # Runtime arguments
634
- parser.add_argument(
635
- "--max-running-requests", type=int, default=2048, help="Max concurrent requests"
636
- )
637
- parser.add_argument(
638
- "--max-total-tokens", type=int, default=1000000, help="Max total tokens"
639
- )
640
- parser.add_argument(
641
- "--max-prefill-tokens", type=int, default=16384, help="Max prefill tokens"
642
- )
643
- parser.add_argument(
644
- "--attention-backend", type=str, default="flashinfer", help="Attention backend"
645
- )
646
- parser.add_argument("--lora-paths", type=str, help="LoRA adapter paths")
647
-
648
- # Logging
649
- parser.add_argument("--log-level", type=str, default="INFO", help="Logging level")
650
-
797
+ ServerArgs.add_cli_args(parser)
651
798
  args = parser.parse_args()
652
-
653
- # Convert to ServerArgs with gRPC host and port
654
- server_args = ServerArgs(
655
- model_path=args.model_path,
656
- tokenizer_path=args.tokenizer_path or args.model_path,
657
- context_length=args.context_length,
658
- tp_size=args.tp_size,
659
- dp_size=args.dp_size,
660
- max_running_requests=args.max_running_requests,
661
- max_total_tokens=args.max_total_tokens,
662
- max_prefill_tokens=args.max_prefill_tokens,
663
- attention_backend=args.attention_backend,
664
- lora_paths=args.lora_paths.split(",") if args.lora_paths else None,
665
- log_level=args.log_level,
666
- # Override with gRPC server host and port
667
- host=args.host,
668
- port=args.port,
669
- )
799
+ server_args = ServerArgs.from_cli_args(args)
670
800
 
671
801
  # Run server
672
802
  asyncio.run(
@@ -29,8 +29,6 @@ import time
29
29
  from http import HTTPStatus
30
30
  from typing import Any, AsyncIterator, Callable, Dict, List, Optional, Union
31
31
 
32
- import setproctitle
33
-
34
32
  from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
35
33
 
36
34
  # Fix a bug of Python threading
@@ -72,6 +70,7 @@ from sglang.srt.managers.io_struct import (
72
70
  AbortReq,
73
71
  CloseSessionReqInput,
74
72
  ConfigureLoggingReq,
73
+ DestroyWeightsUpdateGroupReqInput,
75
74
  EmbeddingReqInput,
76
75
  GenerateReqInput,
77
76
  GetWeightsByNameReqInput,
@@ -95,8 +94,8 @@ from sglang.srt.managers.io_struct import (
95
94
  VertexGenerateReqInput,
96
95
  )
97
96
  from sglang.srt.managers.multi_tokenizer_mixin import (
98
- MultiTokenizerManager,
99
97
  MultiTokenizerRouter,
98
+ TokenizerWorker,
100
99
  get_main_process_id,
101
100
  monkey_patch_uvicorn_multiprocessing,
102
101
  read_from_shared_memory,
@@ -128,9 +127,7 @@ HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
128
127
  # Store global states
129
128
  @dataclasses.dataclass
130
129
  class _GlobalState:
131
- tokenizer_manager: Union[
132
- TokenizerManager, MultiTokenizerRouter, MultiTokenizerManager
133
- ]
130
+ tokenizer_manager: Union[TokenizerManager, MultiTokenizerRouter, TokenizerWorker]
134
131
  template_manager: TemplateManager
135
132
  scheduler_info: Dict
136
133
 
@@ -165,7 +162,7 @@ async def init_multi_tokenizer() -> ServerArgs:
165
162
  )
166
163
 
167
164
  # Launch multi-tokenizer manager process
168
- tokenizer_manager = MultiTokenizerManager(server_args, port_args)
165
+ tokenizer_manager = TokenizerWorker(server_args, port_args)
169
166
  template_manager = TemplateManager()
170
167
  template_manager.initialize_templates(
171
168
  tokenizer_manager=tokenizer_manager,
@@ -302,7 +299,23 @@ app.add_middleware(
302
299
 
303
300
  @app.exception_handler(HTTPException)
304
301
  async def validation_exception_handler(request: Request, exc: HTTPException):
305
- """Enrich HTTP exception with status code and other details"""
302
+ """Enrich HTTP exception with status code and other details.
303
+
304
+ For /v1/responses, emit OpenAI-style nested error envelope:
305
+ {"error": {"message": "...", "type": "...", "param": null, "code": <status>}}
306
+ """
307
+ # adjust fmt for responses api
308
+ if request.url.path.startswith("/v1/responses"):
309
+ nested_error = {
310
+ "message": exc.detail,
311
+ "type": HTTPStatus(exc.status_code).phrase,
312
+ "param": None,
313
+ "code": exc.status_code,
314
+ }
315
+ return ORJSONResponse(
316
+ content={"error": nested_error}, status_code=exc.status_code
317
+ )
318
+
306
319
  error = ErrorResponse(
307
320
  object="error",
308
321
  message=exc.detail,
@@ -315,7 +328,10 @@ async def validation_exception_handler(request: Request, exc: HTTPException):
315
328
  # Custom exception handlers to change validation error status codes
316
329
  @app.exception_handler(RequestValidationError)
317
330
  async def validation_exception_handler(request: Request, exc: RequestValidationError):
318
- """Override FastAPI's default 422 validation error with 400"""
331
+ """Override FastAPI's default 422 validation error with 400.
332
+
333
+ For /v1/responses, emit OpenAI-style nested error envelope; for other endpoints keep legacy format.
334
+ """
319
335
  exc_str = str(exc)
320
336
  errors_str = str(exc.errors())
321
337
 
@@ -324,6 +340,16 @@ async def validation_exception_handler(request: Request, exc: RequestValidationE
324
340
  else:
325
341
  message = exc_str
326
342
 
343
+ if request.url.path.startswith("/v1/responses"):
344
+ # adapt specially, for v1/responses API only (notice the error key is different)
345
+ nested_error = {
346
+ "message": message,
347
+ "type": HTTPStatus.BAD_REQUEST.phrase,
348
+ "param": None,
349
+ "code": HTTPStatus.BAD_REQUEST.value,
350
+ }
351
+ return ORJSONResponse(status_code=400, content={"error": nested_error})
352
+
327
353
  err = ErrorResponse(
328
354
  message=message,
329
355
  type=HTTPStatus.BAD_REQUEST.phrase,
@@ -731,6 +757,20 @@ async def init_weights_update_group(
731
757
  return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
732
758
 
733
759
 
760
+ @app.post("/destroy_weights_update_group")
761
+ async def destroy_weights_update_group(
762
+ obj: DestroyWeightsUpdateGroupReqInput, request: Request
763
+ ):
764
+ """Destroy the parameter update group."""
765
+ success, message = (
766
+ await _global_state.tokenizer_manager.destroy_weights_update_group(obj, request)
767
+ )
768
+ content = {"success": success, "message": message}
769
+ return ORJSONResponse(
770
+ content, status_code=200 if success else HTTPStatus.BAD_REQUEST
771
+ )
772
+
773
+
734
774
  @app.post("/update_weights_from_tensor")
735
775
  async def update_weights_from_tensor(
736
776
  obj: UpdateWeightsFromTensorReqInput, request: Request