sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@ Mimics TokenizerManager's state management and ZMQ communication patterns.
4
4
  """
5
5
 
6
6
  import asyncio
7
+ import copy
7
8
  import dataclasses
8
9
  import logging
9
10
  import os
@@ -11,7 +12,8 @@ import signal
11
12
  import sys
12
13
  import threading
13
14
  import time
14
- from typing import Any, Dict, List, Optional, Union
15
+ import uuid
16
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Union
15
17
 
16
18
  import grpc
17
19
  import zmq
@@ -19,8 +21,8 @@ import zmq.asyncio
19
21
 
20
22
  from sglang.srt.managers.io_struct import (
21
23
  AbortReq,
22
- BatchEmbeddingOut,
23
- BatchTokenIDOut,
24
+ BatchEmbeddingOutput,
25
+ BatchTokenIDOutput,
24
26
  HealthCheckOutput,
25
27
  TokenizedEmbeddingReqInput,
26
28
  TokenizedGenerateReqInput,
@@ -79,11 +81,10 @@ class GrpcReqState:
79
81
  last_completion_tokens: int = 1
80
82
 
81
83
  # Streaming state
82
- last_output_offset: int = 0
83
84
  stream_finished: bool = False
85
+ input_logprobs_sent: bool = False # Track if input logprobs were sent in streaming
84
86
 
85
- # Output accumulation
86
- text: str = ""
87
+ # Token accumulation (for non-streaming)
87
88
  output_ids: List[int] = dataclasses.field(default_factory=list)
88
89
  input_token_logprobs_val: List[float] = dataclasses.field(default_factory=list)
89
90
  input_token_logprobs_idx: List[int] = dataclasses.field(default_factory=list)
@@ -109,22 +110,23 @@ class GrpcRequestManager:
109
110
  self,
110
111
  server_args: ServerArgs,
111
112
  port_args: PortArgs,
113
+ bootstrap_server=None,
112
114
  ):
113
115
  """Initialize the gRPC request manager."""
114
116
  self.server_args = server_args
115
117
  self.port_args = port_args
116
118
 
117
119
  # ZMQ Communication Setup (same pattern as TokenizerManager)
118
- context = zmq.asyncio.Context(2)
120
+ self.context = zmq.asyncio.Context(2)
119
121
 
120
122
  # Socket for receiving outputs from scheduler
121
123
  self.recv_from_scheduler = get_zmq_socket(
122
- context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
124
+ self.context, zmq.PULL, port_args.detokenizer_ipc_name, bind=True
123
125
  )
124
126
 
125
127
  # Socket for sending requests to scheduler
126
128
  self.send_to_scheduler = get_zmq_socket(
127
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
129
+ self.context, zmq.PUSH, port_args.scheduler_input_ipc_name, bind=True
128
130
  )
129
131
 
130
132
  # State Management (from TokenizerManager)
@@ -139,41 +141,158 @@ class GrpcRequestManager:
139
141
  self.is_pause_cond = asyncio.Condition()
140
142
 
141
143
  # Metrics
142
- self.request_counter = 0
143
- self.request_counter_lock = asyncio.Lock()
144
144
  self.last_receive_tstamp = time.time()
145
145
 
146
146
  # Crash dump for debugging
147
147
  self.crash_dump_request_list = []
148
148
  self.crash_dump_performed = False
149
149
 
150
+ # Bootstrap server (passed from serve_grpc, not started here)
151
+ self.bootstrap_server = bootstrap_server
152
+
150
153
  logger.info(
151
154
  f"GrpcRequestManager initialized with ZMQ IPC: "
152
155
  f"recv={port_args.detokenizer_ipc_name}, "
153
156
  f"send={port_args.scheduler_input_ipc_name}"
154
157
  )
158
+ if self.bootstrap_server:
159
+ logger.info(
160
+ f"Bootstrap server initialized for disaggregation mode: "
161
+ f"{server_args.disaggregation_mode}"
162
+ )
155
163
 
156
164
  async def generate_request(
157
165
  self,
158
166
  obj: TokenizedGenerateReqInput,
159
167
  request_id: Optional[str] = None,
160
168
  grpc_context: Optional[grpc.aio.ServicerContext] = None,
161
- ) -> asyncio.Queue:
169
+ ) -> AsyncGenerator[Union[Dict, List[Dict]], None]:
162
170
  """
