sglang 0.5.3rc2__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (419) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +378 -160
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +10 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +105 -10
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +136 -25
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +63 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +83 -80
  69. sglang/srt/entrypoints/grpc_server.py +430 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +195 -102
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +58 -6
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +33 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +20 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/minimax_m2.py +367 -0
  93. sglang/srt/function_call/utils.py +2 -2
  94. sglang/srt/grpc/compile_proto.py +3 -3
  95. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  96. sglang/srt/grpc/health_servicer.py +189 -0
  97. sglang/srt/grpc/scheduler_launcher.py +181 -0
  98. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  99. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  100. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  101. sglang/srt/layers/activation.py +10 -1
  102. sglang/srt/layers/attention/aiter_backend.py +3 -3
  103. sglang/srt/layers/attention/ascend_backend.py +17 -1
  104. sglang/srt/layers/attention/attention_registry.py +43 -23
  105. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  106. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  107. sglang/srt/layers/attention/fla/chunk.py +0 -1
  108. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  109. sglang/srt/layers/attention/fla/index.py +0 -2
  110. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  111. sglang/srt/layers/attention/fla/utils.py +0 -3
  112. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  113. sglang/srt/layers/attention/flashattention_backend.py +24 -10
  114. sglang/srt/layers/attention/flashinfer_backend.py +258 -22
  115. sglang/srt/layers/attention/flashinfer_mla_backend.py +38 -28
  116. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  117. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  118. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  119. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  121. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  122. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  123. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  124. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  125. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  127. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  128. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  129. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  130. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  131. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  132. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  133. sglang/srt/layers/attention/nsa/utils.py +0 -1
  134. sglang/srt/layers/attention/nsa_backend.py +404 -90
  135. sglang/srt/layers/attention/triton_backend.py +208 -34
  136. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  137. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  138. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  139. sglang/srt/layers/attention/trtllm_mla_backend.py +362 -43
  140. sglang/srt/layers/attention/utils.py +89 -7
  141. sglang/srt/layers/attention/vision.py +3 -3
  142. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  143. sglang/srt/layers/communicator.py +12 -7
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +5 -9
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  146. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  147. sglang/srt/layers/dp_attention.py +17 -0
  148. sglang/srt/layers/layernorm.py +64 -19
  149. sglang/srt/layers/linear.py +9 -1
  150. sglang/srt/layers/logits_processor.py +152 -17
  151. sglang/srt/layers/modelopt_utils.py +11 -0
  152. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  153. sglang/srt/layers/moe/cutlass_w4a8_moe.py +351 -21
  154. sglang/srt/layers/moe/ep_moe/kernels.py +229 -457
  155. sglang/srt/layers/moe/ep_moe/layer.py +154 -625
  156. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  160. sglang/srt/layers/moe/fused_moe_triton/layer.py +79 -73
  161. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +25 -46
  162. sglang/srt/layers/moe/moe_runner/deep_gemm.py +569 -0
  163. sglang/srt/layers/moe/moe_runner/runner.py +6 -0
  164. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  165. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  166. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  167. sglang/srt/layers/moe/router.py +51 -15
  168. sglang/srt/layers/moe/token_dispatcher/__init__.py +14 -4
  169. sglang/srt/layers/moe/token_dispatcher/base.py +12 -6
  170. sglang/srt/layers/moe/token_dispatcher/deepep.py +127 -110
  171. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  172. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  173. sglang/srt/layers/moe/topk.py +7 -6
  174. sglang/srt/layers/moe/utils.py +20 -5
  175. sglang/srt/layers/quantization/__init__.py +5 -58
  176. sglang/srt/layers/quantization/awq.py +183 -9
  177. sglang/srt/layers/quantization/awq_triton.py +29 -0
  178. sglang/srt/layers/quantization/base_config.py +27 -1
  179. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  180. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  181. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  183. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  184. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  185. sglang/srt/layers/quantization/fp8.py +152 -81
  186. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  187. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  188. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  189. sglang/srt/layers/quantization/gguf.py +566 -0
  190. sglang/srt/layers/quantization/gptq.py +0 -1
  191. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  192. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  193. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  194. sglang/srt/layers/quantization/mxfp4.py +35 -68
  195. sglang/srt/layers/quantization/petit.py +1 -1
  196. sglang/srt/layers/quantization/quark/quark.py +3 -1
  197. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  198. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  199. sglang/srt/layers/quantization/unquant.py +23 -48
  200. sglang/srt/layers/quantization/utils.py +0 -1
  201. sglang/srt/layers/quantization/w4afp8.py +87 -20
  202. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  203. sglang/srt/layers/radix_attention.py +62 -9
  204. sglang/srt/layers/rotary_embedding.py +686 -17
  205. sglang/srt/layers/sampler.py +47 -16
  206. sglang/srt/layers/sparse_pooler.py +98 -0
  207. sglang/srt/layers/utils.py +0 -1
  208. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  209. sglang/srt/lora/backend/triton_backend.py +0 -1
  210. sglang/srt/lora/eviction_policy.py +139 -0
  211. sglang/srt/lora/lora_manager.py +24 -9
  212. sglang/srt/lora/lora_registry.py +1 -1
  213. sglang/srt/lora/mem_pool.py +40 -16
  214. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  215. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  216. sglang/srt/managers/cache_controller.py +48 -17
  217. sglang/srt/managers/data_parallel_controller.py +146 -42
  218. sglang/srt/managers/detokenizer_manager.py +40 -13
  219. sglang/srt/managers/io_struct.py +69 -16
  220. sglang/srt/managers/mm_utils.py +20 -18
  221. sglang/srt/managers/multi_tokenizer_mixin.py +83 -82
  222. sglang/srt/managers/overlap_utils.py +96 -19
  223. sglang/srt/managers/schedule_batch.py +241 -511
  224. sglang/srt/managers/schedule_policy.py +15 -2
  225. sglang/srt/managers/scheduler.py +420 -514
  226. sglang/srt/managers/scheduler_metrics_mixin.py +73 -18
  227. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  228. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  229. sglang/srt/managers/scheduler_profiler_mixin.py +60 -14
  230. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  231. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  232. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  233. sglang/srt/managers/tokenizer_manager.py +375 -95
  234. sglang/srt/managers/tp_worker.py +212 -161
  235. sglang/srt/managers/utils.py +78 -2
  236. sglang/srt/mem_cache/allocator.py +7 -2
  237. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  238. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  239. sglang/srt/mem_cache/chunk_cache.py +13 -2
  240. sglang/srt/mem_cache/common.py +480 -0
  241. sglang/srt/mem_cache/evict_policy.py +16 -1
  242. sglang/srt/mem_cache/hicache_storage.py +11 -2
  243. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  244. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  245. sglang/srt/mem_cache/memory_pool.py +517 -219
  246. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  247. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  248. sglang/srt/mem_cache/radix_cache.py +53 -19
  249. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  250. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  251. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  252. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  253. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  254. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  255. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  256. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  257. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  258. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  259. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  260. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  261. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  262. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  263. sglang/srt/metrics/collector.py +31 -0
  264. sglang/srt/metrics/func_timer.py +1 -1
  265. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  266. sglang/srt/model_executor/forward_batch_info.py +71 -25
  267. sglang/srt/model_executor/model_runner.py +362 -270
  268. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  269. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +549 -0
  270. sglang/srt/model_loader/__init__.py +1 -1
  271. sglang/srt/model_loader/loader.py +424 -27
  272. sglang/srt/model_loader/utils.py +0 -1
  273. sglang/srt/model_loader/weight_utils.py +47 -28
  274. sglang/srt/models/apertus.py +2 -3
  275. sglang/srt/models/arcee.py +2 -2
  276. sglang/srt/models/bailing_moe.py +13 -52
  277. sglang/srt/models/bailing_moe_nextn.py +3 -4
  278. sglang/srt/models/bert.py +1 -1
  279. sglang/srt/models/deepseek_nextn.py +19 -3
  280. sglang/srt/models/deepseek_ocr.py +1516 -0
  281. sglang/srt/models/deepseek_v2.py +418 -140
  282. sglang/srt/models/dots_ocr.py +0 -2
  283. sglang/srt/models/dots_vlm.py +0 -1
  284. sglang/srt/models/dots_vlm_vit.py +1 -1
  285. sglang/srt/models/falcon_h1.py +13 -19
  286. sglang/srt/models/gemma3_mm.py +16 -0
  287. sglang/srt/models/gemma3n_mm.py +1 -2
  288. sglang/srt/models/glm4_moe.py +327 -382
  289. sglang/srt/models/glm4_moe_nextn.py +6 -16
  290. sglang/srt/models/glm4v.py +2 -1
  291. sglang/srt/models/glm4v_moe.py +32 -199
  292. sglang/srt/models/gpt_oss.py +5 -5
  293. sglang/srt/models/grok.py +10 -23
  294. sglang/srt/models/hunyuan.py +2 -7
  295. sglang/srt/models/interns1.py +0 -1
  296. sglang/srt/models/kimi_vl.py +1 -7
  297. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  298. sglang/srt/models/llama.py +2 -2
  299. sglang/srt/models/llama_eagle3.py +1 -1
  300. sglang/srt/models/longcat_flash.py +5 -22
  301. sglang/srt/models/longcat_flash_nextn.py +3 -14
  302. sglang/srt/models/mimo.py +2 -13
  303. sglang/srt/models/mimo_mtp.py +1 -2
  304. sglang/srt/models/minicpmo.py +7 -5
  305. sglang/srt/models/minimax_m2.py +922 -0
  306. sglang/srt/models/mixtral.py +1 -4
  307. sglang/srt/models/mllama.py +1 -1
  308. sglang/srt/models/mllama4.py +13 -3
  309. sglang/srt/models/nemotron_h.py +511 -0
  310. sglang/srt/models/nvila.py +355 -0
  311. sglang/srt/models/nvila_lite.py +184 -0
  312. sglang/srt/models/olmo2.py +31 -4
  313. sglang/srt/models/opt.py +5 -5
  314. sglang/srt/models/phi.py +1 -1
  315. sglang/srt/models/phi4mm.py +1 -1
  316. sglang/srt/models/phimoe.py +0 -1
  317. sglang/srt/models/pixtral.py +0 -3
  318. sglang/srt/models/points_v15_chat.py +186 -0
  319. sglang/srt/models/qwen.py +0 -1
  320. sglang/srt/models/qwen2.py +22 -1
  321. sglang/srt/models/qwen2_5_vl.py +3 -3
  322. sglang/srt/models/qwen2_audio.py +2 -15
  323. sglang/srt/models/qwen2_moe.py +15 -12
  324. sglang/srt/models/qwen2_vl.py +5 -2
  325. sglang/srt/models/qwen3.py +34 -4
  326. sglang/srt/models/qwen3_moe.py +19 -37
  327. sglang/srt/models/qwen3_next.py +7 -12
  328. sglang/srt/models/qwen3_next_mtp.py +3 -4
  329. sglang/srt/models/qwen3_omni_moe.py +661 -0
  330. sglang/srt/models/qwen3_vl.py +37 -33
  331. sglang/srt/models/qwen3_vl_moe.py +57 -185
  332. sglang/srt/models/roberta.py +55 -3
  333. sglang/srt/models/sarashina2_vision.py +0 -1
  334. sglang/srt/models/step3_vl.py +3 -5
  335. sglang/srt/models/utils.py +11 -1
  336. sglang/srt/multimodal/processors/base_processor.py +7 -2
  337. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  338. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  339. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  340. sglang/srt/multimodal/processors/glm4v.py +2 -6
  341. sglang/srt/multimodal/processors/internvl.py +0 -2
  342. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  343. sglang/srt/multimodal/processors/mllama4.py +0 -8
  344. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  345. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  346. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  347. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  348. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  349. sglang/srt/parser/conversation.py +41 -0
  350. sglang/srt/parser/reasoning_parser.py +28 -2
  351. sglang/srt/sampling/custom_logit_processor.py +77 -2
  352. sglang/srt/sampling/sampling_batch_info.py +17 -22
  353. sglang/srt/sampling/sampling_params.py +70 -2
  354. sglang/srt/server_args.py +846 -163
  355. sglang/srt/server_args_config_parser.py +1 -1
  356. sglang/srt/single_batch_overlap.py +36 -31
  357. sglang/srt/speculative/base_spec_worker.py +34 -0
  358. sglang/srt/speculative/draft_utils.py +226 -0
  359. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  360. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  361. sglang/srt/speculative/eagle_info.py +57 -18
  362. sglang/srt/speculative/eagle_info_v2.py +458 -0
  363. sglang/srt/speculative/eagle_utils.py +138 -0
  364. sglang/srt/speculative/eagle_worker.py +83 -280
  365. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  366. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  367. sglang/srt/speculative/ngram_worker.py +12 -11
  368. sglang/srt/speculative/spec_info.py +2 -0
  369. sglang/srt/speculative/spec_utils.py +38 -3
  370. sglang/srt/speculative/standalone_worker.py +4 -14
  371. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  372. sglang/srt/two_batch_overlap.py +28 -14
  373. sglang/srt/utils/__init__.py +1 -1
  374. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  375. sglang/srt/utils/common.py +272 -82
  376. sglang/srt/utils/hf_transformers_utils.py +44 -17
  377. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  378. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  379. sglang/srt/utils/profile_merger.py +199 -0
  380. sglang/test/attention/test_flashattn_backend.py +1 -1
  381. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  382. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  383. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  384. sglang/test/few_shot_gsm8k_engine.py +2 -4
  385. sglang/test/kit_matched_stop.py +157 -0
  386. sglang/test/longbench_v2/__init__.py +1 -0
  387. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  388. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  389. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  390. sglang/test/run_eval.py +41 -0
  391. sglang/test/runners.py +2 -0
  392. sglang/test/send_one.py +42 -7
  393. sglang/test/simple_eval_common.py +3 -0
  394. sglang/test/simple_eval_gpqa.py +0 -1
  395. sglang/test/simple_eval_humaneval.py +0 -3
  396. sglang/test/simple_eval_longbench_v2.py +344 -0
  397. sglang/test/test_block_fp8.py +1 -2
  398. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  399. sglang/test/test_cutlass_moe.py +1 -2
  400. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  401. sglang/test/test_deterministic.py +463 -107
  402. sglang/test/test_deterministic_utils.py +74 -0
  403. sglang/test/test_disaggregation_utils.py +81 -0
  404. sglang/test/test_marlin_moe.py +0 -1
  405. sglang/test/test_utils.py +85 -20
  406. sglang/version.py +1 -1
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +48 -35
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +414 -350
  409. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  410. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  411. sglang/srt/models/vila.py +0 -306
  412. sglang/srt/speculative/build_eagle_tree.py +0 -427
  413. sglang/test/test_block_fp8_ep.py +0 -358
  414. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  415. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  416. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  417. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  418. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  419. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -1,81 +0,0 @@
