sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- """Inference-only GLM-4.5 NextN Speculative Decoding."""
15
+ """Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
16
16
  import logging
17
17
  from typing import Iterable, Optional, Tuple
18
18
 
@@ -48,7 +48,7 @@ class Glm4MoeModelNextN(nn.Module):
48
48
  super().__init__()
49
49
  if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
50
50
  logger.warning(
51
- "Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
51
+ "Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 / GLM-4.6 model."
52
52
  )
53
53
  quant_config = None
54
54
 
@@ -7,7 +7,6 @@ import torch.nn as nn
7
7
  import torch.nn.functional as F
8
8
  from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
9
9
 
10
- from sglang.srt.hf_transformers_utils import get_processor
11
10
  from sglang.srt.layers.activation import SiluAndMul
12
11
  from sglang.srt.layers.attention import vision_utils
13
12
  from sglang.srt.layers.layernorm import RMSNorm
@@ -28,6 +27,7 @@ from sglang.srt.models.qwen2_5_vl import (
28
27
  Qwen2_5_VLForConditionalGeneration,
29
28
  )
30
29
  from sglang.srt.utils import add_prefix
30
+ from sglang.srt.utils.hf_transformers_utils import get_processor
31
31
 
32
32
  logger = logging.getLogger(__name__)
33
33
 
@@ -93,9 +93,8 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
93
93
  quant_config=quant_config,
94
94
  prefix=prefix,
95
95
  num_dummy_heads=config.num_dummy_heads,
96
+ rms_norm_eps=config.rms_norm_eps,
96
97
  )
97
- self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
98
- self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
99
98
 
100
99
  self.mlp = Glm4vVisionMLP(
101
100
  config.hidden_size,
@@ -498,6 +497,9 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
498
497
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
499
498
  self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
500
499
 
500
+ # For EAGLE3 support
501
+ self.capture_aux_hidden_states = False
502
+
501
503
  def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
502
504
  pixel_values = torch.cat(
503
505
  [item.feature.squeeze(0) for item in items], dim=0
@@ -10,7 +10,6 @@ from sglang.srt.distributed import (
10
10
  get_moe_expert_parallel_world_size,
11
11
  get_tensor_model_parallel_world_size,
12
12
  )
13
- from sglang.srt.hf_transformers_utils import get_processor
14
13
  from sglang.srt.layers.attention import vision_utils
15
14
  from sglang.srt.layers.logits_processor import LogitsProcessor
16
15
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
@@ -22,6 +21,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
22
21
  from sglang.srt.models.glm4_moe import Glm4MoeModel
23
22
  from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
24
23
  from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
24
+ from sglang.srt.utils.hf_transformers_utils import get_processor
25
25
 
26
26
  _is_cuda = is_cuda()
27
27
 
@@ -74,6 +74,9 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration):
74
74
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
75
75
  self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
76
76
 
77
+ # For EAGLE3 support
78
+ self.capture_aux_hidden_states = False
79
+
77
80
  def determine_num_fused_shared_experts(
78
81
  self, architecture: str = "Glm4MoeForCausalLM"
79
82
  ):
@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
66
66
  from sglang.srt.managers.schedule_batch import global_server_args_dict
67
67
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
68
68
  from sglang.srt.model_loader.weight_utils import default_weight_loader
69
+ from sglang.srt.models.utils import (
70
+ create_fused_set_kv_buffer_arg,
71
+ enable_fused_set_kv_buffer,
72
+ )
69
73
  from sglang.srt.utils import (
70
74
  LazyValue,
71
75
  add_prefix,
@@ -121,7 +125,7 @@ class GptOssSparseMoeBlock(nn.Module):
121
125
  )
122
126
 
123
127
  self.top_k = config.num_experts_per_tok
124
- experts_type = get_moe_impl_class()
128
+ experts_type = get_moe_impl_class(quant_config)
125
129
  extra_kwargs = {}
126
130
  if experts_type.__name__ == "FusedMoE":
127
131
  quant_config_name = (
@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
193
197
  return ans
194
198
 
195
199
 
196
- def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
197
- """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
198
- return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
199
-
200
-
201
- # TODO maybe move to a model-common utils
202
- def _create_fused_set_kv_buffer_arg(
203
- value: torch.Tensor,
204
- layer: RadixAttention,
205
- forward_batch: ForwardBatch,
206
- ):
207
- layer_id = layer.layer_id
208
- token_to_kv_pool = forward_batch.token_to_kv_pool
209
-
210
- k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
211
- v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
212
-
213
- return FusedSetKVBufferArg(
214
- value=value,
215
- k_buffer=k_buffer.view(k_buffer.shape[0], -1),
216
- v_buffer=v_buffer.view(v_buffer.shape[0], -1),
217
- k_scale=layer.k_scale,
218
- v_scale=layer.v_scale,
219
- cache_loc=forward_batch.out_cache_loc,
220
- )
221
-
222
-
223
200
  class GptOssAttention(nn.Module):