163
- Submit a generation request to the scheduler.
164
- Returns a queue for streaming outputs.
171
+ Submit a generation request to the scheduler with n>1 parallel sampling support.
172
+
173
+ This method implements the same two-phase approach as tokenizer_manager.py:
174
+ 1. Phase 1: Send prefix caching request (max_new_tokens=0)
175
+ 2. Phase 2: Send n generation requests that reuse the cached prefix
176
+
177
+ Yields individual responses for streaming, or aggregated responses for non-streaming.
165
178
  """
179
+ n = getattr(obj.sampling_params, "n", 1)
180
+
181
+ if n <= 1:
182
+ async for response in self._handle_single_request(
183
+ obj, request_id, grpc_context
184
+ ):
185
+ yield response
186
+ return
187
+
188
+ # N>1 handling - two-phase approach
189
+ logger.debug(f"Multiple sampling request (n={n}), using two-phase approach")
190
+
191
+ # Generate base request ID if not provided
192
+ if request_id is None:
193
+ base_request_id = f"grpc-{uuid.uuid4().hex}"
194
+ else:
195
+ base_request_id = request_id
196
+
197
+ # Phase 1: Cache the common prefix
198
+ logger.debug(f"Phase 1: Caching prefix for request {base_request_id}")
199
+ prefix_obj = copy.copy(obj)
200
+ prefix_obj.sampling_params = copy.copy(obj.sampling_params)
201
+ prefix_obj.sampling_params.max_new_tokens = 0 # Prefill-only
202
+ prefix_obj.sampling_params.n = 1 # Don't replicate prefix request
203
+
204
+ # Send prefix caching request and consume response
205
+ async for _ in self._handle_single_request(
206
+ prefix_obj, f"{base_request_id}-prefix", grpc_context
207
+ ):
208
+ # Consume prefix response (usually just one chunk with finish_reason)
209
+ pass
210
+
211
+ logger.debug(f"Phase 1 completed: Prefix cached for {base_request_id}")
212
+
213
+ # Phase 2: Generate n parallel requests
214
+ logger.debug(f"Phase 2: Generating {n} parallel requests")
215
+ generators = []
216
+ request_ids = []
217
+
218
+ for i in range(n):
219
+ # Create individual generation request
220
+ gen_obj = copy.copy(obj)
221
+ gen_obj.sampling_params = copy.copy(obj.sampling_params)
222
+ gen_obj.sampling_params.n = 1 # Each request generates 1 response
223
+
224
+ gen_request_id = f"{base_request_id}-{i}"
225
+ request_ids.append(gen_request_id)
226
+
227
+ # Start generation request
228
+ generators.append(
229
+ self._handle_single_request(gen_obj, gen_request_id, grpc_context)
230
+ )
231
+
232
+ # Handle response aggregation
233
+ is_stream = getattr(obj, "stream", False)
234
+
235
+ if not is_stream:
236
+ # Non-streaming: collect all responses and return as batch
237
+ logger.debug(f"Non-streaming mode: collecting {n} responses")
238
+ responses = []
239
+ for generator in generators:
240
+ async for response in generator:
241
+ responses.append(response)
242
+ yield responses # Return all responses as a batch
243
+ else:
244
+ # Streaming mode: multiplex responses with index for ordering
245
+ logger.debug(f"Streaming mode: multiplexing {n} streams")
246
+ rid_to_index = {rid: i for i, rid in enumerate(request_ids)}
247
+
248
+ # Create async tasks for all generators
249
+ task_map = {}
250
+ for generator in generators:
251
+ task = asyncio.create_task(generator.__anext__())
252
+ task_map[task] = generator
253
+
254
+ # Process responses as they arrive
255
+ while task_map:
256
+ done, _ = await asyncio.wait(
257
+ task_map.keys(), return_when=asyncio.FIRST_COMPLETED
258
+ )
259
+
260
+ for task in done:
261
+ generator = task_map.pop(task)
262
+ try:
263
+ response = await task
264
+
265
+ # Add index for client-side ordering
266
+ if isinstance(response, dict) and "meta_info" in response:
267
+ response_rid = response["meta_info"].get("id", "")
268
+ if response_rid in rid_to_index:
269
+ response["index"] = rid_to_index[response_rid]
270
+
271
+ yield response
272
+
273
+ # Create next task for this generator
274
+ next_task = asyncio.create_task(generator.__anext__())
275
+ task_map[next_task] = generator
276
+
277
+ except StopAsyncIteration:
278
+ # This generator is finished
279
+ pass
280
+
281
+ async def _handle_single_request(
282
+ self,
283
+ obj: TokenizedGenerateReqInput,
284
+ request_id: Optional[str] = None,
285
+ grpc_context: Optional[grpc.aio.ServicerContext] = None,
286
+ ):
287
+ """Handle a single request - core implementation without n>1 logic."""
166
288
  # Generate request ID if not provided
167
289
  if request_id is None:
168
- async with self.request_counter_lock:
169
- request_id = f"grpc-{self.request_counter}"
170
- self.request_counter += 1
290
+ request_id = f"grpc-{uuid.uuid4().hex}"
171
291
 
172
292
  obj.rid = request_id
173
293
 
294
+ # Create and register request state
174
295
  # TODO: support log_request
175
-
176
- # Create request state
177
296
  state = GrpcReqState(
178
297
  request_id=request_id,
179
298
  grpc_context=grpc_context,
@@ -189,19 +308,51 @@ class GrpcRequestManager:
189
308
  state.session_id = obj.session_params.session_id
190
309
  state.is_session_request = True
191
310
 
192
- # Register state
193
311
  self.rid_to_state[request_id] = state
194
312
  self.record_request_for_crash_dump(obj)
195
313
 
196
- # Send to scheduler via ZMQ
197
314
  try:
315
+ # Send to scheduler - let exceptions bubble up to grpc_server.py
198
316
  await self._send_to_scheduler(obj)
199
- except Exception as e:
200
- # Clean up on failure
201
- del self.rid_to_state[request_id]
202
- raise RuntimeError(f"Failed to send request to scheduler: {e}")
203
317
 
204
- return state.out_queue
318
+ is_stream = getattr(obj, "stream", False)
319
+
320
+ while True:
321
+ # Client cancelled - notify scheduler and exit
322
+ if grpc_context and grpc_context.cancelled():
323
+ await self.abort_request(request_id)
324
+ return
325
+
326
+ try:
327
+ response = await asyncio.wait_for(state.out_queue.get(), timeout=4)
328
+
329
+ if is_stream:
330
+ yield response
331
+
332
+ # Non-streaming: yield final response with accumulated tokens from state
333
+ if isinstance(response, dict) and response.get("finished", False):
334
+ if not is_stream:
335
+ final_response = response.copy()
336
+ final_response["token_ids"] = state.output_ids
337
+ yield final_response
338
+ break
339
+
340
+ except asyncio.TimeoutError:
341
+ # Timeout waiting for response - abort and cleanup
342
+ logger.warning(
343
+ f"Timeout waiting for response for request {request_id}"
344
+ )
345
+ await self.abort_request(request_id)
346
+ return
347
+
348
+ finally:
349
+ # Always clean up request state when exiting
350
+ self._cleanup_request_state(request_id)
351
+
352
+ def _cleanup_request_state(self, request_id: str):
353
+ """Clean up local request state (does not notify scheduler)."""
354
+ if request_id in self.rid_to_state:
355
+ del self.rid_to_state[request_id]
205
356
 
206
357
  async def embedding_request(
207
358
  self,
@@ -214,9 +365,7 @@ class GrpcRequestManager:
214
365
  """
