sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -507,6 +507,7 @@ def embed_mm_inputs(
507
507
  Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
508
508
  ] = None,
509
509
  placeholder_tokens: dict[Modality, List[int]] = None,
510
+ use_deepstack: bool = False,
510
511
  ) -> Optional[torch.Tensor]:
511
512
  """
512
513
  Embed multimodal inputs and integrate them with text token embeddings.
@@ -522,7 +523,7 @@ def embed_mm_inputs(
522
523
  Returns:
523
524
  Combined embedding tensor with multimodal content integrated
524
525
  """
525
-
526
+ other_info = {}
526
527
  if mm_inputs_list is None:
527
528
  return None
528
529
 
@@ -532,7 +533,7 @@ def embed_mm_inputs(
532
533
  for mm_inputs in mm_inputs_list:
533
534
  item_flatten_list += [item for item in mm_inputs.mm_items if item is not None]
534
535
 
535
- embeddings, masks = [], []
536
+ embeddings, masks, deepstack_embeddings = [], [], []
536
537
  # 2. Get multimodal embedding separately
537
538
  # Try get mm embedding if any
538
539
  for modality in Modality.all():
@@ -578,6 +579,12 @@ def embed_mm_inputs(
578
579
  extend_length=extend_seq_lens,
579
580
  items_offset_list=items_offsets,
580
581
  )
582
+
583
+ if use_deepstack and embedding is not None:
584
+ embedding, deepstack_embedding = (
585
+ multimodal_model.separate_deepstack_embeds(embedding)
586
+ )
587
+ deepstack_embeddings += [deepstack_embedding]
581
588
  embeddings += [embedding]
582
589
  masks += [mask]
583
590
 
@@ -591,13 +598,37 @@ def embed_mm_inputs(
591
598
  inputs_embeds = input_embedding(input_ids)
592
599
 
593
600
  # 4. scatter embeddings into input embedding
594
- for embedding, mask in zip(embeddings, masks):
601
+
602
+ # deepstack embedding
603
+ if use_deepstack:
604
+ num_deepstack_embeddings = (
605
+ len(multimodal_model.deepstack_visual_indexes) if use_deepstack else 0
606
+ )
607
+ deepstack_embedding_shape = inputs_embeds.shape[:-1] + (
608
+ inputs_embeds.shape[-1] * num_deepstack_embeddings,
609
+ )
610
+
611
+ input_deepstack_embeds = torch.zeros(
612
+ deepstack_embedding_shape,
613
+ device=inputs_embeds.device,
614
+ dtype=inputs_embeds.dtype,
615
+ )
616
+
617
+ other_info["input_deepstack_embeds"] = input_deepstack_embeds
618
+
619
+ for i, embedding, mask in zip(range(len(embeddings)), embeddings, masks):
595
620
  if embedding is None or mask is None:
596
621
  continue
597
622
  # in-place update
598
623
  indices = torch.where(mask.squeeze(dim=-1))[0]
599
624
  inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
600
- return inputs_embeds
625
+
626
+ if use_deepstack:
627
+ input_deepstack_embeds[indices] = deepstack_embeddings[i].to(
628
+ inputs_embeds.device, inputs_embeds.dtype
629
+ )
630
+
631
+ return inputs_embeds, other_info
601
632
 
602
633
 
603
634
  def general_mm_embed_routine(
@@ -609,6 +640,7 @@ def general_mm_embed_routine(
609
640
  Modality, Callable[[List[MultimodalDataItem]], torch.Tensor]
610
641
  ] = None,
611
642
  placeholder_tokens: Optional[dict[Modality, List[int]]] = None,
643
+ use_deepstack: bool = False,
612
644
  **kwargs,
613
645
  ) -> torch.Tensor:
614
646
  """
@@ -620,6 +652,7 @@ def general_mm_embed_routine(
620
652
  language_model: Base language model to use
621
653
  data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
622
654
  placeholder_tokens: Token IDs for multimodal placeholders
655
+ use_deepstack: Whether to use deepstack embeddings
623
656
  **kwargs: Additional arguments passed to language model
624
657
 
625
658
  Returns:
@@ -645,16 +678,20 @@ def general_mm_embed_routine(
645
678
  for i, seq_len in enumerate(forward_batch.extend_seq_lens_cpu)
646
679
  if forward_batch.mm_inputs[i] is not None
647
680
  ]
648
- inputs_embeds = embed_mm_inputs(
681
+ inputs_embeds, other_info = embed_mm_inputs(
649
682
  mm_inputs_list=mm_inputs_list,
650
683
  extend_prefix_lens=extend_prefix_lens,
651
684
  extend_seq_lens=extend_seq_lens,
652
685
  input_ids=input_ids,
653
- input_embedding=embed_tokens,
654
686
  multimodal_model=multimodal_model,
687
+ input_embedding=embed_tokens,
655
688
  data_embedding_func_mapping=data_embedding_funcs,
656
689
  placeholder_tokens=placeholder_tokens,
690
+ use_deepstack=use_deepstack,
657
691
  )
692
+ # add for qwen3_vl deepstack
693
+ if use_deepstack:
694
+ kwargs["input_deepstack_embeds"] = other_info["input_deepstack_embeds"]
658
695
  # once used, mm_inputs is useless, considering chunked-prefill is disabled for multimodal models
659
696
  # just being defensive here
660
697
  forward_batch.mm_inputs = None
@@ -11,7 +11,7 @@
11
11
  # See the License for the specific language governing permissions and
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
- """MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
14
+ """Mixin class and utils for multi-http-worker mode"""
15
15
  import asyncio
16
16
  import logging
17
17
  import multiprocessing as multiprocessing
@@ -30,10 +30,10 @@ import zmq.asyncio
30
30
  from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
31
31
  from sglang.srt.managers.disagg_service import start_disagg_service
32
32
  from sglang.srt.managers.io_struct import (
33
- BatchEmbeddingOut,
34
- BatchMultimodalOut,
35
- BatchStrOut,
36
- BatchTokenIDOut,
33
+ BatchEmbeddingOutput,
34
+ BatchMultimodalOutput,
35
+ BatchStrOutput,
36
+ BatchTokenIDOutput,
37
37
  MultiTokenizerRegisterReq,
38
38
  MultiTokenizerWrapper,
39
39
  )
@@ -83,8 +83,8 @@ class SocketMapping:
83
83
 
84
84
  def _handle_output_by_index(output, i):
85
85
  """NOTE: A maintainable method is better here."""
86
- if isinstance(output, BatchTokenIDOut):
87
- new_output = BatchTokenIDOut(
86
+ if isinstance(output, BatchTokenIDOutput):
87
+ new_output = BatchTokenIDOutput(
88
88
  rids=[output.rids[i]],
89
89
  finished_reasons=(
90
90
  [output.finished_reasons[i]]
@@ -198,8 +198,8 @@ def _handle_output_by_index(output, i):
198
198
  placeholder_tokens_idx=None,
199
199
  placeholder_tokens_val=None,
200
200
  )
201
- elif isinstance(output, BatchEmbeddingOut):
202
- new_output = BatchEmbeddingOut(
201
+ elif isinstance(output, BatchEmbeddingOutput):
202
+ new_output = BatchEmbeddingOutput(
203
203
  rids=[output.rids[i]],
204
204
  finished_reasons=(
205
205
  [output.finished_reasons[i]]
@@ -216,8 +216,8 @@ def _handle_output_by_index(output, i):
216
216
  placeholder_tokens_idx=None,
217
217
  placeholder_tokens_val=None,
218
218
  )
219
- elif isinstance(output, BatchStrOut):
220
- new_output = BatchStrOut(
219
+ elif isinstance(output, BatchStrOutput):
220
+ new_output = BatchStrOutput(
221
221
  rids=[output.rids[i]],
222
222
  finished_reasons=(
223
223
  [output.finished_reasons[i]]
@@ -314,8 +314,8 @@ def _handle_output_by_index(output, i):
314
314
  placeholder_tokens_idx=None,
315
315
  placeholder_tokens_val=None,
316
316
  )
317
- elif isinstance(output, BatchMultimodalOut):
318
- new_output = BatchMultimodalOut(
317
+ elif isinstance(output, BatchMultimodalOutput):
318
+ new_output = BatchMultimodalOutput(
319
319
  rids=[output.rids[i]],
320
320
  finished_reasons=(
321
321
  [output.finished_reasons[i]]
@@ -343,7 +343,7 @@ def _handle_output_by_index(output, i):
343
343
 
344
344
 
345
345
  class MultiHttpWorkerDetokenizerMixin:
346
- """Mixin class for MultiTokenizerManager and DetokenizerManager"""
346
+ """Mixin class for DetokenizerManager"""
347
347
 
348
348
  def get_worker_ids_from_req_rids(self, rids):
349
349
  if isinstance(rids, list):
@@ -386,7 +386,7 @@ class MultiHttpWorkerDetokenizerMixin:
386
386
 
387
387
 
388
388
  class MultiTokenizerRouter:
389
- """A router to receive requests from MultiTokenizerManager"""
389
+ """A router to receive requests from TokenizerWorker"""
390
390
 
391
391
  def __init__(
392
392
  self,
@@ -454,8 +454,8 @@ class MultiTokenizerRouter:
454
454
  self.socket_mapping.send_output(worker_id, new_recv_obj)
455
455
 
456
456
 
457
- class MultiTokenizerManager(TokenizerManager):
458
- """Multi Process Tokenizer Manager that tokenizes the text."""
457
+ class TokenizerWorker(TokenizerManager):
458
+ """Tokenizer Worker in multi-http-worker mode"""
459
459
 
460
460
  def __init__(
461
461
  self,
@@ -12,8 +12,7 @@ logger = logging.getLogger(__name__)
12
12
  PROCESSOR_MAPPING = {}
13
13
 
14
14
 
15
- def import_processors():
16
- package_name = "sglang.srt.multimodal.processors"
15
+ def import_processors(package_name: str):
17
16
  package = importlib.import_module(package_name)
18
17
  for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
19
18
  if not ispkg:
@@ -0,0 +1,53 @@
1
+ import torch
2
+
3
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch
4
+ from sglang.srt.utils import get_compiler_backend
5
+
6
+
7
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
8
+ def _resolve_future_token_ids(input_ids, future_token_ids_map):
9
+ input_ids[:] = torch.where(
10
+ input_ids < 0,
11
+ future_token_ids_map[torch.clamp(-input_ids, min=0)],
12
+ input_ids,
13
+ )
14
+
15
+
16
+ class FutureMap:
17
+ def __init__(
18
+ self,
19
+ max_running_requests: int,
20
+ device: torch.device,
21
+ ):
22
+ self.future_ct = 0
23
+ # A factor of 3 is used to avoid collision in the circular buffer.
24
+ self.future_limit = max_running_requests * 3
25
+ # A factor of 5 is used to ensure the buffer is large enough.
26
+ self.future_buffer_len = max_running_requests * 5
27
+ self.device = device
28
+
29
+ self.token_ids_buf = torch.empty(
30
+ (self.future_buffer_len,), dtype=torch.int64, device=self.device
31
+ )
32
+
33
+ def update_ct(self, bs: int) -> int:
34
+ """Update the circular buffer pointer and return the current pointer."""
35
+ cur_future_ct = self.future_ct
36
+ self.future_ct = (cur_future_ct + bs) % self.future_limit
37
+ return cur_future_ct
38
+
39
+ def resolve_future(self, model_worker_batch: ModelWorkerBatch):
40
+ input_ids = model_worker_batch.input_ids
41
+ _resolve_future_token_ids(input_ids, self.token_ids_buf)
42
+
43
+ def update_next_future(self, future_ct: int, bs: int):
44
+ return torch.arange(
45
+ -(future_ct + 1),
46
+ -(future_ct + 1 + bs),
47
+ -1,
48
+ dtype=torch.int64,
49
+ device=self.device,
50
+ )
51
+
52
+ def store_to_map(self, future_ct: int, bs: int, next_token_ids: torch.Tensor):
53
+ self.token_ids_buf[future_ct + 1 : future_ct + bs + 1] = next_token_ids