224
201
  def __init__(
225
202
  self,
@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
337
314
  q,
338
315
  k,
339
316
  fused_set_kv_buffer_arg=(
340
- _create_fused_set_kv_buffer_arg(
317
+ create_fused_set_kv_buffer_arg(
341
318
  value=v,
342
319
  layer=self.attn,
343
320
  forward_batch=forward_batch,
344
321
  )
345
- if _enable_fused_set_kv_buffer(forward_batch)
322
+ if enable_fused_set_kv_buffer(forward_batch)
346
323
  else None
347
324
  ),
348
325
  )
@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
356
333
  attn_output = self.attn(
357
334
  *inner_state,
358
335
  sinks=self.sinks,
359
- save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
336
+ save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
360
337
  )
361
338
  output, _ = self.o_proj(attn_output)
362
339
  return output
sglang/srt/models/grok.py CHANGED
@@ -49,7 +49,6 @@ from sglang.srt.layers.linear import (
49
49
  RowParallelLinear,
50
50
  )
51
51
  from sglang.srt.layers.logits_processor import LogitsProcessor
52
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
53
52
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
54
53
  from sglang.srt.layers.moe.router import fused_moe_router_shim
55
54
  from sglang.srt.layers.moe.topk import TopK
@@ -176,17 +175,7 @@ class Grok1MoE(nn.Module):
176
175
  custom_routing_function=custom_routing_function,
177
176
  )
178
177
 
179
- kwargs = {}
180
- if get_moe_expert_parallel_world_size() > 1:
181
- MoEImpl = EPMoE
182
- else:
183
- MoEImpl = FusedMoE
184
- kwargs["reduce_results"] = reduce_results
185
- kwargs["use_presharded_weights"] = use_presharded_weights
186
- kwargs["inplace"] = inplace
187
- kwargs["no_combine"] = no_combine
188
-
189
- self.experts = MoEImpl(
178
+ self.experts = FusedMoE(
190
179
  num_experts=num_experts,
191
180
  top_k=top_k,
192
181
  layer_id=layer_id,
@@ -195,7 +184,10 @@ class Grok1MoE(nn.Module):
195
184
  params_dtype=params_dtype,
196
185
  quant_config=quant_config,
197
186
  activation="gelu",
198
- **kwargs,
187
+ reduce_results=reduce_results,
188
+ use_presharded_weights=use_presharded_weights,
189
+ inplace=inplace,
190
+ no_combine=no_combine,
199
191
  )
200
192
 
201
193
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union
49
49
  import torch
50
50
  import torch.nn as nn
51
51
  import torch.nn.functional as F
52
- from transformers.activations import ACT2FN, PytorchGELUTanh
52
+ from transformers.activations import ACT2FN, GELUTanh
53
53
  from transformers.modeling_utils import PreTrainedModel
54
54
 
55
55
  try:
@@ -614,7 +614,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
614
614
  "num_heads": config.num_attention_heads,
615
615
  "hidden_dim": config.hidden_size,
616
616
  "mlp_dim": config.intermediate_size,
617
- "activation": PytorchGELUTanh(),
617
+ "activation": GELUTanh(),
618
618
  "attn_bias": True,
619
619
  "attn_implementation": config._attn_implementation,
620
620
  },
@@ -385,6 +385,10 @@ class LlamaModel(nn.Module):
385
385
  "Self attention has no KV cache scaling " "factor attribute!"
386
386
  )
387
387
 
388
+ def get_input_embeddings(self) -> nn.Embedding:
389
+ """Get input embeddings from the model."""
390
+ return self.embed_tokens
391
+
388
392
 
389
393
  class LlamaForCausalLM(nn.Module):
390
394
  # BitandBytes specific attributes
@@ -423,6 +423,12 @@ class Llama4DecoderLayer(nn.Module):
423
423
  return self.config.num_local_experts > 0
424
424
  return (layer_id + 1) % self.config.interleave_moe_layer_step == 0
425
425
 