215
366
  # Generate request ID if not provided
216
367
  if request_id is None:
217
- async with self.request_counter_lock:
218
- request_id = f"grpc-embed-{self.request_counter}"
219
- self.request_counter += 1
368
+ request_id = f"grpc-embed-{uuid.uuid4().hex}"
220
369
 
221
370
  obj.rid = request_id
222
371
 
@@ -318,9 +467,9 @@ class GrpcRequestManager:
318
467
  await self.is_pause_cond.wait()
319
468
 
320
469
  # Handle different output types
321
- if isinstance(recv_obj, BatchTokenIDOut):
470
+ if isinstance(recv_obj, BatchTokenIDOutput):
322
471
  await self._handle_batch_output(recv_obj)
323
- elif isinstance(recv_obj, BatchEmbeddingOut):
472
+ elif isinstance(recv_obj, BatchEmbeddingOutput):
324
473
  await self._handle_embedding_output(recv_obj)
325
474
  elif isinstance(recv_obj, HealthCheckOutput):
326
475
  await self._handle_health_check_output(recv_obj)
@@ -332,12 +481,71 @@ class GrpcRequestManager:
332
481
  if self.gracefully_exit:
333
482
  break
334
483
  continue
484
+ except zmq.error.ZMQError as e:
485
+ # Socket closed or other ZMQ error - exit cleanly if shutting down
486
+ if self.gracefully_exit:
487
+ logger.debug(f"ZMQ recv interrupted during shutdown: {e}")
488
+ break
489
+ logger.error(
490
+ f"ZMQ error in handle loop: {e}\n{get_exception_traceback()}"
491
+ )
492
+ break
335
493
  except Exception as e:
