sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (396) hide show
  1. sglang/bench_one_batch.py +7 -11
  2. sglang/bench_one_batch_server.py +330 -31
  3. sglang/bench_serving.py +474 -142
  4. sglang/compile_deep_gemm.py +3 -0
  5. sglang/global_config.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +10 -0
  11. sglang/srt/configs/device_config.py +3 -1
  12. sglang/srt/configs/dots_ocr.py +64 -0
  13. sglang/srt/configs/dots_vlm.py +139 -0
  14. sglang/srt/configs/falcon_h1.py +314 -0
  15. sglang/srt/configs/load_config.py +9 -0
  16. sglang/srt/configs/mamba_utils.py +117 -0
  17. sglang/srt/configs/model_config.py +228 -92
  18. sglang/srt/configs/nemotron_h.py +286 -0
  19. sglang/srt/configs/qwen3_next.py +294 -0
  20. sglang/srt/configs/qwen3_vl.py +586 -0
  21. sglang/srt/connector/__init__.py +8 -1
  22. sglang/srt/connector/remote_instance.py +82 -0
  23. sglang/srt/constrained/base_grammar_backend.py +49 -12
  24. sglang/srt/constrained/llguidance_backend.py +0 -1
  25. sglang/srt/constrained/outlines_backend.py +0 -1
  26. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  27. sglang/srt/constrained/xgrammar_backend.py +30 -9
  28. sglang/srt/custom_op.py +11 -1
  29. sglang/srt/debug_utils/dump_comparator.py +81 -44
  30. sglang/srt/debug_utils/dump_loader.py +97 -0
  31. sglang/srt/debug_utils/dumper.py +21 -6
  32. sglang/srt/debug_utils/text_comparator.py +73 -11
  33. sglang/srt/disaggregation/ascend/conn.py +2 -2
  34. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  35. sglang/srt/disaggregation/base/conn.py +1 -1
  36. sglang/srt/disaggregation/common/conn.py +279 -108
  37. sglang/srt/disaggregation/decode.py +78 -37
  38. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  39. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
  40. sglang/srt/disaggregation/fake/conn.py +1 -1
  41. sglang/srt/disaggregation/mini_lb.py +6 -445
  42. sglang/srt/disaggregation/mooncake/conn.py +55 -537
  43. sglang/srt/disaggregation/nixl/conn.py +373 -68
  44. sglang/srt/disaggregation/prefill.py +53 -49
  45. sglang/srt/disaggregation/utils.py +40 -54
  46. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  47. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  48. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  49. sglang/srt/distributed/parallel_state.py +156 -80
  50. sglang/srt/entrypoints/engine.py +59 -18
  51. sglang/srt/entrypoints/grpc_request_manager.py +842 -0
  52. sglang/srt/entrypoints/grpc_server.py +950 -0
  53. sglang/srt/entrypoints/http_server.py +179 -60
  54. sglang/srt/entrypoints/openai/protocol.py +265 -29
  55. sglang/srt/entrypoints/openai/serving_base.py +65 -3
  56. sglang/srt/entrypoints/openai/serving_chat.py +213 -122
  57. sglang/srt/entrypoints/openai/serving_completions.py +14 -3
  58. sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
  59. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  60. sglang/srt/entrypoints/openai/serving_responses.py +48 -3
  61. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  62. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  63. sglang/srt/environ.py +289 -0
  64. sglang/srt/eplb/eplb_manager.py +2 -2
  65. sglang/srt/eplb/expert_distribution.py +26 -13
  66. sglang/srt/eplb/expert_location.py +38 -8
  67. sglang/srt/eplb/expert_location_updater.py +1 -1
  68. sglang/srt/function_call/base_format_detector.py +3 -6
  69. sglang/srt/function_call/ebnf_composer.py +11 -9
  70. sglang/srt/function_call/function_call_parser.py +17 -8
  71. sglang/srt/function_call/glm4_moe_detector.py +4 -4
  72. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  73. sglang/srt/function_call/json_array_parser.py +63 -0
  74. sglang/srt/function_call/kimik2_detector.py +17 -4
  75. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  76. sglang/srt/function_call/utils.py +96 -5
  77. sglang/srt/grpc/__init__.py +1 -0
  78. sglang/srt/grpc/compile_proto.py +245 -0
  79. sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
  80. sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
  81. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
  82. sglang/srt/layers/activation.py +143 -9
  83. sglang/srt/layers/attention/aiter_backend.py +14 -15
  84. sglang/srt/layers/attention/ascend_backend.py +115 -9
  85. sglang/srt/layers/attention/attention_registry.py +215 -0
  86. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  87. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  88. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  89. sglang/srt/layers/attention/fla/chunk.py +242 -0
  90. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  91. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  92. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  93. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  94. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  95. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  96. sglang/srt/layers/attention/fla/index.py +37 -0
  97. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  98. sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
  99. sglang/srt/layers/attention/fla/op.py +66 -0
  100. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  101. sglang/srt/layers/attention/fla/utils.py +331 -0
  102. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  103. sglang/srt/layers/attention/flashattention_backend.py +40 -8
  104. sglang/srt/layers/attention/flashinfer_backend.py +341 -204
  105. sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
  106. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  107. sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
  108. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
  109. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  110. sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
  111. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
  112. sglang/srt/layers/attention/mamba/mamba.py +577 -0
  113. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  114. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  115. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  116. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  117. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  118. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
  119. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
  120. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
  121. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  122. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
  123. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  124. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  125. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  126. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  127. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  128. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  129. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  130. sglang/srt/layers/attention/nsa/utils.py +24 -0
  131. sglang/srt/layers/attention/nsa_backend.py +887 -0
  132. sglang/srt/layers/attention/tbo_backend.py +6 -6
  133. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  134. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  135. sglang/srt/layers/attention/triton_backend.py +57 -7
  136. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  137. sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
  138. sglang/srt/layers/attention/vision.py +58 -0
  139. sglang/srt/layers/attention/wave_backend.py +4 -4
  140. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  141. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  142. sglang/srt/layers/communicator.py +8 -0
  143. sglang/srt/layers/dp_attention.py +41 -2
  144. sglang/srt/layers/elementwise.py +3 -1
  145. sglang/srt/layers/layernorm.py +34 -15
  146. sglang/srt/layers/linear.py +55 -7
  147. sglang/srt/layers/logits_processor.py +180 -18
  148. sglang/srt/layers/modelopt_utils.py +11 -0
  149. sglang/srt/layers/moe/__init__.py +2 -1
  150. sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
  151. sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
  152. sglang/srt/layers/moe/ep_moe/layer.py +248 -333
  153. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
  154. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  155. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  159. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  160. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  161. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  162. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  163. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  164. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  165. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  166. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  167. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  168. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  169. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
  170. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  171. sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
  172. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  173. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  174. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  175. sglang/srt/layers/moe/moe_runner/runner.py +83 -0
  176. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  177. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  178. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  179. sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
  180. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  181. sglang/srt/layers/moe/topk.py +30 -9
  182. sglang/srt/layers/moe/utils.py +29 -7
  183. sglang/srt/layers/parameter.py +23 -6
  184. sglang/srt/layers/quantization/__init__.py +1 -1
  185. sglang/srt/layers/quantization/awq.py +19 -7
  186. sglang/srt/layers/quantization/base_config.py +11 -6
  187. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  188. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  189. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  190. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  191. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  192. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  193. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  194. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  195. sglang/srt/layers/quantization/fp8.py +155 -60
  196. sglang/srt/layers/quantization/fp8_utils.py +51 -32
  197. sglang/srt/layers/quantization/gptq.py +25 -17
  198. sglang/srt/layers/quantization/modelopt_quant.py +191 -56
  199. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  200. sglang/srt/layers/quantization/mxfp4.py +74 -42
  201. sglang/srt/layers/quantization/quark/quark.py +3 -1
  202. sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
  203. sglang/srt/layers/quantization/unquant.py +135 -47
  204. sglang/srt/layers/quantization/w4afp8.py +28 -33
  205. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  206. sglang/srt/layers/quantization/w8a8_int8.py +91 -41
  207. sglang/srt/layers/rotary_embedding.py +78 -31
  208. sglang/srt/layers/sampler.py +213 -21
  209. sglang/srt/layers/utils.py +23 -0
  210. sglang/srt/lora/backend/base_backend.py +50 -8
  211. sglang/srt/lora/backend/chunked_backend.py +348 -0
  212. sglang/srt/lora/backend/triton_backend.py +99 -5
  213. sglang/srt/lora/layers.py +32 -0
  214. sglang/srt/lora/lora.py +8 -3
  215. sglang/srt/lora/lora_manager.py +44 -118
  216. sglang/srt/lora/mem_pool.py +25 -11
  217. sglang/srt/lora/triton_ops/__init__.py +4 -0
  218. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  219. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  220. sglang/srt/lora/utils.py +22 -11
  221. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  222. sglang/srt/managers/cache_controller.py +199 -301
  223. sglang/srt/managers/data_parallel_controller.py +115 -80
  224. sglang/srt/managers/detokenizer_manager.py +19 -15
  225. sglang/srt/managers/disagg_service.py +46 -0
  226. sglang/srt/managers/io_struct.py +340 -109
  227. sglang/srt/managers/mm_utils.py +44 -6
  228. sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
  229. sglang/srt/managers/multimodal_processor.py +1 -2
  230. sglang/srt/managers/overlap_utils.py +55 -0
  231. sglang/srt/managers/schedule_batch.py +343 -212
  232. sglang/srt/managers/schedule_policy.py +145 -18
  233. sglang/srt/managers/scheduler.py +653 -273
  234. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  235. sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
  236. sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
  237. sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
  238. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  239. sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
  240. sglang/srt/managers/tokenizer_manager.py +579 -674
  241. sglang/srt/managers/tp_worker.py +96 -26
  242. sglang/srt/managers/utils.py +1 -45
  243. sglang/srt/mem_cache/allocator.py +21 -22
  244. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  245. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  246. sglang/srt/mem_cache/chunk_cache.py +9 -2
  247. sglang/srt/mem_cache/evict_policy.py +23 -0
  248. sglang/srt/mem_cache/hicache_storage.py +43 -24
  249. sglang/srt/mem_cache/hiradix_cache.py +222 -75
  250. sglang/srt/mem_cache/memory_pool.py +651 -80
  251. sglang/srt/mem_cache/memory_pool_host.py +239 -228
  252. sglang/srt/mem_cache/radix_cache.py +227 -73
  253. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  254. sglang/srt/mem_cache/storage/__init__.py +10 -0
  255. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  256. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  257. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  258. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  259. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  260. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  261. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  262. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
  263. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
  264. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  265. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
  266. sglang/srt/mem_cache/swa_radix_cache.py +93 -48
  267. sglang/srt/metrics/collector.py +511 -132
  268. sglang/srt/metrics/func_timer.py +2 -7
  269. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  270. sglang/srt/metrics/utils.py +8 -1
  271. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  272. sglang/srt/model_executor/cuda_graph_runner.py +52 -37
  273. sglang/srt/model_executor/forward_batch_info.py +74 -46
  274. sglang/srt/model_executor/model_runner.py +455 -176
  275. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  276. sglang/srt/model_loader/__init__.py +10 -4
  277. sglang/srt/model_loader/loader.py +319 -10
  278. sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
  279. sglang/srt/model_loader/weight_utils.py +161 -3
  280. sglang/srt/models/apertus.py +686 -0
  281. sglang/srt/models/bailing_moe.py +820 -217
  282. sglang/srt/models/bailing_moe_nextn.py +168 -0
  283. sglang/srt/models/deepseek_nextn.py +6 -1
  284. sglang/srt/models/deepseek_v2.py +607 -130
  285. sglang/srt/models/dots_ocr.py +173 -0
  286. sglang/srt/models/dots_vlm.py +174 -0
  287. sglang/srt/models/dots_vlm_vit.py +337 -0
  288. sglang/srt/models/ernie4.py +1 -1
  289. sglang/srt/models/falcon_h1.py +578 -0
  290. sglang/srt/models/gemma3_causal.py +0 -2
  291. sglang/srt/models/gemma3_mm.py +17 -1
  292. sglang/srt/models/gemma3n_mm.py +2 -2
  293. sglang/srt/models/glm4_moe.py +4 -4
  294. sglang/srt/models/glm4_moe_nextn.py +2 -2
  295. sglang/srt/models/glm4v.py +5 -3
  296. sglang/srt/models/glm4v_moe.py +4 -1
  297. sglang/srt/models/gpt_oss.py +8 -31
  298. sglang/srt/models/grok.py +5 -13
  299. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  300. sglang/srt/models/llama.py +4 -0
  301. sglang/srt/models/llama4.py +9 -0
  302. sglang/srt/models/llama_eagle3.py +13 -0
  303. sglang/srt/models/longcat_flash.py +3 -3
  304. sglang/srt/models/longcat_flash_nextn.py +1 -1
  305. sglang/srt/models/mixtral.py +1 -3
  306. sglang/srt/models/mllama4.py +50 -4
  307. sglang/srt/models/nemotron_h.py +514 -0
  308. sglang/srt/models/opt.py +637 -0
  309. sglang/srt/models/qwen2_5_vl.py +29 -5
  310. sglang/srt/models/qwen2_audio.py +1 -1
  311. sglang/srt/models/qwen2_moe.py +120 -13
  312. sglang/srt/models/qwen2_vl.py +1 -1
  313. sglang/srt/models/qwen3.py +18 -3
  314. sglang/srt/models/qwen3_moe.py +32 -4
  315. sglang/srt/models/qwen3_next.py +1069 -0
  316. sglang/srt/models/qwen3_next_mtp.py +112 -0
  317. sglang/srt/models/qwen3_vl.py +787 -0
  318. sglang/srt/models/qwen3_vl_moe.py +471 -0
  319. sglang/srt/models/registry.py +15 -3
  320. sglang/srt/models/sarashina2_vision.py +269 -0
  321. sglang/srt/models/solar.py +505 -0
  322. sglang/srt/models/starcoder2.py +357 -0
  323. sglang/srt/models/step3_vl.py +1 -1
  324. sglang/srt/models/torch_native_llama.py +9 -2
  325. sglang/srt/models/utils.py +55 -0
  326. sglang/srt/multimodal/processors/base_processor.py +15 -7
  327. sglang/srt/multimodal/processors/dots_vlm.py +98 -0
  328. sglang/srt/multimodal/processors/glm4v.py +9 -9
  329. sglang/srt/multimodal/processors/internvl.py +153 -129
  330. sglang/srt/multimodal/processors/qwen_vl.py +23 -6
  331. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  332. sglang/srt/offloader.py +27 -3
  333. sglang/srt/parser/jinja_template_utils.py +6 -0
  334. sglang/srt/sampling/sampling_batch_info.py +49 -26
  335. sglang/srt/sampling/sampling_params.py +7 -0
  336. sglang/srt/server_args.py +1051 -285
  337. sglang/srt/server_args_config_parser.py +146 -0
  338. sglang/srt/single_batch_overlap.py +151 -0
  339. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  340. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  341. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  342. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  343. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  344. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  345. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
  346. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
  347. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
  348. sglang/srt/speculative/eagle_worker.py +98 -29
  349. sglang/srt/speculative/ngram_info.py +428 -0
  350. sglang/srt/speculative/ngram_worker.py +246 -0
  351. sglang/srt/speculative/spec_info.py +52 -0
  352. sglang/srt/speculative/spec_utils.py +605 -0
  353. sglang/srt/speculative/standalone_worker.py +109 -0
  354. sglang/srt/torch_memory_saver_adapter.py +5 -7
  355. sglang/srt/tracing/trace.py +578 -0
  356. sglang/srt/two_batch_overlap.py +9 -5
  357. sglang/srt/utils/__init__.py +2 -0
  358. sglang/srt/{utils.py → utils/common.py} +451 -77
  359. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
  360. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  361. sglang/srt/utils/rpd_utils.py +452 -0
  362. sglang/srt/utils/slow_rank_detector.py +71 -0
  363. sglang/srt/warmup.py +8 -4
  364. sglang/srt/weight_sync/utils.py +2 -2
  365. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  366. sglang/test/get_logits_ut.py +57 -0
  367. sglang/test/longbench_v2/__init__.py +1 -0
  368. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  369. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  370. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  371. sglang/test/run_eval.py +119 -11
  372. sglang/test/runners.py +5 -1
  373. sglang/test/simple_eval_common.py +5 -2
  374. sglang/test/simple_eval_longbench_v2.py +332 -0
  375. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  376. sglang/test/test_block_fp8.py +2 -2
  377. sglang/test/test_cutlass_moe.py +24 -6
  378. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  379. sglang/test/test_deterministic.py +313 -0
  380. sglang/test/test_deterministic_utils.py +81 -0
  381. sglang/test/test_disaggregation_utils.py +140 -0
  382. sglang/test/test_fp4_moe.py +370 -1
  383. sglang/test/test_programs.py +1 -1
  384. sglang/test/test_utils.py +407 -8
  385. sglang/utils.py +21 -1
  386. sglang/version.py +1 -1
  387. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
  388. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
  389. sglang/srt/disaggregation/launch_lb.py +0 -118
  390. sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
  391. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  392. sglang/test/test_block_fp8_ep.py +0 -358
  393. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  394. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  395. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  396. {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,286 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/nemotron_h.py
15
+
16
+ """NemotronH model configuration"""
17
+
18
+ import regex as re
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
23
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ MAMBA = "M"
28
+ ATTENTION = "*"
29
+ MLP = "-"
30
+
31
+
32
+ class NemotronHConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a
35
+ [`NemotronHModel`]. It is used to instantiate a NemotronH model according
36
+ to the specified arguments, defining the model architecture. Instantiating
37
+ a configuration with the defaults will yield a similar configuration to
38
+ that of the NemotronH-v0.1 model.
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 131072):
41
+ Vocabulary size of the NemotronH model. Defines the number of
42
+ different tokens that can be represented by the `inputs_ids`
43
+ passed when calling [`NemotronHModel`]
44
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
45
+ Whether the model's input and output word embeddings should be
46
+ tied. Note that this is only relevant if the model has an output
47
+ word embedding layer.
48
+ hidden_size (`int`, *optional*, defaults to 4096):
49
+ Dimension of the hidden representations.
50
+ intermediate_size (`int`, *optional*, defaults to 21504):
51
+ Dimension of the MLP representations.
52
+ num_hidden_layers (`int`, *optional*, defaults to 52):
53
+ Number of hidden layers in the Transformer encoder.
54
+ hybrid_override_pattern (`str`, *optional*, defaults to
55
+ `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
56
+ The pattern of the hybrid model. The pattern is a string of
57
+ characters where each character represents
58
+ M: Mamba2, *: Attention, -: MLP
59
+ num_attention_heads (`int`, *optional*, defaults to 32):
60
+ Number of attention heads for each attention layer in the
61
+ Transformer encoder.
62
+ attention_head_dim (`int`, *optional*, defaults to 128):
63
+ Dimension of each attention head.
64
+ num_key_value_heads (`int`, *optional*, defaults to 8):
65
+ This is the number of key_value heads that should be used to
66
+ implement Grouped Query Attention. If
67
+ `num_key_value_heads=num_attention_heads`, the model will use
68
+ Multi Head Attention (MHA), if `num_key_value_heads=1` the model
69
+ will use Multi Query Attention (MQA) otherwise GQA is used.
70
+ mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
71
+ The non-linear activation function in the MLP layers.
72
+ attention_bias (`bool`, *optional*, defaults to `False`):
73
+ Whether to use bias in attention layers.
74
+ mlp_bias (`bool`, *optional*, defaults to `False`):
75
+ Whether to use bias in MLP layers.
76
+ use_bias (`bool`, *optional*, defaults to `False`):
77
+ Whether to use bias in the model.
78
+ initializer_range (`float`, *optional*, defaults to 0.02):
79
+ The standard deviation of the truncated_normal_initializer for
80
+ initializing all weight matrices.
81
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
82
+ The epsilon used by the layer normalization layers.
83
+ residual_in_fp32 (`bool`, *optional*, defaults to `False`):
84
+ Whether or not residuals should be in `float32`. If set to `False`
85
+ residuals will keep the same `dtype` as the rest of the model.
86
+ use_cache (`bool`, *optional*, defaults to `True`):
87
+ Whether or not the model should return the last key/values
88
+ attentions (not used by all models). Only relevant if
89
+ `config.is_decoder=True`.
90
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
91
+ Number of prompt logits to calculate during generation. If `None`,
92
+ all logits will be calculated. If an integer value, only last
93
+ `num_logits_to_keep` logits will be calculated.
94
+ pad_token_id (`int`, *optional*, defaults to 0):
95
+ The id of the padding token.
96
+ bos_token_id (`int`, *optional*, defaults to 1):
97
+ The id of the "beginning-of-sequence" token.
98
+ eos_token_id (`int`, *optional*, defaults to 2):
99
+ The id of the "end-of-sequence" token.
100
+ sliding_window (`int`, *optional*, defaults to None):
101
+ Sliding window attention window size.
102
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
103
+ The maximum sequence length that this model might ever be used
104
+ with.
105
+ attention_dropout (`float`, *optional*, defaults to 0.0):
106
+ The dropout ratio for the attention probabilities.
107
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
108
+ The dropout ratio for the hidden states.
109
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
110
+ Flag indicating whether or not to use the fast mamba kernels.
111
+ These are available only if `mamba-ssm` and `causal-conv1d`
112
+ are installed, and the mamba modules are running on a CUDA device.
113
+ ssm_state_size (`int`, *optional*, defaults to 128):
114
+ The dimension of the mamba state space latents.
115
+ mamba_num_heads (`int`, *optional*, defaults to 128):
116
+ Number of heads in Mamba layers.
117
+ mamba_n_groups (`int`, *optional*, defaults to 8):
118
+ Number of groups in Mamba layers.
119
+ mamba_head_dim (`int`, *optional*, defaults to 64):
120
+ Dimension of each Mamba head.
121
+ mamba_d_conv (`int`, *optional*, defaults to 4):
122
+ The size of the mamba convolution kernel.
123
+ mamba_expand (`int`, *optional*, defaults to 2):
124
+ Expanding factor used to determine the mamba intermediate size.
125
+ mamba_hidden_act (`str`, *optional*, defaults to "silu"):
126
+ The non-linear activation function in the Mamba layers.
127
+ mamba_dt_min (`float`, *optional*, defaults to 0.001):
128
+ Minimum value for the time step in Mamba.
129
+ mamba_dt_max (`float`, *optional*, defaults to 0.1):
130
+ Maximum value for the time step in Mamba.
131
+ mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
132
+ Limits for the time step in Mamba.
133
+ mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
134
+ Floor value for time step initialization in Mamba.
135
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
136
+ Whether to use bias in the convolution layer of the mamba mixer
137
+ block.
138
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
139
+ Whether to use bias in the input and output projections of the
140
+ mamba mixer block.
141
+ mamba_chunk_size (`int`, *optional*, defaults to 256):
142
+ Size of chunks for Mamba processing.
143
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
144
+ Whether to rescale the pre-normalization residual connections.
145
+ """
146
+
147
+ model_type = "nemotron_h"
148
+ keys_to_ignore_at_inference = ["past_key_values"]
149
+
150
+ def __init__(
151
+ self,
152
+ vocab_size=131072,
153
+ tie_word_embeddings=False,
154
+ hidden_size=4096,
155
+ intermediate_size=21504,
156
+ num_hidden_layers=52,
157
+ hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
158
+ num_attention_heads=32,
159
+ head_dim=128,
160
+ num_key_value_heads=8, # nemo: num_query_groups
161
+ mlp_hidden_act="relu2",
162
+ attention_bias=False,
163
+ mlp_bias=False,
164
+ use_bias=False,
165
+ initializer_range=0.02, # nemo: init_method_std
166
+ layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
167
+ residual_in_fp32=False, # Megatron Core default value
168
+ use_cache=True,
169
+ num_logits_to_keep=1,
170
+ pad_token_id=0,
171
+ bos_token_id=1,
172
+ eos_token_id=2,
173
+ sliding_window=None,
174
+ max_position_embeddings=4096,
175
+ attention_dropout=0.0,
176
+ hidden_dropout=0.0, # * ADDED
177
+ use_mamba_kernels=True,
178
+ ssm_state_size=128, # mamba_state_size
179
+ mamba_num_heads=128,
180
+ mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
181
+ mamba_head_dim=64,
182
+ mamba_d_conv=4,
183
+ mamba_expand=2,
184
+ mamba_hidden_act="silu",
185
+ mamba_dt_min=0.001,
186
+ mamba_dt_max=0.1,
187
+ mamba_dt_limit=(0.0, float("inf")),
188
+ mamba_dt_init_floor=1e-4,
189
+ mamba_conv_bias=True,
190
+ mamba_proj_bias=False,
191
+ mamba_chunk_size=256,
192
+ rescale_prenorm_residual=True,
193
+ **kwargs,
194
+ ):
195
+ self.vocab_size = vocab_size
196
+ self.tie_word_embeddings = tie_word_embeddings
197
+ self.hidden_size = hidden_size
198
+ self.intermediate_size = intermediate_size
199
+ self.num_hidden_layers = num_hidden_layers
200
+ self.hybrid_override_pattern = hybrid_override_pattern
201
+ self.num_attention_heads = num_attention_heads
202
+ self.head_dim = head_dim
203
+ self.sliding_window = sliding_window
204
+ self.max_position_embeddings = max_position_embeddings
205
+ self.attention_dropout = attention_dropout
206
+ self.hidden_dropout = hidden_dropout
207
+
208
+ # Validate hybrid_override_pattern
209
+ # M: Mamba2, *: Attention, -: MLP
210
+ assert len(self.hybrid_override_pattern) == self.num_hidden_layers, (
211
+ "hybrid_override_pattern must have same length as " "num_hidden_layers"
212
+ )
213
+ assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), (
214
+ "hybrid_override_pattern must only contain characters " "'M', '*', or '-'"
215
+ )
216
+
217
+ # for backward compatibility
218
+ if num_key_value_heads is None:
219
+ num_key_value_heads = num_attention_heads
220
+
221
+ self.num_key_value_heads = num_key_value_heads
222
+ self.mlp_hidden_act = mlp_hidden_act
223
+ self.attention_bias = attention_bias
224
+ self.mlp_bias = mlp_bias
225
+ self.use_bias = use_bias
226
+ self.initializer_range = initializer_range
227
+ self.layer_norm_epsilon = layer_norm_epsilon
228
+ self.residual_in_fp32 = residual_in_fp32
229
+
230
+ self.use_cache = use_cache
231
+ self.num_logits_to_keep = num_logits_to_keep
232
+
233
+ self.use_mamba_kernels = use_mamba_kernels
234
+ self.mamba_n_groups = mamba_n_groups
235
+ self.mamba_head_dim = mamba_head_dim
236
+ self.ssm_state_size = ssm_state_size
237
+ self.mamba_num_heads = mamba_num_heads
238
+ self.conv_kernel = mamba_d_conv
239
+ self.expand = mamba_expand
240
+ self.mamba_hidden_act = mamba_hidden_act
241
+ self.time_step_min = mamba_dt_min
242
+ self.time_step_max = mamba_dt_max
243
+ self.time_step_limit = mamba_dt_limit
244
+ self.time_step_floor = mamba_dt_init_floor
245
+ self.use_conv_bias = mamba_conv_bias
246
+ self.mamba_proj_bias = mamba_proj_bias
247
+ self.mamba_chunk_size = mamba_chunk_size
248
+ self.rescale_prenorm_residual = rescale_prenorm_residual
249
+
250
+ super().__init__(
251
+ pad_token_id=pad_token_id,
252
+ bos_token_id=bos_token_id,
253
+ eos_token_id=eos_token_id,
254
+ tie_word_embeddings=tie_word_embeddings,
255
+ **kwargs,
256
+ )
257
+
258
+ @property
259
+ def mamba_layer_ids(self):
260
+ return [
261
+ i
262
+ for i in range(self.num_hidden_layers)
263
+ if self.hybrid_override_pattern[i] == MAMBA
264
+ ]
265
+
266
+ @property
267
+ def full_attention_layer_ids(self):
268
+ return [
269
+ i
270
+ for i in range(self.num_hidden_layers)
271
+ if self.hybrid_override_pattern[i] == ATTENTION
272
+ ]
273
+
274
+ @property
275
+ def mamba2_cache_params(self) -> Mamba2CacheParams:
276
+ shape = Mamba2StateShape.create(
277
+ tp_world_size=get_attention_tp_size(),
278
+ intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
279
+ n_groups=self.n_groups,
280
+ num_heads=self.mamba_num_heads,
281
+ head_dim=self.mamba_head_dim,
282
+ state_size=self.ssm_state_size,
283
+ conv_kernel=self.conv_kernel,
284
+ )
285
+
286
+ return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids)
@@ -0,0 +1,294 @@
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen3Hybrid model configuration"""
16
+
17
+ import enum
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.modeling_rope_utils import rope_config_validation
21
+ from transformers.utils import logging
22
+
23
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
24
+ from sglang.srt.distributed.utils import divide
25
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+ # NOTE: HybridLayerType
31
+ class HybridLayerType(enum.Enum):
32
+ full_attention = "attention"
33
+ swa_attention = "swa_attention"
34
+ linear_attention = "linear_attention"
35
+ mamba2 = "mamba"
36
+
37
+
38
+ class Qwen3NextConfig(PretrainedConfig):
39
+ r"""
40
+ This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
41
+ Qwen3-Next model according to the specified arguments, defining the model architecture.
42
+ Instantiating a configuration with the defaults will yield a similar configuration to that of
43
+ Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).
44
+
45
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
46
+ documentation from [`PretrainedConfig`] for more information.
47
+
48
+
49
+ Args:
50
+ vocab_size (`int`, *optional*, defaults to 151936):
51
+ Vocabulary size of the model. Defines the number of different tokens that can be represented by the
52
+ `inputs_ids`.
53
+ hidden_size (`int`, *optional*, defaults to 2048):
54
+ Dimension of the hidden representations.
55
+ intermediate_size (`int`, *optional*, defaults to 5632):
56
+ Dimension of the MLP representations.
57
+ num_hidden_layers (`int`, *optional*, defaults to 48):
58
+ Number of hidden layers in the Transformer encoder.
59
+ num_attention_heads (`int`, *optional*, defaults to 16):
60
+ Number of attention heads for each attention layer in the Transformer encoder.
61
+ num_key_value_heads (`int`, *optional*, defaults to 2):
62
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
63
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
64
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
65
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
66
+ by meanpooling all the original heads within that group. For more details checkout [this
67
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
68
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
69
+ The non-linear activation function in the decoder.
70
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
71
+ The maximum sequence length that this model might ever be used with.
72
+ initializer_range (`float`, *optional*, defaults to 0.02):
73
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
74
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
75
+ The epsilon used by the rms normalization layers.
76
+ use_cache (`bool`, *optional*, defaults to `True`):
77
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
78
+ relevant if `config.is_decoder=True`.
79
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
80
+ Whether the model's input and output word embeddings should be tied.
81
+ rope_theta (`float`, *optional*, defaults to 10000.0):
82
+ The base period of the RoPE embeddings.
83
+ rope_scaling (`Dict`, *optional*):
84
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
85
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
86
+ accordingly.
87
+ Expected contents:
88
+ `rope_type` (`str`):
89
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
90
+ 'llama3'], with 'default' being the original RoPE implementation.
91
+ `factor` (`float`, *optional*):
92
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
93
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
94
+ original maximum pre-trained length.
95
+ `original_max_position_embeddings` (`int`, *optional*):
96
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
97
+ pretraining.
98
+ `attention_factor` (`float`, *optional*):
99
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
100
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
101
+ `factor` field to infer the suggested value.
102
+ `beta_fast` (`float`, *optional*):
103
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
104
+ ramp function. If unspecified, it defaults to 32.
105
+ `beta_slow` (`float`, *optional*):
106
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
107
+ ramp function. If unspecified, it defaults to 1.
108
+ `short_factor` (`List[float]`, *optional*):
109
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
110
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
111
+ size divided by the number of attention heads divided by 2
112
+ `long_factor` (`List[float]`, *optional*):
113
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
114
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
115
+ size divided by the number of attention heads divided by 2
116
+ `low_freq_factor` (`float`, *optional*):
117
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
118
+ `high_freq_factor` (`float`, *optional*):
119
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
120
+ partial_rotary_factor (`float`, *optional*, defaults to 0.25):
121
+ Percentage of the query and keys which will have rotary embedding.
122
+ attention_bias (`bool`, *optional*, defaults to `False`):
123
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
124
+ attention_dropout (`float`, *optional*, defaults to 0.0):
125
+ The dropout ratio for the attention probabilities.
126
+ head_dim (`int`, *optional*, defaults to 256):
127
+ Projection weights dimension in multi-head attention.
128
+ linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
129
+ Kernel size of the convolution used in linear attention layers.
130
+ linear_key_head_dim (`int`, *optional*, defaults to 128):
131
+ Dimension of each key head in linear attention.
132
+ linear_value_head_dim (`int`, *optional*, defaults to 128):
133
+ Dimension of each value head in linear attention.
134
+ linear_num_key_heads (`int`, *optional*, defaults to 16):
135
+ Number of key heads used in linear attention layers.
136
+ linear_num_value_heads (`int`, *optional*, defaults to 32):
137
+ Number of value heads used in linear attention layers.
138
+ decoder_sparse_step (`int`, *optional*, defaults to 1):
139
+ The frequency of the MoE layer.
140
+ moe_intermediate_size (`int`, *optional*, defaults to 512):
141
+ Intermediate size of the routed expert.
142
+ shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
143
+ Intermediate size of the shared expert.
144
+ num_experts_per_tok (`int`, *optional*, defaults to 10):
145
+ Number of selected experts.
146
+ num_experts (`int`, *optional*, defaults to 512):
147
+ Number of routed experts.
148
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
149
+ Whether to normalize the topk probabilities.
150
+ output_router_logits (`bool`, *optional*, defaults to `False`):
151
+ Whether or not the router logits should be returned by the model. Enabling this will also
152
+ allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
153
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
154
+ The aux loss factor for the total loss.
155
+ mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
156
+ Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
157
+ The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
158
+ If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
159
+ layer_types (`list[str]`, *optional*, defaults to None):
160
+ Types of each layer (attention or linear).
161
+
162
+ ```python
163
+ >>> from transformers import Qwen3NextModel, Qwen3NextConfig
164
+
165
+ >>> # Initializing a Qwen3Next style configuration
166
+ >>> configuration = Qwen3NextConfig()
167
+
168
+ >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
169
+ >>> model = Qwen3NextModel(configuration)
170
+
171
+ >>> # Accessing the model configuration
172
+ >>> configuration = model.config
173
+ ```
174
+ """
175
+
176
+ model_type = "qwen3_next"
177
+ keys_to_ignore_at_inference = ["past_key_values"]
178
+
179
+ def __init__(
180
+ self,
181
+ vocab_size=151936,
182
+ hidden_size=2048,
183
+ intermediate_size=5632,
184
+ num_hidden_layers=48,
185
+ num_attention_heads=16,
186
+ num_key_value_heads=2,
187
+ hidden_act="silu",
188
+ max_position_embeddings=32768,
189
+ initializer_range=0.02,
190
+ rms_norm_eps=1e-6,
191
+ use_cache=True,
192
+ tie_word_embeddings=False,
193
+ rope_theta=10000.0,
194
+ rope_scaling=None,
195
+ partial_rotary_factor=0.25,
196
+ attention_bias=False,
197
+ attention_dropout=0.0,
198
+ head_dim=256,
199
+ linear_conv_kernel_dim=4,
200
+ linear_key_head_dim=128,
201
+ linear_value_head_dim=128,
202
+ linear_num_key_heads=16,
203
+ linear_num_value_heads=32,
204
+ decoder_sparse_step=1,
205
+ moe_intermediate_size=512,
206
+ shared_expert_intermediate_size=512,
207
+ num_experts_per_tok=10,
208
+ num_experts=512,
209
+ norm_topk_prob=True,
210
+ output_router_logits=False,
211
+ router_aux_loss_coef=0.001,
212
+ mlp_only_layers=[],
213
+ layer_types=None,
214
+ **kwargs,
215
+ ):
216
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
217
+ self.vocab_size = vocab_size
218
+ self.max_position_embeddings = max_position_embeddings
219
+ self.hidden_size = hidden_size
220
+ self.intermediate_size = intermediate_size
221
+ self.num_hidden_layers = num_hidden_layers
222
+ self.num_attention_heads = num_attention_heads
223
+ self.num_key_value_heads = num_key_value_heads
224
+ self.hidden_act = hidden_act
225
+ self.initializer_range = initializer_range
226
+ self.rms_norm_eps = rms_norm_eps
227
+ self.use_cache = use_cache
228
+ self.rope_theta = rope_theta
229
+ self.rope_scaling = rope_scaling
230
+ self.partial_rotary_factor = partial_rotary_factor
231
+ self.attention_bias = attention_bias
232
+ self.attention_dropout = attention_dropout
233
+ self.head_dim = head_dim
234
+ rope_config_validation(self)
235
+
236
+ # linear attention (gdn now part)
237
+ self.linear_conv_kernel_dim = linear_conv_kernel_dim
238
+ self.linear_key_head_dim = linear_key_head_dim
239
+ self.linear_value_head_dim = linear_value_head_dim
240
+ self.linear_num_key_heads = linear_num_key_heads
241
+ self.linear_num_value_heads = linear_num_value_heads
242
+
243
+ # MoE arguments
244
+ self.decoder_sparse_step = decoder_sparse_step
245
+ self.moe_intermediate_size = moe_intermediate_size
246
+ self.shared_expert_intermediate_size = shared_expert_intermediate_size
247
+ self.num_experts_per_tok = num_experts_per_tok
248
+ self.num_experts = num_experts
249
+ self.norm_topk_prob = norm_topk_prob
250
+ self.output_router_logits = output_router_logits
251
+ self.router_aux_loss_coef = router_aux_loss_coef
252
+ self.mlp_only_layers = mlp_only_layers
253
+
254
+ @property
255
+ def layers_block_type(self):
256
+ layer_type_list = []
257
+
258
+ for l in range(self.num_hidden_layers):
259
+ if (l + 1) % self.full_attention_interval == 0:
260
+ layer_type_list.append(HybridLayerType.full_attention.value)
261
+ else:
262
+ layer_type_list.append(HybridLayerType.linear_attention.value)
263
+
264
+ return layer_type_list
265
+
266
+ @property
267
+ def linear_layer_ids(self):
268
+ return [
269
+ i
270
+ for i, type_value in enumerate(self.layers_block_type)
271
+ if type_value == HybridLayerType.linear_attention.value
272
+ ]
273
+
274
+ @property
275
+ def full_attention_layer_ids(self):
276
+ return [
277
+ i
278
+ for i, type_value in enumerate(self.layers_block_type)
279
+ if type_value == HybridLayerType.full_attention.value
280
+ ]
281
+
282
+ @property
283
+ def mamba2_cache_params(self) -> Mamba2CacheParams:
284
+ shape = Mamba2StateShape.create(
285
+ tp_world_size=get_attention_tp_size(),
286
+ intermediate_size=self.linear_value_head_dim * self.linear_num_value_heads,
287
+ n_groups=self.linear_num_key_heads,
288
+ num_heads=self.linear_num_value_heads,
289
+ head_dim=self.linear_value_head_dim,
290
+ state_size=self.linear_key_head_dim,
291
+ conv_kernel=self.linear_conv_kernel_dim,
292
+ )
293
+
294
+ return Mamba2CacheParams(shape=shape, layers=self.linear_layer_ids)