426
+ def get_intermediate_size(self) -> int:
427
+ if isinstance(self.feed_forward, Llama4MoE):
428
+ return self.config.intermediate_size
429
+ else:
430
+ return self.config.intermediate_size_mlp
431
+
426
432
  def forward(
427
433
  self,
428
434
  positions: torch.Tensor,
@@ -540,6 +546,9 @@ class Llama4ForCausalLM(LlamaForCausalLM):
540
546
  def get_input_embeddings(self):
541
547
  return self.model.embed_tokens
542
548
 
549
+ def get_layers(self):
550
+ return self.model.layers
551
+
543
552
  def _init_model(
544
553
  self,
545
554
  config: Llama4TextConfig,
@@ -109,6 +109,16 @@ class LlamaModel(nn.Module):
109
109
  ) -> None:
110
110
  super().__init__()
111
111
  self.config = config
112
+
113
+ self.is_mrope_enabled = (
114
+ hasattr(config, "rope_scaling")
115
+ and config.rope_scaling is not None
116
+ and "mrope_section" in config.rope_scaling
117
+ )
118
+ # fix rope_scaling for qwen2.5-vl
119
+ if self.is_mrope_enabled:
120
+ config.rope_scaling["rope_type"] = "default"
121
+
112
122
  self.vocab_size = config.vocab_size
113
123
  self.embed_tokens = VocabParallelEmbedding(
114
124
  config.vocab_size,
@@ -144,6 +154,9 @@ class LlamaModel(nn.Module):
144
154
  else:
145
155
  embeds = input_embeds
146
156
 
157
+ if self.is_mrope_enabled:
158
+ positions = forward_batch.mrope_positions
159
+
147
160
  hidden_states = forward_batch.spec_info.hidden_states
148
161
  if hidden_states.shape[-1] != embeds.shape[-1]:
149
162
  hidden_states = self.fc(hidden_states)
@@ -131,7 +131,7 @@ elif _is_hip:
131
131
  awq_dequantize_triton as awq_dequantize,
132
132
  )
133
133
  else:
134
- from vllm._custom_ops import awq_dequantize
134
+ pass
135
135
 
136
136
  logger = logging.getLogger(__name__)
137
137
 
@@ -260,7 +260,7 @@ class LongcatFlashMoE(nn.Module):
260
260
  )
261
261
  self.topk.forward = self.topk.forward_native
262
262
 
263
- self.experts = get_moe_impl_class()(
263
+ self.experts = get_moe_impl_class(quant_config)(
264
264
  num_experts=self.num_experts,
265
265
  top_k=self.top_k,
266
266
  layer_id=self.layer_id,
@@ -853,7 +853,7 @@ class LongcatFlashForCausalLM(nn.Module):
853
853
 
854
854
  # Params for weights, fp8 weight scales, fp8 activation scales
855
855
  # (param_name, weight_name, expert_id, shard_id)
856
- expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
856
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
857
857
  ckpt_gate_proj_name="gate_proj",
858
858
  ckpt_down_proj_name="down_proj",
859
859
  ckpt_up_proj_name="up_proj",
@@ -111,7 +111,7 @@ elif _is_hip:
111
111
  awq_dequantize_triton as awq_dequantize,
112
112
  )
113
113
  else:
114
- from vllm._custom_ops import awq_dequantize
114
+ pass
115
115
 
116
116
 
117
117
  logger = logging.getLogger(__name__)
@@ -36,7 +36,6 @@ from sglang.srt.layers.linear import (
36
36
  RowParallelLinear,
37
37
  )
38
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
39
- from sglang.srt.layers.moe.ep_moe.layer import EPMoE
40
39
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
41
40
  from sglang.srt.layers.moe.topk import TopK
42
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -94,8 +93,7 @@ class MixtralMoE(nn.Module):
94
93
  renormalize=True,
95
94
  )
96
95
 