336
494
  logger.error(f"Handle loop error: {e}\n{get_exception_traceback()}")
337
495
  if self.gracefully_exit:
338
496
  break
339
497
 
340
- async def _handle_batch_output(self, batch_out: BatchTokenIDOut):
498
+ def _convert_logprob_style(
499
+ self,
500
+ state: GrpcReqState,
501
+ batch_out: BatchTokenIDOutput,
502
+ batch_index: int,
503
+ ):
504
+ """
505
+ Convert and accumulate logprobs from batch output to state.
506
+ Follows the same logic as tokenizer_manager.convert_logprob_style.
507
+ """
508
+ # Early exit if no input logprobs at all
509
+ if batch_out.input_token_logprobs_val is None:
510
+ return
511
+
512
+ # Accumulate input token logprobs (only if list is non-empty)
513
+ if len(batch_out.input_token_logprobs_val) > 0:
514
+ state.input_token_logprobs_val.extend(
515
+ batch_out.input_token_logprobs_val[batch_index]
516
+ )
517
+ state.input_token_logprobs_idx.extend(
518
+ batch_out.input_token_logprobs_idx[batch_index]
519
+ )
520
+
521
+ # Always accumulate output token logprobs
522
+ state.output_token_logprobs_val.extend(
523
+ batch_out.output_token_logprobs_val[batch_index]
524
+ )
525
+ state.output_token_logprobs_idx.extend(
526
+ batch_out.output_token_logprobs_idx[batch_index]
527
+ )
528
+
529
+ # Handle top logprobs if requested
530
+ if state.obj.top_logprobs_num > 0:
531
+ # Accumulate input top logprobs (only if list is non-empty)
532
+ if len(batch_out.input_top_logprobs_val) > 0:
533
+ state.input_top_logprobs_val.extend(
534
+ batch_out.input_top_logprobs_val[batch_index]
535
+ )
536
+ state.input_top_logprobs_idx.extend(
537
+ batch_out.input_top_logprobs_idx[batch_index]
538
+ )
539
+
540
+ # Always accumulate output top logprobs
541
+ state.output_top_logprobs_val.extend(
542
+ batch_out.output_top_logprobs_val[batch_index]
543
+ )
544
+ state.output_top_logprobs_idx.extend(
545
+ batch_out.output_top_logprobs_idx[batch_index]
546
+ )
547
+
548
+ async def _handle_batch_output(self, batch_out: BatchTokenIDOutput):
341
549
  """Handle batch generation output from scheduler."""