1
- # Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
2
- from sglang.srt.distributed.utils import divide
3
-
4
-
5
- class MambaStateShapeCalculator:
6
-
7
- @classmethod
8
- def linear_attention_state_shape(
9
- cls,
10
- num_heads: int,
11
- tp_size: int,
12
- head_dim: int,
13
- ) -> tuple[tuple[int, int, int], ...]:
14
-
15
- state_shape = (num_heads // tp_size, head_dim, head_dim)
16
- return (state_shape,)
17
-
18
- @classmethod
19
- def mamba1_state_shape(
20
- cls,
21
- tp_world_size: int,
22
- intermediate_size: int,
23
- state_size: int,
24
- conv_kernel: int,
25
- ) -> tuple[tuple[int, int], tuple[int, int]]:
26
- conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
27
-
28
- temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
29
-
30
- conv_state_shape = conv_state_shape[1], conv_state_shape[0]
31
-
32
- return conv_state_shape, temporal_state_shape
33
-
34
- @classmethod
35
- def mamba2_state_shape(
36
- cls,
37
- tp_world_size: int,
38
- intermediate_size: int,
39
- n_groups: int,
40
- num_heads: int,
41
- head_dim: int,
42
- state_size: int,
43
- conv_kernel: int,
44
- ) -> tuple[tuple[int, int], tuple[int, int, int]]:
45
- # if n_groups is not divisible by world_size, need to extend the shards
46
- # to ensure all groups needed by a head is sharded along with it
47
- n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
48
- # heads and n_groups are TP-ed
49
- conv_dim = intermediate_size + 2 * n_groups * state_size
50
-
51
- # contiguous along 'dim' axis
52
- conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
53
-
54
- # These are not TP-ed as they depend on A, dt_bias, D
55
- # - they are typically small
56
- # e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
57
- temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
58
- return conv_state_shape, temporal_state_shape
59
-
60
- @classmethod
61
- def short_conv_state_shape(
62
- cls,
63
- tp_world_size: int,
64
- intermediate_size: int,
65
- conv_kernel: int,
66
- ) -> tuple[tuple[int, int]]:
67
- conv_dim = divide(intermediate_size, tp_world_size)
68
- conv_state_shape = (conv_kernel - 1, conv_dim)
69
- return (conv_state_shape,)
70
-
71
- @classmethod
72
- def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
73
- """Compute the increase in group numbers to account for
74
- replication in order to accompany the head shards."""
75
-
76
- # in the case ngoups % tp_size == 0, this will be zero
77
- if ngroups % tp_size == 0:
78
- return 0
79
-
80
- # for n_groups == 1, this is exactly tp_size - n_groups
81
- return tp_size - ngroups
@@ -1,311 +0,0 @@
1
- # Copyright 2023-2024 SGLang Team
2
- # Licensed under the Apache License, Version 2.0 (the "License");
3
- # you may not use this file except in compliance with the License.
4
- # You may obtain a copy of the License at
5
- #
6
- # http://www.apache.org/licenses/LICENSE-2.0
7
- #
8
- # Unless required by applicable law or agreed to in writing, software
9
- # distributed under the License is distributed on an "AS IS" BASIS,
10
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
- # See the License for the specific language governing permissions and
12
- # limitations under the License.
13
- # ==============================================================================
14
- """A tensor parallel worker."""
15
- from __future__ import annotations
16
-
17
- import dataclasses
18
- import logging
19
- import signal
20
- import threading
21
- from queue import Queue
22
- from typing import TYPE_CHECKING, List, Optional, Tuple
23
-
24
- import psutil
25
- import torch
26
-
27
- from sglang.srt.managers.io_struct import (
28
- DestroyWeightsUpdateGroupReqInput,
29
- GetWeightsByNameReqInput,
30
- InitWeightsSendGroupForRemoteInstanceReqInput,
31
- InitWeightsUpdateGroupReqInput,
32
- LoadLoRAAdapterReqInput,
33
- SendWeightsToRemoteInstanceReqInput,
34
- UnloadLoRAAdapterReqInput,
35
- UpdateWeightFromDiskReqInput,
36
- UpdateWeightsFromDistributedReqInput,
37
- UpdateWeightsFromTensorReqInput,
38
- )
39
- from sglang.srt.managers.overlap_utils import FutureMap
40
- from sglang.srt.managers.schedule_batch import ModelWorkerBatch
41
- from sglang.srt.managers.tp_worker import TpModelWorker
42
- from sglang.srt.model_executor.forward_batch_info import ForwardBatchOutput
43
- from sglang.srt.server_args import ServerArgs
44
- from sglang.srt.utils import DynamicGradMode
45
- from sglang.utils import get_exception_traceback
46
-
47
- if TYPE_CHECKING:
48
- from sglang.srt.managers.cache_controller import LayerDoneCounter
49
-
50
- logger = logging.getLogger(__name__)
51
-
52
-
53
- class TpModelWorkerClient:
54
- """A tensor parallel model worker."""
55
-
56
- def __init__(
57
- self,
58
- server_args: ServerArgs,
59
- gpu_id: int,
60
- tp_rank: int,
61
- moe_ep_rank: int,
62
- pp_rank: int,
63
- dp_rank: Optional[int],
64
- nccl_port: int,
65
- ):
66
- # Load the model
67
- self.worker = TpModelWorker(
68
- server_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank, nccl_port
69
- )
70
- self.max_running_requests = self.worker.max_running_requests
71
- self.device = self.worker.device
72
- self.gpu_id = gpu_id
73
-
74
- # Init future mappings
75
- self.future_map = FutureMap(self.max_running_requests, self.device)
76
-
77
- # Launch threads
78
- self.input_queue = Queue[Tuple[ModelWorkerBatch, int, torch.Event]]()
79
- self.output_queue = Queue()
80
- self.forward_stream = torch.get_device_module(self.device).Stream()
81
- self.forward_thread = threading.Thread(
82
- target=self.forward_thread_func,
83
- )
84
- self.forward_thread.start()
85
- self.parent_process = psutil.Process().parent()
86
- self.scheduler_stream = torch.get_device_module(self.device).current_stream()
87
- if self.device == "cpu":
88
- self.scheduler_stream.synchronize = lambda: None # No-op for CPU
89
-
90
- self.hicache_layer_transfer_counter = None
91
-
92
- def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
93
- self.hicache_layer_transfer_counter = counter
94
-
95
- def get_worker_info(self):
96
- return self.worker.get_worker_info()
97
-
98
- def get_tokens_per_layer_info(self):
99
- return self.worker.get_tokens_per_layer_info()
100
-
101
- @property
102
- def sliding_window_size(self) -> Optional[int]:
103
- return self.worker.sliding_window_size
104
-
105
- @property
106
- def is_hybrid(self) -> bool:
107
- return self.worker.is_hybrid
108
-
109
- def get_pad_input_ids_func(self):
110
- return self.worker.get_pad_input_ids_func()
111
-
112
- def get_tp_group(self):
113
- return self.worker.get_tp_group()
114
-
115
- def get_attention_tp_group(self):
116
- return self.worker.get_attention_tp_group()
117
-
118
- def get_attention_tp_cpu_group(self):
119
- return self.worker.get_attention_tp_cpu_group()
120
-
121
- def get_memory_pool(self):
122
- return (
123
- self.worker.model_runner.req_to_token_pool,
124
- self.worker.model_runner.token_to_kv_pool_allocator,
125
- )
126
-
127
- def get_kv_cache(self):
128
- return self.worker.model_runner.token_to_kv_pool
129
-
130
- def forward_thread_func(self):
131
- try:
132
- with torch.get_device_module(self.device).stream(self.forward_stream):
133
- self.forward_thread_func_()
134
- except Exception:
135
- traceback = get_exception_traceback()
136
- logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
137
- self.parent_process.send_signal(signal.SIGQUIT)
138
-
139
- @DynamicGradMode()
140
- def forward_thread_func_(self):
141
- batch_pt = 0
142
- batch_lists: List = [None] * 2
143
-
144
- while True:
145
- model_worker_batch, future_map_ct, sync_event = self.input_queue.get()
146
- if not model_worker_batch:
147
- break
148
-
149
- sync_event.wait()
150
-
151
- # Keep a reference of model_worker_batch by storing it into a list.
152
- # Otherwise, the tensor members of model_worker_batch will be released
153
- # by pytorch and cause CUDA illegal memory access errors.
154
- batch_lists[batch_pt % 2] = model_worker_batch
155
- batch_pt += 1
156
-
157
- # Create event
158
- copy_done = torch.get_device_module(self.device).Event()
159
-
160
- # Resolve future tokens in the input
161
- self.future_map.resolve_future(model_worker_batch)
162
-
163
- # Run forward
164
- forward_batch_output = self.worker.forward_batch_generation(
165
- model_worker_batch,
166
- model_worker_batch.launch_done,
167
- )
168
-
169
- logits_output, next_token_ids, can_run_cuda_graph = (
170
- forward_batch_output.logits_output,
171
- forward_batch_output.next_token_ids,
172
- forward_batch_output.can_run_cuda_graph,
173
- )
174
-
175
- # Update the future token ids map
176
- bs = len(model_worker_batch.seq_lens)
177
- if model_worker_batch.is_prefill_only:
178
- # For prefill-only requests, create dummy token IDs on CPU
179
- next_token_ids = torch.zeros(bs, dtype=torch.long)
180
-
181
- # store the future indices into future map
182
- self.future_map.store_to_map(future_map_ct, bs, next_token_ids)
183
-
184
- # Copy results to the CPU
185
- if model_worker_batch.return_logprob:
186
- if logits_output.next_token_logprobs is not None:
187
- logits_output.next_token_logprobs = (
188
- logits_output.next_token_logprobs.to("cpu", non_blocking=True)
189
- )
190
- if logits_output.input_token_logprobs is not None:
191
- logits_output.input_token_logprobs = (
192
- logits_output.input_token_logprobs.to("cpu", non_blocking=True)
193
- )
194
- if logits_output.hidden_states is not None:
195
- logits_output.hidden_states = logits_output.hidden_states.to(
196
- "cpu", non_blocking=True
197
- )
198
- # Only copy to CPU if not already on CPU
199
- if next_token_ids.device.type != "cpu":
200
- next_token_ids = next_token_ids.to("cpu", non_blocking=True)
201
- copy_done.record()
202
-
203
- self.output_queue.put(
204
- (copy_done, logits_output, next_token_ids, can_run_cuda_graph)
205
- )
206
-
207
- def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
208
- """
209
- This function is called to resolve the last batch result and
210
- wait for the current batch to be launched. Used in overlap mode.
211
- """
212
- copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
213
- self.output_queue.get()
214
- )
215
-
216
- if launch_done is not None:
217
- launch_done.wait()
218
- copy_done.synchronize()
219
-
220
- if logits_output.next_token_logprobs is not None:
221
- logits_output.next_token_logprobs = (
222
- logits_output.next_token_logprobs.tolist()
223
- )
224
- if logits_output.input_token_logprobs is not None:
225
- logits_output.input_token_logprobs = tuple(
226
- logits_output.input_token_logprobs.tolist()
227
- )
228
- next_token_ids = next_token_ids.tolist()
229
- return logits_output, next_token_ids, can_run_cuda_graph
230
-
231
- def forward_batch_generation(
232
- self, model_worker_batch: ModelWorkerBatch
233
- ) -> ForwardBatchOutput:
234
- # Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
235
- sampling_info = model_worker_batch.sampling_info
236
- sampling_info.update_penalties()
237
- model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
238
- sampling_info,
239
- sampling_info_done=threading.Event(),
240
- penalizer_orchestrator=None,
241
- )
242
-
243
- # A cuda stream sync here to avoid the cuda illegal memory access error.
244
- sync_event = torch.get_device_module(self.device).Event()
245
- sync_event.record(self.scheduler_stream)
246
-
247
- # Push a new batch to the queue
248
- bs = len(model_worker_batch.seq_lens)
249
- cur_future_map_ct = self.future_map.update_ct(bs)
250
- self.input_queue.put((model_worker_batch, cur_future_map_ct, sync_event))
251
-
252
- # get this forward batch's future token ids
253
- future_next_token_ids = self.future_map.update_next_future(
254
- cur_future_map_ct, bs
255
- )
256
- return ForwardBatchOutput(
257
- next_token_ids=future_next_token_ids,
258
- can_run_cuda_graph=False,
259
- )
260
-
261
- def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
262
- success, message = self.worker.update_weights_from_disk(recv_req)
263
- return success, message
264
-
265
- def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
266
- success, message = self.worker.init_weights_update_group(recv_req)
267
- return success, message
268
-
269
- def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
270
- success, message = self.worker.destroy_weights_update_group(recv_req)
271
- return success, message
272
-
273
- def init_weights_send_group_for_remote_instance(
274
- self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
275
- ):
276
- success, message = self.worker.init_weights_send_group_for_remote_instance(
277
- recv_req
278
- )
279
- return success, message
280
-
281
- def send_weights_to_remote_instance(
282
- self, recv_req: SendWeightsToRemoteInstanceReqInput
283
- ):
284
- success, message = self.worker.send_weights_to_remote_instance(recv_req)
285
- return success, message
286
-
287
- def update_weights_from_distributed(
288
- self, recv_req: UpdateWeightsFromDistributedReqInput
289
- ):
290
- success, message = self.worker.update_weights_from_distributed(recv_req)
291
- return success, message
292
-
293
- def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
294
- success, message = self.worker.update_weights_from_tensor(recv_req)
295
- return success, message
296
-
297
- def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
298
- return self.worker.get_weights_by_name(recv_req)
299
-
300
- def load_lora_adapter(self, recv_req: LoadLoRAAdapterReqInput):
301
- return self.worker.load_lora_adapter(recv_req)
302
-
303
- def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
304
- return self.worker.unload_lora_adapter(recv_req)
305
-
306
- def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
307
- return self.worker.can_run_lora_batch(lora_ids)
308
-
309
- def __delete__(self):
310
- self.input_queue.put((None, None))
311
- self.copy_queue.put((None, None, None))
sglang/srt/models/vila.py DELETED
@@ -1,306 +0,0 @@
1
- import logging
2
- from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
- from torch import Tensor
8
- from transformers.configuration_utils import PretrainedConfig
9
- from transformers.modeling_outputs import BaseModelOutputWithPooling
10
- from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
11
- from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
12
-
13
- import sglang.srt.managers.mm_utils as mm_utils
14
- import sglang.srt.model_loader.weight_utils as weight_utils
15
- import sglang.srt.utils as utils
16
- from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
17
- from sglang.srt.layers.pooler import Pooler, PoolingType
18
- from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
- from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
20
- from sglang.srt.managers.schedule_batch import (
21
- Modality,
22
- MultimodalDataItem,
23
- MultimodalInputs,
24
- )
25
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
26
- from sglang.srt.models.qwen2 import Qwen2ForCausalLM
27
-
28
- logger = logging.getLogger(__name__)
29
-
30
-
31
- ##### BEGIN COPY configuration.py #####
32
-
33
-
34
- class VILAConfig(PretrainedConfig):
35
- # Class attributes.
36
- model_type: str = "vila"
37
- sub_configs: Dict[str, PretrainedConfig] = {
38
- "text_config": Qwen2Config(),
39
- "vision_config": SiglipVisionConfig(),
40
- }
41
- _auto_class: Optional[str] = "AutoConfig"
42
-
43
- # Configuration for sub-modules.
44
- text_config: Qwen2Config = Qwen2Config()
45
- vision_config: SiglipVisionConfig = SiglipVisionConfig()
46
-
47
- # Model configuration.
48
- hidden_size: int
49
- image_token_id: int
50
- mm_hidden_size: int
51
- mm_projector_type: str
52
- mm_vision_select_feature: str
53
- mm_vision_select_layer: int
54
- video_token_id: int
55
-
56
- def __init__(
57
- self,
58
- text_config: Optional[Dict[str, Any]] = None,
59
- vision_config: Optional[Dict[str, Any]] = None,
60
- *,
61
- hidden_size: int = 1536,
62
- image_token_id: int = 151649,
63
- mm_hidden_size: int = 1152,
64
- mm_projector_type: str = "mlp_downsample_3x3_fix",
65
- mm_vision_select_feature: str = "cls_patch",
66
- mm_vision_select_layer: int = -2,
67
- video_token_id: int = 151650,
68
- **kwargs,
69
- ):
70
- super().__init__(**kwargs)
71
-
72
- self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config()
73
- self.vision_config = (
74
- SiglipVisionConfig(**vision_config)
75
- if vision_config
76
- else SiglipVisionConfig()
77
- )
78
-
79
- self.hidden_size = hidden_size
80
- self.image_token_id = image_token_id
81
- self.mm_hidden_size = mm_hidden_size
82
- self.mm_projector_type = mm_projector_type
83
- self.mm_vision_select_feature = mm_vision_select_feature
84
- self.mm_vision_select_layer = mm_vision_select_layer
85
- self.video_token_id = video_token_id
86
-
87
-
88
- ##### END COPY configuration.py #####
89
-
90
- ##### BEGIN COPY modeling_vila.py #####
91
-
92
-
93
- class DownSample3x3BlockFix(nn.Module):
94
- def forward(self, x: Tensor) -> Tensor:
95
- """
96
- Args:
97
- x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
98
-
99
- Returns:
100
- The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
101
- """
102
-
103
- batch_size, sequence_length, hidden_size = x.shape
104
-
105
- feat_size = int(sequence_length**0.5)
106
- if feat_size**2 != sequence_length:
107
- raise ValueError(
108
- f"Cannot take square root: sequence_length {sequence_length} is not a perfect square"
109
- )
110
-
111
- features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
112
-
113
- pad_after = (3 - feat_size % 3) % 3
114
- if pad_after > 0:
115
- features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
116
- feat_size = feat_size + pad_after
117
-
118
- features = features.reshape(
119
- batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size
120
- )
121
- features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
122
- features = features.reshape(batch_size, -1, 9 * hidden_size)
123
-
124
- return features
125
-
126
-
127
- class MultimodalProjector(nn.Module):
128
- layers: nn.Sequential
129
-
130
- def __init__(
131
- self,
132
- config: VILAConfig,
133
- *args,
134
- **kwargs,
135
- ):
136
- super().__init__(*args, **kwargs)
137
-
138
- if config.mm_projector_type == "mlp_downsample_3x3_fix":
139
- self.layers = nn.Sequential(
140
- DownSample3x3BlockFix(),
141
- nn.LayerNorm(config.mm_hidden_size * 9),
142
- nn.Linear(
143
- config.mm_hidden_size * 9,
144
- config.mm_hidden_size * 3,
145
- ),
146
- nn.GELU(),
147
- nn.LayerNorm(config.vision_config.hidden_size * 3),
148
- nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
149
- nn.GELU(),
150
- nn.Linear(config.hidden_size, config.hidden_size),
151
- )
152
- else:
153
- raise NotImplementedError(
154
- f"Unsupported mm_projector_type: {config.mm_projector_type}"
155
- )
156
-
157
- self.layers.type(config.torch_dtype)
158
-
159
- @property
160
- def device(self) -> torch.device:
161
- return next(self.parameters()).device
162
-
163
- @property
164
- def dtype(self) -> torch.dtype:
165
- return next(self.parameters()).dtype
166
-
167
- def forward(self, x: Tensor) -> Tensor:
168
- """
169
- Args:
170
- x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
171
-
172
- Returns:
173
- The output tensor of shape (batch_size, image_pad_len, hidden_size).
174
- """
175
-
176
- return self.layers(x.to(device=self.device, dtype=self.dtype))
177
-
178
-
179
- ##### END COPY modeling_vila.py #####
180
-
181
-
182
- class VILAForConditionalGeneration(nn.Module):
183
- config: VILAConfig
184
- quant_config: Optional[QuantizationConfig]
185
-
186
- logits_processor: LogitsProcessor
187
- pooler: Pooler
188
-
189
- llm: Qwen2ForCausalLM
190
- mm_projector: MultimodalProjector
191
- vision_tower: SiglipVisionModel
192
-
193
- def __init__(
194
- self,
195
- config: VILAConfig,
196
- quant_config: Optional[QuantizationConfig] = None,
197
- prefix: str = "",
198
- ) -> None:
199
- super().__init__()
200
-
201
- self.config = config
202
- self.quant_config = quant_config
203
-
204
- self.logits_processor = LogitsProcessor(config)
205
- self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
206
-
207
- self.llm = Qwen2ForCausalLM(
208
- config=config.text_config,
209
- quant_config=quant_config,
210
- prefix=utils.add_prefix("llm", prefix),
211
- )
212
- self.mm_projector = MultimodalProjector(config)
213
- self.vision_tower = SiglipVisionModel(config.vision_config)
214
-
215
- @property
216
- def dtype(self) -> torch.dtype:
217
- return self.config.torch_dtype
218
-
219
- def forward(
220
- self,
221
- input_ids: Tensor,
222
- positions: Tensor,
223
- forward_batch: ForwardBatch,
224
- get_embedding: bool = False,
225
- ) -> LogitsProcessorOutput:
226
- output = mm_utils.general_mm_embed_routine(
227
- input_ids=input_ids,
228
- forward_batch=forward_batch,
229
- language_model=self.llm,
230
- data_embedding_funcs={
231
- Modality.IMAGE: self.get_image_feature,
232
- },
233
- get_embedding=get_embedding,
234
- positions=positions,
235
- )
236
-
237
- return cast(LogitsProcessorOutput, output)
238
-
239
- def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
240
- pixel_values = cast(Tensor, mm_input[0].feature)
241
-
242
- ##### BEGIN COPY modeling_vila.py #####
243
-
244
- vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__(
245
- pixel_values.to(
246
- device=self.vision_tower.device, dtype=self.vision_tower.dtype
247
- ),
248
- output_hidden_states=True,
249
- )
250
-
251
- mm_projector_input = self._vision_tower_output_to_mm_projector_input(
252
- vision_tower_output
253
- )
254
-
255
- image_embedding: Tensor = self.mm_projector.__call__(
256
- mm_projector_input.to(
257
- device=self.mm_projector.device, dtype=self.mm_projector.dtype
258
- )
259
- )
260
-
261
- ##### END COPY modeling_vila.py #####
262
-
263
- return image_embedding
264
-
265
- def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> None:
266
- params_dict = dict(self.named_parameters())
267
-
268
- for name, loaded_weight in weights:
269
- if name.startswith("llm."):
270
- self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
271
- else:
272
- param = params_dict[name]
273
- weight_loader = getattr(
274
- param, "weight_loader", weight_utils.default_weight_loader
275
- )
276
- weight_loader(param, loaded_weight)
277
-
278
- def pad_input_ids(
279
- self, input_ids: List[int], mm_inputs: MultimodalInputs
280
- ) -> List[int]:
281
- pattern = MultiModalityDataPaddingPatternMultimodalTokens()
282
- return pattern.pad_input_tokens(input_ids, mm_inputs)
283
-
284
- ##### BEGIN COPY modeling_vila.py #####
285
-
286
- def _vision_tower_output_to_mm_projector_input(
287
- self,
288
- vision_tower_output: BaseModelOutputWithPooling,
289
- ) -> Tensor:
290
- assert vision_tower_output.hidden_states is not None
291
-
292
- selected_layer_hidden_states = vision_tower_output.hidden_states[
293
- self.config.mm_vision_select_layer
294
- ]
295
-
296
- if self.config.mm_vision_select_feature == "cls_patch":
297
- return selected_layer_hidden_states
298
- else:
299
- raise NotImplementedError(
300
- f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}"
301
- )
302
-
303
- ##### END COPY modeling_vila.py #####
304
-
305
-
306
- EntryClass = [VILAForConditionalGeneration]