97
- MoEImpl = EPMoE if get_moe_expert_parallel_world_size() > 1 else FusedMoE
98
- self.experts = MoEImpl(
96
+ self.experts = FusedMoE(
99
97
  num_experts=num_experts,
100
98
  top_k=top_k,
101
99
  layer_id=layer_id,
@@ -2,6 +2,7 @@ import json as json_lib
2
2
  import logging
3
3
  import math
4
4
  import os
5
+ import re
5
6
  from collections.abc import Iterable
6
7
  from typing import List, Optional, Set, Tuple
7
8
 
@@ -291,7 +292,7 @@ class Llama4UnfoldConvolution(nn.Module):
291
292
 
292
293
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
293
294
  hidden_states = self.unfold(hidden_states)
294
- hidden_states = hidden_states.permute(0, 2, 1)
295
+ hidden_states = hidden_states.permute(0, 2, 1).contiguous()
295
296
  hidden_states, _ = self.linear(hidden_states)
296
297
  return hidden_states
297
298
 
@@ -422,6 +423,11 @@ class Llama4ForConditionalGeneration(nn.Module):
422
423
  "gate_up_proj": ["gate_proj", "up_proj"],
423
424
  }
424
425
 
426
+ # Pattern to match language model layers only (skip vision_model and multi_modal_projector)
427
+ lora_pattern = re.compile(
428
+ r"^language_model\.model\.layers\.(\d+)\.(?:self_attn|mlp)\.(?:qkv_proj|o_proj|down_proj|gate_up_proj)"
429
+ )
430
+
425
431
  def __init__(
426
432
  self,
427
433
  config: Llama4Config,
@@ -446,9 +452,20 @@ class Llama4ForConditionalGeneration(nn.Module):
446
452
  )
447
453
 
448
454
  if self.has_vision:
455
+ # TODO: make this more general
456
+ ignore_quant_layers = getattr(config, "quantization_config", {}).get(
457
+ "ignore", {}
458
+ )
459
+ if (
460
+ "model.layers.vision_model*" in ignore_quant_layers
461
+ and "model.layers.multi_modal_projector*" in ignore_quant_layers
462
+ ):
463
+ vision_quant_config = None
464
+ else:
465
+ vision_quant_config = quant_config
449
466
  self.vision_model = Llama4VisionModel(
450
467
  config.vision_config,
451
- quant_config=quant_config,
468
+ quant_config=vision_quant_config,
452
469
  prefix=add_prefix("vision_model", prefix),
453
470
  )
454
471
 
@@ -544,6 +561,10 @@ class Llama4ForConditionalGeneration(nn.Module):
544
561
 
545
562
  return projected_vision_flat
546
563
 
564
+ def should_apply_lora(self, module_name: str) -> bool:
565
+ """Skip vision model and multi_modal_projector for LoRA."""
566
+ return bool(self.lora_pattern.match(module_name))
567
+
547
568
  def forward(
548
569
  self,
549
570
  input_ids: torch.Tensor,
@@ -560,7 +581,7 @@ class Llama4ForConditionalGeneration(nn.Module):
560
581
  forward_batch=forward_batch,
561
582
  language_model=self.language_model,
562
583
  data_embedding_funcs={
563
- Modality.IMAGE: self.get_image_feature,
584
+ Modality.IMAGE: image_embedding_func,
564
585
  },
565
586
  positions=positions,
566
587
  )
@@ -689,7 +710,7 @@ class Llama4ForConditionalGeneration(nn.Module):
689
710
  """Handle scale parameter remapping. Returns True if handled."""
690
711
  if "scale" in name and "expert" not in name:
691
712
  remapped_name = maybe_remap_kv_scale_name(name, params_dict)
692
- return remapped_name is None
713
+ return remapped_name != name
693
714
  return False
694
715
 
695
716
  def _handle_stacked_params(
@@ -961,5 +982,30 @@ class Llama4ForConditionalGeneration(nn.Module):
961
982
  def set_embed(self, embed):
962
983
  return self.language_model.set_embed(embed)
963
984
 
985
+ def get_hidden_dim(self, module_name, layer_idx):
986
+ # return input_dim, output_dim
987
+ if module_name == "qkv_proj":
988
+ return (
989
+ self.config.hidden_size,
990
+ self.config.head_dim
991
+ * (
992
+ self.config.num_attention_heads
993
+ + self.config.num_key_value_heads * 2
994
+ ),
995
+ )
996
+ elif module_name == "o_proj":
997
+ return (
998
+ self.config.head_dim * self.config.num_attention_heads,
999
+ self.config.hidden_size,
1000
+ )
1001
+ elif module_name == "gate_up_proj":
1002
+ return self.config.hidden_size, self.config.intermediate_size * 2
1003
+ elif module_name == "down_proj":
1004
+ decoder_layer = self.language_model.get_layers()[layer_idx]
1005
+ intermediate_size = decoder_layer.get_intermediate_size()
1006
+ return intermediate_size, self.config.hidden_size
1007
+ else:
1008
+ raise NotImplementedError()
1009
+
964
1010
 
965
1011
  EntryClass = Llama4ForConditionalGeneration