342
550
  # Process each request in the batch
343
551
  for i, rid in enumerate(batch_out.rids):
@@ -355,7 +563,6 @@ class GrpcRequestManager:
355
563
  # Extract output for this request
356
564
  output_data = {
357
565
  "request_id": rid,
358
- "text": batch_out.decoded_texts[i] if batch_out.decoded_texts else "",
359
566
  "token_ids": batch_out.output_ids[i] if batch_out.output_ids else [],
360
567
  "finished": batch_out.finished_reasons[i] is not None,
361
568
  "meta_info": {
@@ -367,37 +574,81 @@ class GrpcRequestManager:
367
574
  if batch_out.completion_tokens
368
575
  else 0
369
576
  ),
577
+ "cached_tokens": (
578
+ batch_out.cached_tokens[i] if batch_out.cached_tokens else 0
579
+ ),
370
580
  "finish_reason": (
371
- str(batch_out.finished_reasons[i])
581
+ batch_out.finished_reasons[i]
372
582
  if batch_out.finished_reasons[i]
373
583
  else None
374
584
  ),
375
585
  },
376
586
  }
377
587
 
378
- # Add logprobs if available
379
- if batch_out.output_token_logprobs_val and i < len(
380
- batch_out.output_token_logprobs_val
381
- ):
382
- output_data["logprobs"] = {
383
- "tokens": batch_out.output_token_logprobs_val[i],
384
- "top_logprobs": (
385
- batch_out.output_top_logprobs_val[i]
386
- if batch_out.output_top_logprobs_val
387
- and i < len(batch_out.output_top_logprobs_val)
388
- else None
389
- ),
390
- }
391
-
392
- # Update state
393
- if output_data["text"]:
394
- state.text += output_data["text"][state.last_output_offset :]
395
- state.last_output_offset = len(output_data["text"])
588
+ # Accumulate logprobs (following tokenizer_manager pattern)
589
+ if state.obj.return_logprob:
590
+ self._convert_logprob_style(state, batch_out, i)
396
591
 
592
+ # Send input logprobs based if available
593
+ if (
594
+ state.obj.return_logprob
595
+ and state.obj.logprob_start_len >= 0
596
+ and state.input_token_logprobs_val
597
+ ):
598
+ if state.obj.stream and not state.input_logprobs_sent:
599
+ # Streaming: send input logprobs once in first chunk that has them
600
+ output_data["input_logprobs"] = {
601
+ "token_logprobs_val": state.input_token_logprobs_val,
602
+ "token_logprobs_idx": state.input_token_logprobs_idx,
603
+ "top_logprobs_val": state.input_top_logprobs_val,
604
+ "top_logprobs_idx": state.input_top_logprobs_idx,
605
+ }
606
+ state.input_logprobs_sent = True
607
+ elif not state.obj.stream and output_data["finished"]:
608
+ # Non-streaming: send input logprobs in final chunk
609
+ output_data["input_logprobs"] = {
610
+ "token_logprobs_val": state.input_token_logprobs_val,
611
+ "token_logprobs_idx": state.input_token_logprobs_idx,
612
+ "top_logprobs_val": state.input_top_logprobs_val,
613
+ "top_logprobs_idx": state.input_top_logprobs_idx,
614
+ }
615
+
616
+ # Send output logprobs if available
617
+ if (
618
+ state.obj.return_logprob
619
+ and batch_out.output_token_logprobs_val
620
+ and i < len(batch_out.output_token_logprobs_val)
621
+ ):
622
+ if state.obj.stream:
623
+ # For streaming: send incremental logprobs (only new tokens in this chunk)
624
+ # NOTE: this is different than TokenizerManager, which always accumulates
625
+ def get_part(attr_name):
626
+ source_list = getattr(batch_out, attr_name, None)
627
+ return (
628
+ source_list[i]
629
+ if source_list and i < len(source_list)
630
+ else []
631
+ )
632
+
633
+ output_data["output_logprobs"] = {
634
+ "token_logprobs_val": batch_out.output_token_logprobs_val[i],
635
+ "token_logprobs_idx": get_part("output_token_logprobs_idx"),
636
+ "top_logprobs_val": get_part("output_top_logprobs_val"),
637
+ "top_logprobs_idx": get_part("output_top_logprobs_idx"),
638
+ }
639
+ elif output_data["finished"]:
640
+ # Non-streaming: send cumulative output logprobs in final chunk
641
+ output_data["output_logprobs"] = {
642
+ "token_logprobs_val": state.output_token_logprobs_val,
643
+ "token_logprobs_idx": state.output_token_logprobs_idx,
644
+ "top_logprobs_val": state.output_top_logprobs_val,
645
+ "top_logprobs_idx": state.output_top_logprobs_idx,
646
+ }
647
+
648
+ # Update state for accumulation
397
649
  if output_data["token_ids"]:
398
650
  state.output_ids.extend(output_data["token_ids"])
399
651
 
400
- # Send to output queue
401
652
  await state.out_queue.put(output_data)
402
653
 
403
654
  # Handle completion
@@ -415,7 +666,7 @@ class GrpcRequestManager:
415
666
 
416
667
  asyncio.create_task(cleanup())
417
668
 
418
- async def _handle_embedding_output(self, batch_out: BatchEmbeddingOut):
669
+ async def _handle_embedding_output(self, batch_out: BatchEmbeddingOutput):
419
670
  """Handle batch embedding output from scheduler."""
420
671
  for i, rid in enumerate(batch_out.rids):
421
672
  if rid not in self.rid_to_state:
@@ -499,8 +750,17 @@ class GrpcRequestManager:
499
750
  logger.info("Shutting down GrpcRequestManager")
500
751
  self.gracefully_exit = True
501
752
 
753
+ # Cancel all asyncio tasks FIRST - this will interrupt blocked recv() calls
754
+ for task in list(self.asyncio_tasks):
755
+ if not task.done():
756
+ task.cancel()
757
+
758
+ # Give tasks a moment to process cancellation
759
+ if self.asyncio_tasks:
760
+ await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
761
+
502
762
  # Cancel all pending requests
503
- for rid, state in self.rid_to_state.items():
763
+ for rid, state in list(self.rid_to_state.items()):
504
764
  if not state.finished:
505
765
  await state.out_queue.put(
506
766
  {"error": "Server shutting down", "shutdown": True}
@@ -512,10 +772,25 @@ class GrpcRequestManager:
512
772
  if self.asyncio_tasks:
513
773
  await asyncio.gather(*list(self.asyncio_tasks), return_exceptions=True)
514
774
 
775
+ # Shutdown bootstrap server if running
776
+ if self.bootstrap_server:
777
+ logger.info("Shutting down bootstrap server")
778
+ try:
779
+ if hasattr(self.bootstrap_server, "shutdown"):
780
+ if asyncio.iscoroutinefunction(self.bootstrap_server.shutdown):
781
+ await self.bootstrap_server.shutdown()
782
+ else:
783
+ self.bootstrap_server.shutdown()
784
+ except Exception as e:
785
+ logger.warning(f"Error shutting down bootstrap server: {e}")
786
+
515
787
  # Close ZMQ sockets
516
788
  self.recv_from_scheduler.close()
517
789
  self.send_to_scheduler.close()
518
790
 
791
+ # Terminate the ZMQ context - this is critical for asyncio loop to exit cleanly
792
+ self.context.term()
793
+
519
794
  logger.info("GrpcRequestManager shutdown complete")
520
795
 
521
796
  def get_server_info(self) -> Dict[str, Any]: