sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (408) hide show
  1. sglang/bench_one_batch.py +47 -28
  2. sglang/bench_one_batch_server.py +41 -25
  3. sglang/bench_serving.py +330 -156
  4. sglang/check_env.py +1 -1
  5. sglang/compile_deep_gemm.py +6 -2
  6. sglang/global_config.py +1 -25
  7. sglang/lang/api.py +6 -0
  8. sglang/lang/interpreter.py +1 -0
  9. sglang/lang/ir.py +13 -0
  10. sglang/launch_server.py +8 -15
  11. sglang/profiler.py +18 -1
  12. sglang/srt/_custom_ops.py +1 -1
  13. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
  14. sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
  15. sglang/srt/compilation/backend.py +437 -0
  16. sglang/srt/compilation/compilation_config.py +20 -0
  17. sglang/srt/compilation/compilation_counter.py +47 -0
  18. sglang/srt/compilation/compile.py +210 -0
  19. sglang/srt/compilation/compiler_interface.py +503 -0
  20. sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
  21. sglang/srt/compilation/fix_functionalization.py +134 -0
  22. sglang/srt/compilation/fx_utils.py +83 -0
  23. sglang/srt/compilation/inductor_pass.py +140 -0
  24. sglang/srt/compilation/pass_manager.py +66 -0
  25. sglang/srt/compilation/piecewise_context_manager.py +40 -0
  26. sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
  27. sglang/srt/configs/__init__.py +4 -0
  28. sglang/srt/configs/deepseek_ocr.py +262 -0
  29. sglang/srt/configs/deepseekvl2.py +194 -96
  30. sglang/srt/configs/dots_vlm.py +2 -7
  31. sglang/srt/configs/falcon_h1.py +13 -64
  32. sglang/srt/configs/load_config.py +25 -2
  33. sglang/srt/configs/mamba_utils.py +117 -0
  34. sglang/srt/configs/model_config.py +134 -23
  35. sglang/srt/configs/modelopt_config.py +30 -0
  36. sglang/srt/configs/nemotron_h.py +286 -0
  37. sglang/srt/configs/olmo3.py +105 -0
  38. sglang/srt/configs/points_v15_chat.py +29 -0
  39. sglang/srt/configs/qwen3_next.py +11 -47
  40. sglang/srt/configs/qwen3_omni.py +613 -0
  41. sglang/srt/configs/qwen3_vl.py +0 -10
  42. sglang/srt/connector/remote_instance.py +1 -1
  43. sglang/srt/constrained/base_grammar_backend.py +5 -1
  44. sglang/srt/constrained/llguidance_backend.py +5 -0
  45. sglang/srt/constrained/outlines_backend.py +1 -1
  46. sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
  47. sglang/srt/constrained/utils.py +12 -0
  48. sglang/srt/constrained/xgrammar_backend.py +20 -11
  49. sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
  50. sglang/srt/disaggregation/base/conn.py +17 -4
  51. sglang/srt/disaggregation/common/conn.py +4 -2
  52. sglang/srt/disaggregation/decode.py +123 -31
  53. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  54. sglang/srt/disaggregation/fake/conn.py +11 -3
  55. sglang/srt/disaggregation/mooncake/conn.py +157 -19
  56. sglang/srt/disaggregation/nixl/conn.py +69 -24
  57. sglang/srt/disaggregation/prefill.py +96 -270
  58. sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
  59. sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
  60. sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
  61. sglang/srt/distributed/device_communicators/pynccl.py +24 -12
  62. sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
  63. sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
  64. sglang/srt/distributed/naive_distributed.py +5 -4
  65. sglang/srt/distributed/parallel_state.py +70 -19
  66. sglang/srt/elastic_ep/elastic_ep.py +74 -0
  67. sglang/srt/entrypoints/context.py +3 -2
  68. sglang/srt/entrypoints/engine.py +66 -66
  69. sglang/srt/entrypoints/grpc_server.py +431 -234
  70. sglang/srt/entrypoints/harmony_utils.py +2 -2
  71. sglang/srt/entrypoints/http_server.py +120 -8
  72. sglang/srt/entrypoints/http_server_engine.py +1 -7
  73. sglang/srt/entrypoints/openai/protocol.py +225 -37
  74. sglang/srt/entrypoints/openai/serving_base.py +49 -2
  75. sglang/srt/entrypoints/openai/serving_chat.py +29 -74
  76. sglang/srt/entrypoints/openai/serving_classify.py +204 -0
  77. sglang/srt/entrypoints/openai/serving_completions.py +15 -1
  78. sglang/srt/entrypoints/openai/serving_responses.py +5 -2
  79. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  80. sglang/srt/environ.py +42 -4
  81. sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
  82. sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
  83. sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
  84. sglang/srt/eplb/expert_distribution.py +3 -4
  85. sglang/srt/eplb/expert_location_dispatch.py +2 -2
  86. sglang/srt/eplb/expert_location_updater.py +2 -2
  87. sglang/srt/function_call/base_format_detector.py +17 -18
  88. sglang/srt/function_call/function_call_parser.py +18 -14
  89. sglang/srt/function_call/glm4_moe_detector.py +1 -5
  90. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  91. sglang/srt/function_call/json_array_parser.py +0 -2
  92. sglang/srt/function_call/utils.py +2 -2
  93. sglang/srt/grpc/compile_proto.py +3 -3
  94. sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
  95. sglang/srt/grpc/health_servicer.py +189 -0
  96. sglang/srt/grpc/scheduler_launcher.py +181 -0
  97. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  98. sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
  99. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
  100. sglang/srt/layers/activation.py +4 -1
  101. sglang/srt/layers/attention/aiter_backend.py +3 -3
  102. sglang/srt/layers/attention/ascend_backend.py +17 -1
  103. sglang/srt/layers/attention/attention_registry.py +43 -23
  104. sglang/srt/layers/attention/base_attn_backend.py +20 -1
  105. sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
  106. sglang/srt/layers/attention/fla/chunk.py +0 -1
  107. sglang/srt/layers/attention/fla/chunk_o.py +1 -1
  108. sglang/srt/layers/attention/fla/index.py +0 -2
  109. sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
  110. sglang/srt/layers/attention/fla/utils.py +0 -3
  111. sglang/srt/layers/attention/fla/wy_fast.py +0 -2
  112. sglang/srt/layers/attention/flashattention_backend.py +12 -8
  113. sglang/srt/layers/attention/flashinfer_backend.py +248 -21
  114. sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
  115. sglang/srt/layers/attention/flashmla_backend.py +2 -2
  116. sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
  117. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
  118. sglang/srt/layers/attention/intel_amx_backend.py +1 -1
  119. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  120. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
  121. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  122. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  123. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  124. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  125. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  126. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  127. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
  128. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  129. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
  130. sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
  131. sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
  132. sglang/srt/layers/attention/nsa/utils.py +0 -1
  133. sglang/srt/layers/attention/nsa_backend.py +404 -90
  134. sglang/srt/layers/attention/triton_backend.py +208 -34
  135. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
  136. sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
  137. sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
  138. sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
  139. sglang/srt/layers/attention/utils.py +11 -7
  140. sglang/srt/layers/attention/vision.py +3 -3
  141. sglang/srt/layers/attention/xpu_backend.py +1028 -0
  142. sglang/srt/layers/communicator.py +11 -7
  143. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
  144. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
  145. sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
  146. sglang/srt/layers/dp_attention.py +17 -0
  147. sglang/srt/layers/layernorm.py +45 -15
  148. sglang/srt/layers/linear.py +9 -1
  149. sglang/srt/layers/logits_processor.py +147 -17
  150. sglang/srt/layers/modelopt_utils.py +11 -0
  151. sglang/srt/layers/moe/cutlass_moe.py +0 -2
  152. sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
  153. sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
  154. sglang/srt/layers/moe/ep_moe/layer.py +119 -397
  155. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
  156. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  157. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
  158. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
  159. sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
  160. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
  161. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  162. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  163. sglang/srt/layers/moe/moe_runner/triton.py +3 -1
  164. sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
  165. sglang/srt/layers/moe/router.py +51 -15
  166. sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
  167. sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
  168. sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
  169. sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
  170. sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
  171. sglang/srt/layers/moe/topk.py +3 -2
  172. sglang/srt/layers/moe/utils.py +17 -1
  173. sglang/srt/layers/quantization/__init__.py +2 -53
  174. sglang/srt/layers/quantization/awq.py +183 -6
  175. sglang/srt/layers/quantization/awq_triton.py +29 -0
  176. sglang/srt/layers/quantization/base_config.py +20 -1
  177. sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
  178. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
  179. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
  180. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
  181. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
  182. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
  183. sglang/srt/layers/quantization/fp8.py +84 -18
  184. sglang/srt/layers/quantization/fp8_kernel.py +55 -10
  185. sglang/srt/layers/quantization/fp8_utils.py +42 -14
  186. sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
  187. sglang/srt/layers/quantization/gptq.py +0 -1
  188. sglang/srt/layers/quantization/int8_kernel.py +18 -2
  189. sglang/srt/layers/quantization/marlin_utils.py +12 -0
  190. sglang/srt/layers/quantization/modelopt_quant.py +125 -100
  191. sglang/srt/layers/quantization/mxfp4.py +5 -30
  192. sglang/srt/layers/quantization/petit.py +1 -1
  193. sglang/srt/layers/quantization/quark/quark.py +3 -1
  194. sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
  195. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
  196. sglang/srt/layers/quantization/unquant.py +1 -4
  197. sglang/srt/layers/quantization/utils.py +0 -1
  198. sglang/srt/layers/quantization/w4afp8.py +51 -20
  199. sglang/srt/layers/quantization/w8a8_int8.py +30 -24
  200. sglang/srt/layers/radix_attention.py +59 -9
  201. sglang/srt/layers/rotary_embedding.py +673 -16
  202. sglang/srt/layers/sampler.py +36 -16
  203. sglang/srt/layers/sparse_pooler.py +98 -0
  204. sglang/srt/layers/utils.py +0 -1
  205. sglang/srt/layers/vocab_parallel_embedding.py +4 -1
  206. sglang/srt/lora/backend/triton_backend.py +0 -1
  207. sglang/srt/lora/eviction_policy.py +139 -0
  208. sglang/srt/lora/lora_manager.py +24 -9
  209. sglang/srt/lora/lora_registry.py +1 -1
  210. sglang/srt/lora/mem_pool.py +40 -16
  211. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
  212. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
  213. sglang/srt/managers/cache_controller.py +48 -17
  214. sglang/srt/managers/data_parallel_controller.py +146 -42
  215. sglang/srt/managers/detokenizer_manager.py +40 -13
  216. sglang/srt/managers/io_struct.py +66 -16
  217. sglang/srt/managers/mm_utils.py +20 -18
  218. sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
  219. sglang/srt/managers/overlap_utils.py +96 -19
  220. sglang/srt/managers/schedule_batch.py +241 -511
  221. sglang/srt/managers/schedule_policy.py +15 -2
  222. sglang/srt/managers/scheduler.py +399 -499
  223. sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
  224. sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
  225. sglang/srt/managers/scheduler_pp_mixin.py +341 -0
  226. sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
  227. sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
  228. sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
  229. sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
  230. sglang/srt/managers/tokenizer_manager.py +378 -90
  231. sglang/srt/managers/tp_worker.py +212 -161
  232. sglang/srt/managers/utils.py +78 -2
  233. sglang/srt/mem_cache/allocator.py +7 -2
  234. sglang/srt/mem_cache/allocator_ascend.py +2 -2
  235. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  236. sglang/srt/mem_cache/chunk_cache.py +13 -2
  237. sglang/srt/mem_cache/common.py +480 -0
  238. sglang/srt/mem_cache/evict_policy.py +16 -1
  239. sglang/srt/mem_cache/hicache_storage.py +4 -1
  240. sglang/srt/mem_cache/hiradix_cache.py +16 -3
  241. sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
  242. sglang/srt/mem_cache/memory_pool.py +435 -219
  243. sglang/srt/mem_cache/memory_pool_host.py +0 -1
  244. sglang/srt/mem_cache/multimodal_cache.py +0 -1
  245. sglang/srt/mem_cache/radix_cache.py +53 -19
  246. sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
  247. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
  248. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
  249. sglang/srt/mem_cache/storage/backend_factory.py +2 -2
  250. sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
  251. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
  252. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
  253. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
  254. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
  255. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
  256. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
  257. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
  258. sglang/srt/mem_cache/swa_radix_cache.py +92 -26
  259. sglang/srt/metrics/collector.py +31 -0
  260. sglang/srt/metrics/func_timer.py +1 -1
  261. sglang/srt/model_executor/cuda_graph_runner.py +43 -5
  262. sglang/srt/model_executor/forward_batch_info.py +28 -23
  263. sglang/srt/model_executor/model_runner.py +379 -139
  264. sglang/srt/model_executor/npu_graph_runner.py +2 -3
  265. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
  266. sglang/srt/model_loader/__init__.py +1 -1
  267. sglang/srt/model_loader/loader.py +424 -27
  268. sglang/srt/model_loader/utils.py +0 -1
  269. sglang/srt/model_loader/weight_utils.py +47 -28
  270. sglang/srt/models/apertus.py +2 -3
  271. sglang/srt/models/arcee.py +2 -2
  272. sglang/srt/models/bailing_moe.py +13 -52
  273. sglang/srt/models/bailing_moe_nextn.py +3 -4
  274. sglang/srt/models/bert.py +1 -1
  275. sglang/srt/models/deepseek_nextn.py +19 -3
  276. sglang/srt/models/deepseek_ocr.py +1516 -0
  277. sglang/srt/models/deepseek_v2.py +273 -98
  278. sglang/srt/models/dots_ocr.py +0 -2
  279. sglang/srt/models/dots_vlm.py +0 -1
  280. sglang/srt/models/dots_vlm_vit.py +1 -1
  281. sglang/srt/models/falcon_h1.py +13 -19
  282. sglang/srt/models/gemma3_mm.py +16 -0
  283. sglang/srt/models/gemma3n_mm.py +1 -2
  284. sglang/srt/models/glm4_moe.py +14 -37
  285. sglang/srt/models/glm4_moe_nextn.py +2 -2
  286. sglang/srt/models/glm4v.py +2 -1
  287. sglang/srt/models/glm4v_moe.py +5 -5
  288. sglang/srt/models/gpt_oss.py +5 -5
  289. sglang/srt/models/grok.py +10 -23
  290. sglang/srt/models/hunyuan.py +2 -7
  291. sglang/srt/models/interns1.py +0 -1
  292. sglang/srt/models/kimi_vl.py +1 -7
  293. sglang/srt/models/kimi_vl_moonvit.py +3 -1
  294. sglang/srt/models/llama.py +2 -2
  295. sglang/srt/models/llama_eagle3.py +1 -1
  296. sglang/srt/models/longcat_flash.py +5 -22
  297. sglang/srt/models/longcat_flash_nextn.py +3 -14
  298. sglang/srt/models/mimo.py +2 -13
  299. sglang/srt/models/mimo_mtp.py +1 -2
  300. sglang/srt/models/minicpmo.py +7 -5
  301. sglang/srt/models/mixtral.py +1 -4
  302. sglang/srt/models/mllama.py +1 -1
  303. sglang/srt/models/mllama4.py +13 -3
  304. sglang/srt/models/nemotron_h.py +511 -0
  305. sglang/srt/models/olmo2.py +31 -4
  306. sglang/srt/models/opt.py +5 -5
  307. sglang/srt/models/phi.py +1 -1
  308. sglang/srt/models/phi4mm.py +1 -1
  309. sglang/srt/models/phimoe.py +0 -1
  310. sglang/srt/models/pixtral.py +0 -3
  311. sglang/srt/models/points_v15_chat.py +186 -0
  312. sglang/srt/models/qwen.py +0 -1
  313. sglang/srt/models/qwen2_5_vl.py +3 -3
  314. sglang/srt/models/qwen2_audio.py +2 -15
  315. sglang/srt/models/qwen2_moe.py +15 -12
  316. sglang/srt/models/qwen2_vl.py +5 -2
  317. sglang/srt/models/qwen3_moe.py +19 -35
  318. sglang/srt/models/qwen3_next.py +7 -12
  319. sglang/srt/models/qwen3_next_mtp.py +3 -4
  320. sglang/srt/models/qwen3_omni_moe.py +661 -0
  321. sglang/srt/models/qwen3_vl.py +37 -33
  322. sglang/srt/models/qwen3_vl_moe.py +57 -185
  323. sglang/srt/models/roberta.py +55 -3
  324. sglang/srt/models/sarashina2_vision.py +0 -1
  325. sglang/srt/models/step3_vl.py +3 -5
  326. sglang/srt/models/utils.py +11 -1
  327. sglang/srt/multimodal/processors/base_processor.py +6 -2
  328. sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
  329. sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
  330. sglang/srt/multimodal/processors/dots_vlm.py +0 -1
  331. sglang/srt/multimodal/processors/glm4v.py +1 -5
  332. sglang/srt/multimodal/processors/internvl.py +0 -2
  333. sglang/srt/multimodal/processors/janus_pro.py +0 -1
  334. sglang/srt/multimodal/processors/mllama4.py +0 -8
  335. sglang/srt/multimodal/processors/phi4mm.py +0 -1
  336. sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
  337. sglang/srt/multimodal/processors/qwen_vl.py +75 -16
  338. sglang/srt/multimodal/processors/step3_vl.py +1 -1
  339. sglang/srt/parser/conversation.py +41 -0
  340. sglang/srt/parser/reasoning_parser.py +0 -1
  341. sglang/srt/sampling/custom_logit_processor.py +77 -2
  342. sglang/srt/sampling/sampling_batch_info.py +17 -22
  343. sglang/srt/sampling/sampling_params.py +70 -2
  344. sglang/srt/server_args.py +577 -73
  345. sglang/srt/server_args_config_parser.py +1 -1
  346. sglang/srt/single_batch_overlap.py +38 -28
  347. sglang/srt/speculative/base_spec_worker.py +34 -0
  348. sglang/srt/speculative/draft_utils.py +226 -0
  349. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
  350. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
  351. sglang/srt/speculative/eagle_info.py +57 -18
  352. sglang/srt/speculative/eagle_info_v2.py +458 -0
  353. sglang/srt/speculative/eagle_utils.py +138 -0
  354. sglang/srt/speculative/eagle_worker.py +83 -280
  355. sglang/srt/speculative/eagle_worker_v2.py +702 -0
  356. sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
  357. sglang/srt/speculative/ngram_worker.py +12 -11
  358. sglang/srt/speculative/spec_info.py +2 -0
  359. sglang/srt/speculative/spec_utils.py +38 -3
  360. sglang/srt/speculative/standalone_worker.py +4 -14
  361. sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
  362. sglang/srt/two_batch_overlap.py +28 -14
  363. sglang/srt/utils/__init__.py +1 -1
  364. sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
  365. sglang/srt/utils/common.py +192 -47
  366. sglang/srt/utils/hf_transformers_utils.py +40 -17
  367. sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
  368. sglang/srt/{offloader.py → utils/offloader.py} +4 -4
  369. sglang/srt/utils/profile_merger.py +199 -0
  370. sglang/test/attention/test_flashattn_backend.py +1 -1
  371. sglang/test/attention/test_flashattn_mla_backend.py +0 -1
  372. sglang/test/attention/test_prefix_chunk_info.py +0 -2
  373. sglang/test/attention/test_trtllm_mla_backend.py +221 -53
  374. sglang/test/few_shot_gsm8k_engine.py +2 -4
  375. sglang/test/kit_matched_stop.py +157 -0
  376. sglang/test/longbench_v2/__init__.py +1 -0
  377. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  378. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  379. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  380. sglang/test/run_eval.py +41 -0
  381. sglang/test/runners.py +2 -0
  382. sglang/test/send_one.py +42 -7
  383. sglang/test/simple_eval_common.py +3 -0
  384. sglang/test/simple_eval_gpqa.py +0 -1
  385. sglang/test/simple_eval_humaneval.py +0 -3
  386. sglang/test/simple_eval_longbench_v2.py +344 -0
  387. sglang/test/test_block_fp8.py +1 -2
  388. sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
  389. sglang/test/test_cutlass_moe.py +1 -2
  390. sglang/test/test_cutlass_w4a8_moe.py +10 -20
  391. sglang/test/test_deterministic.py +232 -99
  392. sglang/test/test_deterministic_utils.py +73 -0
  393. sglang/test/test_disaggregation_utils.py +81 -0
  394. sglang/test/test_marlin_moe.py +0 -1
  395. sglang/test/test_utils.py +85 -20
  396. sglang/version.py +1 -1
  397. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
  398. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
  399. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  400. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  401. sglang/srt/speculative/build_eagle_tree.py +0 -427
  402. sglang/test/test_block_fp8_ep.py +0 -358
  403. /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
  404. /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
  405. /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
  406. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
  407. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
  408. {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,117 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ """Common config utils for mamba2 - NemotronH, FalconH1, Qwen3Next, etc."""
14
+
15
+ import os
16
+ from dataclasses import dataclass, field
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from sglang.srt.distributed.utils import divide
22
+
23
+
24
+ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
25
+ """Compute the increase in group numbers to account for
26
+ replication in order to accompany the head shards."""
27
+
28
+ # in the case ngoups % tp_size == 0, this will be zero
29
+ if ngroups % tp_size == 0:
30
+ return 0
31
+
32
+ # for n_groups == 1, this is exactly tp_size - n_groups
33
+ return tp_size - ngroups
34
+
35
+
36
+ @dataclass(kw_only=True, frozen=True)
37
+ class Mamba2StateShape:
38
+ conv: tuple[int, int]
39
+ temporal: tuple[int, int, int]
40
+
41
+ intermediate_size: int
42
+ conv_dim: int
43
+ ssm_state_size: int
44
+ num_heads: int
45
+ head_dim: int
46
+ state_size: int
47
+ conv_kernel: int
48
+
49
+ @staticmethod
50
+ def create(
51
+ *,
52
+ tp_world_size: int,
53
+ intermediate_size: int,
54
+ n_groups: int,
55
+ num_heads: int,
56
+ head_dim: int,
57
+ state_size: int,
58
+ conv_kernel: int,
59
+ ) -> "Mamba2StateShape":
60
+ # if n_groups is not divisible by world_size, need to extend the shards
61
+ # to ensure all groups needed by a head is sharded along with it
62
+ if n_groups % tp_world_size != 0:
63
+ extra_groups = extra_groups_for_head_shards(n_groups, tp_world_size)
64
+ n_groups += extra_groups
65
+ # heads and n_groups are TP-ed
66
+ conv_dim = intermediate_size + 2 * n_groups * state_size
67
+
68
+ # contiguous along 'dim' axis
69
+ conv_state_shape = divide(conv_dim, tp_world_size), conv_kernel - 1
70
+
71
+ # These are not TP-ed as they depend on A, dt_bias, D
72
+ # - they are typically small
73
+ # e.g., QWen3-Next: (32, 128, 128)
74
+ temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
75
+ return Mamba2StateShape(
76
+ conv=conv_state_shape,
77
+ temporal=temporal_state_shape,
78
+ intermediate_size=intermediate_size,
79
+ conv_dim=conv_dim,
80
+ ssm_state_size=state_size,
81
+ num_heads=num_heads,
82
+ head_dim=head_dim,
83
+ state_size=state_size,
84
+ conv_kernel=conv_kernel,
85
+ )
86
+
87
+
88
+ @dataclass(kw_only=True, frozen=True)
89
+ class Mamba2StateDType:
90
+ conv: torch.dtype
91
+ temporal: torch.dtype
92
+
93
+
94
+ CONV_DTYPE = torch.bfloat16
95
+
96
+
97
+ def mamba2_state_dtype() -> Mamba2StateDType:
98
+ dtype_map = {
99
+ "float32": torch.float32,
100
+ "bfloat16": torch.bfloat16,
101
+ }
102
+ ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
103
+ return Mamba2StateDType(conv=CONV_DTYPE, temporal=ssm_dtype)
104
+
105
+
106
+ @dataclass(kw_only=True, frozen=True)
107
+ class Mamba2CacheParams:
108
+ shape: Mamba2StateShape
109
+ dtype: Mamba2StateDType = field(default_factory=mamba2_state_dtype)
110
+ layers: list[int]
111
+
112
+ @property
113
+ def mamba_cache_per_req(self) -> int:
114
+ return (
115
+ int(np.prod(self.shape.conv)) * self.dtype.conv.itemsize
116
+ + int(np.prod(self.shape.temporal)) * self.dtype.temporal.itemsize
117
+ ) * len(self.layers)
@@ -17,7 +17,7 @@ import logging
17
17
  import math
18
18
  import os
19
19
  from enum import Enum, IntEnum, auto
20
- from typing import List, Optional, Set, Union
20
+ from typing import Any, List, Optional, Set, Union
21
21
 
22
22
  import torch
23
23
  from transformers import PretrainedConfig
@@ -53,7 +53,11 @@ def is_deepseek_nsa(config: PretrainedConfig) -> bool:
53
53
  return (
54
54
  config.architectures is not None
55
55
  and config.architectures[0]
56
- in ["DeepseekV3ForCausalLM", "DeepseekV32ForCausalLM"]
56
+ in [
57
+ "DeepseekV3ForCausalLM",
58
+ "DeepseekV32ForCausalLM",
59
+ "DeepseekV3ForCausalLMNextN",
60
+ ]
57
61
  and getattr(config, "index_topk", None) is not None
58
62
  )
59
63
 
@@ -87,8 +91,12 @@ class ModelConfig:
87
91
  quantization: Optional[str] = None,
88
92
  override_config_file: Optional[str] = None,
89
93
  is_draft_model: bool = False,
90
- hybrid_kvcache_ratio: Optional[float] = None,
94
+ hybrid_kvcache_ratio: Optional[
95
+ float
96
+ ] = None, # TODO: remove this, it is not a model config
91
97
  model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
98
+ sampling_defaults: str = "openai",
99
+ quantize_and_serve: bool = False,
92
100
  ) -> None:
93
101
  # Parse args
94
102
  self.model_path = model_path
@@ -96,6 +104,11 @@ class ModelConfig:
96
104
  self.quantization = quantization
97
105
  self.is_draft_model = is_draft_model
98
106
  self.model_impl = model_impl
107
+ self.sampling_defaults = sampling_defaults
108
+ self.quantize_and_serve = quantize_and_serve
109
+
110
+ # Validate quantize_and_serve configuration
111
+ self._validate_quantize_and_serve_config()
99
112
 
100
113
  # Get hf config
101
114
  self._maybe_pull_model_tokenizer_from_remote()
@@ -211,6 +224,8 @@ class ModelConfig:
211
224
  quantization=server_args.quantization,
212
225
  hybrid_kvcache_ratio=server_args.hybrid_kvcache_ratio,
213
226
  model_impl=server_args.model_impl,
227
+ sampling_defaults=server_args.sampling_defaults,
228
+ quantize_and_serve=server_args.quantize_and_serve,
214
229
  **kwargs,
215
230
  )
216
231
 
@@ -477,31 +492,32 @@ class ModelConfig:
477
492
  # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
478
493
  # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
479
494
  is_local = os.path.exists(self.model_path)
480
- modelopt_quant_config = {"quant_method": "modelopt"}
481
495
  if not is_local:
482
496
  import huggingface_hub
483
497
 
484
498
  try:
485
- from huggingface_hub import HfApi
499
+ from huggingface_hub import HfApi, hf_hub_download
486
500
 
487
501
  hf_api = HfApi()
488
-
489
- def check_hf_quant_config():
490
- return hf_api.file_exists(
491
- self.model_path, "hf_quant_config.json"
492
- )
493
-
494
502
  # Retry HF API call up to 3 times
495
503
  file_exists = retry(
496
- check_hf_quant_config,
504
+ lambda: hf_api.file_exists(
505
+ self.model_path, "hf_quant_config.json"
506
+ ),
497
507
  max_retry=2,
498
508
  initial_delay=1.0,
499
509
  max_delay=5.0,
500
510
  )
501
-
502
511
  if file_exists:
503
- quant_cfg = modelopt_quant_config
504
-
512
+ # Download and parse the quantization config for remote models
513
+ quant_config_file = hf_hub_download(
514
+ repo_id=self.model_path,
515
+ filename="hf_quant_config.json",
516
+ revision=self.revision,
517
+ )
518
+ with open(quant_config_file) as f:
519
+ quant_config_dict = json.load(f)
520
+ quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
505
521
  except huggingface_hub.errors.OfflineModeIsEnabled:
506
522
  logger.warning(
507
523
  "Offline mode is enabled, skipping hf_quant_config.json check"
@@ -510,21 +526,80 @@ class ModelConfig:
510
526
  logger.warning(
511
527
  f"Failed to check hf_quant_config.json: {self.model_path} {e}"
512
528
  )
513
-
514
529
  elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
515
530
  quant_config_file = os.path.join(
516
531
  self.model_path, "hf_quant_config.json"
517
532
  )
518
533
  with open(quant_config_file) as f:
519
534
  quant_config_dict = json.load(f)
520
- json_quant_configs = quant_config_dict["quantization"]
521
- quant_algo = json_quant_configs.get("quant_algo", None)
522
- if quant_algo == "MIXED_PRECISION":
523
- quant_cfg = {"quant_method": "w4afp8"}
524
- else:
525
- quant_cfg = modelopt_quant_config
535
+ quant_cfg = self._parse_modelopt_quant_config(quant_config_dict)
526
536
  return quant_cfg
527
537
 
538
+ def _parse_modelopt_quant_config(self, quant_config_dict: dict) -> dict:
539
+ """Parse ModelOpt quantization config and return the appropriate quant_method."""
540
+ json_quant_configs = quant_config_dict["quantization"]
541
+ quant_algo = json_quant_configs.get("quant_algo", None)
542
+
543
+ if quant_algo == "MIXED_PRECISION":
544
+ return {"quant_method": "w4afp8"}
545
+ elif quant_algo and ("FP4" in quant_algo or "NVFP4" in quant_algo):
546
+ return {"quant_method": "modelopt_fp4"}
547
+ elif quant_algo and "FP8" in quant_algo:
548
+ return {"quant_method": "modelopt_fp8"}
549
+ else:
550
+ # Default to FP8 for backward compatibility
551
+ return {"quant_method": "modelopt_fp8"}
552
+
553
+ def _is_already_quantized(self) -> bool:
554
+ """Check if the model is already quantized based on config files."""
555
+ # Check for HuggingFace quantization config
556
+ from sglang.srt.utils import has_hf_quant_config
557
+
558
+ return has_hf_quant_config(self.model_path)
559
+
560
+ def _get_modelopt_quant_type(self) -> str:
561
+ """Extract ModelOpt quantization type from unified quantization flag."""
562
+ if self.quantization == "modelopt_fp8":
563
+ return "fp8"
564
+ elif self.quantization == "modelopt_fp4":
565
+ return "nvfp4"
566
+ elif self.quantization == "modelopt":
567
+ # Auto-detect from model config
568
+ quant_cfg = self._parse_quant_hf_config()
569
+ if quant_cfg:
570
+ quant_method = quant_cfg.get("quant_method", "").lower()
571
+ if "fp4" in quant_method:
572
+ return "fp4"
573
+ elif "fp8" in quant_method:
574
+ return "fp8"
575
+ # Default to fp8 if can't detect
576
+ return "fp8"
577
+ else:
578
+ return "fp8" # Default fallback
579
+
580
+ def _validate_quantize_and_serve_config(self):
581
+ """Validate quantize_and_serve configuration."""
582
+ if not self.quantize_and_serve:
583
+ return
584
+
585
+ # Check if ModelOpt quantization is specified
586
+ modelopt_quantization_specified = self.quantization in [
587
+ "modelopt",
588
+ "modelopt_fp8",
589
+ "modelopt_fp4",
590
+ ]
591
+
592
+ if not modelopt_quantization_specified:
593
+ raise ValueError("quantize_and_serve requires ModelOpt quantization")
594
+
595
+ # quantize_and_serve is disabled due to compatibility issues
596
+ raise NotImplementedError(
597
+ "quantize_and_serve functionality is currently disabled due to compatibility issues. "
598
+ "Please use the separate quantize-then-deploy workflow instead. "
599
+ "Step 1: Quantize and export model. "
600
+ "Step 2: Deploy the exported model."
601
+ )
602
+
528
603
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
529
604
  def _verify_quantization(self) -> None:
530
605
  supported_quantization = [*QUANTIZATION_METHODS]
@@ -543,7 +618,8 @@ class ModelConfig:
543
618
  optimized_quantization_methods = [
544
619
  "fp8",
545
620
  "marlin",
546
- "modelopt",
621
+ "modelopt_fp8",
622
+ "modelopt_fp4",
547
623
  "gptq_marlin_24",
548
624
  "gptq_marlin",
549
625
  "awq_marlin",
@@ -657,6 +733,38 @@ class ModelConfig:
657
733
  eos_ids = eos_ids | generation_eos_ids
658
734
  return eos_ids
659
735
 
736
+ def get_default_sampling_params(self) -> dict[str, Any]:
737
+ """
738
+ Get default sampling parameters from the model's generation config.
739
+
740
+ This method returns non-default sampling parameters from the model's
741
+ generation_config.json when sampling_defaults is set to "model".
742
+
743
+ Returns:
744
+ A dictionary containing the non-default sampling parameters.
745
+ """
746
+ if self.sampling_defaults != "model":
747
+ return {}
748
+
749
+ if self.hf_generation_config is None:
750
+ return {}
751
+
752
+ config = self.hf_generation_config.to_dict()
753
+
754
+ available_params = [
755
+ "repetition_penalty",
756
+ "temperature",
757
+ "top_k",
758
+ "top_p",
759
+ "min_p",
760
+ ]
761
+
762
+ default_sampling_params = {
763
+ p: config.get(p) for p in available_params if config.get(p) is not None
764
+ }
765
+
766
+ return default_sampling_params
767
+
660
768
  def _maybe_pull_model_tokenizer_from_remote(self) -> None:
661
769
  """
662
770
  Pull the model config files to a temporary
@@ -802,15 +910,18 @@ multimodal_model_archs = [
802
910
  "Qwen2_5_VLForConditionalGeneration",
803
911
  "Qwen3VLForConditionalGeneration",
804
912
  "Qwen3VLMoeForConditionalGeneration",
913
+ "Qwen3OmniMoeForConditionalGeneration",
805
914
  "KimiVLForConditionalGeneration",
806
915
  "InternVLChatModel",
807
916
  "InternS1ForConditionalGeneration",
808
917
  "Phi4MMForCausalLM",
809
918
  "VILAForConditionalGeneration",
810
919
  "Step3VLForConditionalGeneration",
920
+ "POINTSV15ChatModel",
811
921
  "DotsVLMForCausalLM",
812
922
  "DotsOCRForCausalLM",
813
923
  "Sarashina2VisionForCausalLM",
924
+ "DeepseekOCRForCausalLM",
814
925
  ]
815
926
 
816
927
 
@@ -0,0 +1,30 @@
1
+ # Configuration for NVIDIA ModelOpt quantization integration
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+
6
+ @dataclass
7
+ class ModelOptConfig:
8
+ """Configuration for NVIDIA ModelOpt quantization operations.
9
+
10
+ This configuration class holds parameters for ModelOpt quantization,
11
+ checkpoint management, and model export operations.
12
+
13
+ Args:
14
+ quant: Quantization method/type (e.g., "fp8", "fp4")
15
+ checkpoint_restore_path: Path to restore ModelOpt checkpoint from
16
+ checkpoint_save_path: Path to save ModelOpt checkpoint to
17
+ export_path: Path to export quantized model in HuggingFace format
18
+ quantize_and_serve: Whether to quantize and serve in one step
19
+ """
20
+
21
+ quant: Optional[str] = None
22
+ checkpoint_restore_path: Optional[str] = None
23
+ checkpoint_save_path: Optional[str] = None
24
+ export_path: Optional[str] = None
25
+ quantize_and_serve: bool = False
26
+
27
+ def __post_init__(self):
28
+ """Validate configuration after initialization."""
29
+ # Add any validation logic if needed
30
+ pass
@@ -0,0 +1,286 @@
1
+ # Copyright 2025 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/transformers_utils/configs/nemotron_h.py
15
+
16
+ """NemotronH model configuration"""
17
+
18
+ import regex as re
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.utils import logging
21
+
22
+ from sglang.srt.configs.mamba_utils import Mamba2CacheParams, Mamba2StateShape
23
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
24
+
25
+ logger = logging.get_logger(__name__)
26
+
27
+ MAMBA = "M"
28
+ ATTENTION = "*"
29
+ MLP = "-"
30
+
31
+
32
+ class NemotronHConfig(PretrainedConfig):
33
+ r"""
34
+ This is the configuration class to store the configuration of a
35
+ [`NemotronHModel`]. It is used to instantiate a NemotronH model according
36
+ to the specified arguments, defining the model architecture. Instantiating
37
+ a configuration with the defaults will yield a similar configuration to
38
+ that of the NemotronH-v0.1 model.
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 131072):
41
+ Vocabulary size of the NemotronH model. Defines the number of
42
+ different tokens that can be represented by the `inputs_ids`
43
+ passed when calling [`NemotronHModel`]
44
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
45
+ Whether the model's input and output word embeddings should be
46
+ tied. Note that this is only relevant if the model has an output
47
+ word embedding layer.
48
+ hidden_size (`int`, *optional*, defaults to 4096):
49
+ Dimension of the hidden representations.
50
+ intermediate_size (`int`, *optional*, defaults to 21504):
51
+ Dimension of the MLP representations.
52
+ num_hidden_layers (`int`, *optional*, defaults to 52):
53
+ Number of hidden layers in the Transformer encoder.
54
+ hybrid_override_pattern (`str`, *optional*, defaults to
55
+ `"M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-"`):
56
+ The pattern of the hybrid model. The pattern is a string of
57
+ characters where each character represents
58
+ M: Mamba2, *: Attention, -: MLP
59
+ num_attention_heads (`int`, *optional*, defaults to 32):
60
+ Number of attention heads for each attention layer in the
61
+ Transformer encoder.
62
+ attention_head_dim (`int`, *optional*, defaults to 128):
63
+ Dimension of each attention head.
64
+ num_key_value_heads (`int`, *optional*, defaults to 8):
65
+ This is the number of key_value heads that should be used to
66
+ implement Grouped Query Attention. If
67
+ `num_key_value_heads=num_attention_heads`, the model will use
68
+ Multi Head Attention (MHA), if `num_key_value_heads=1` the model
69
+ will use Multi Query Attention (MQA) otherwise GQA is used.
70
+ mlp_hidden_act (`str`, *optional*, defaults to "relu2"):
71
+ The non-linear activation function in the MLP layers.
72
+ attention_bias (`bool`, *optional*, defaults to `False`):
73
+ Whether to use bias in attention layers.
74
+ mlp_bias (`bool`, *optional*, defaults to `False`):
75
+ Whether to use bias in MLP layers.
76
+ use_bias (`bool`, *optional*, defaults to `False`):
77
+ Whether to use bias in the model.
78
+ initializer_range (`float`, *optional*, defaults to 0.02):
79
+ The standard deviation of the truncated_normal_initializer for
80
+ initializing all weight matrices.
81
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
82
+ The epsilon used by the layer normalization layers.
83
+ residual_in_fp32 (`bool`, *optional*, defaults to `False`):
84
+ Whether or not residuals should be in `float32`. If set to `False`
85
+ residuals will keep the same `dtype` as the rest of the model.
86
+ use_cache (`bool`, *optional*, defaults to `True`):
87
+ Whether or not the model should return the last key/values
88
+ attentions (not used by all models). Only relevant if
89
+ `config.is_decoder=True`.
90
+ num_logits_to_keep (`int` or `None`, *optional*, defaults to 1):
91
+ Number of prompt logits to calculate during generation. If `None`,
92
+ all logits will be calculated. If an integer value, only last
93
+ `num_logits_to_keep` logits will be calculated.
94
+ pad_token_id (`int`, *optional*, defaults to 0):
95
+ The id of the padding token.
96
+ bos_token_id (`int`, *optional*, defaults to 1):
97
+ The id of the "beginning-of-sequence" token.
98
+ eos_token_id (`int`, *optional*, defaults to 2):
99
+ The id of the "end-of-sequence" token.
100
+ sliding_window (`int`, *optional*, defaults to None):
101
+ Sliding window attention window size.
102
+ max_position_embeddings (`int`, *optional*, defaults to 4096):
103
+ The maximum sequence length that this model might ever be used
104
+ with.
105
+ attention_dropout (`float`, *optional*, defaults to 0.0):
106
+ The dropout ratio for the attention probabilities.
107
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
108
+ The dropout ratio for the hidden states.
109
+ use_mamba_kernels (`bool`, *optional*, defaults to `True`):
110
+ Flag indicating whether or not to use the fast mamba kernels.
111
+ These are available only if `mamba-ssm` and `causal-conv1d`
112
+ are installed, and the mamba modules are running on a CUDA device.
113
+ ssm_state_size (`int`, *optional*, defaults to 128):
114
+ The dimension of the mamba state space latents.
115
+ mamba_num_heads (`int`, *optional*, defaults to 128):
116
+ Number of heads in Mamba layers.
117
+ mamba_n_groups (`int`, *optional*, defaults to 8):
118
+ Number of groups in Mamba layers.
119
+ mamba_head_dim (`int`, *optional*, defaults to 64):
120
+ Dimension of each Mamba head.
121
+ mamba_d_conv (`int`, *optional*, defaults to 4):
122
+ The size of the mamba convolution kernel.
123
+ mamba_expand (`int`, *optional*, defaults to 2):
124
+ Expanding factor used to determine the mamba intermediate size.
125
+ mamba_hidden_act (`str`, *optional*, defaults to "silu"):
126
+ The non-linear activation function in the Mamba layers.
127
+ mamba_dt_min (`float`, *optional*, defaults to 0.001):
128
+ Minimum value for the time step in Mamba.
129
+ mamba_dt_max (`float`, *optional*, defaults to 0.1):
130
+ Maximum value for the time step in Mamba.
131
+ mamba_dt_limit (`tuple`, *optional*, defaults to (0.0, float("inf"))):
132
+ Limits for the time step in Mamba.
133
+ mamba_dt_init_floor (`float`, *optional*, defaults to 1e-4):
134
+ Floor value for time step initialization in Mamba.
135
+ mamba_conv_bias (`bool`, *optional*, defaults to `True`):
136
+ Whether to use bias in the convolution layer of the mamba mixer
137
+ block.
138
+ mamba_proj_bias (`bool`, *optional*, defaults to `False`):
139
+ Whether to use bias in the input and output projections of the
140
+ mamba mixer block.
141
+ mamba_chunk_size (`int`, *optional*, defaults to 256):
142
+ Size of chunks for Mamba processing.
143
+ rescale_prenorm_residual (`bool`, *optional*, defaults to `True`):
144
+ Whether to rescale the pre-normalization residual connections.
145
+ """
146
+
147
+ model_type = "nemotron_h"
148
+ keys_to_ignore_at_inference = ["past_key_values"]
149
+
150
+ def __init__(
151
+ self,
152
+ vocab_size=131072,
153
+ tie_word_embeddings=False,
154
+ hidden_size=4096,
155
+ intermediate_size=21504,
156
+ num_hidden_layers=52,
157
+ hybrid_override_pattern="M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M*-M-M-M-M-M-",
158
+ num_attention_heads=32,
159
+ head_dim=128,
160
+ num_key_value_heads=8, # nemo: num_query_groups
161
+ mlp_hidden_act="relu2",
162
+ attention_bias=False,
163
+ mlp_bias=False,
164
+ use_bias=False,
165
+ initializer_range=0.02, # nemo: init_method_std
166
+ layer_norm_epsilon=1e-5, # nemo: layernorm_epsilon
167
+ residual_in_fp32=False, # Megatron Core default value
168
+ use_cache=True,
169
+ num_logits_to_keep=1,
170
+ pad_token_id=0,
171
+ bos_token_id=1,
172
+ eos_token_id=2,
173
+ sliding_window=None,
174
+ max_position_embeddings=4096,
175
+ attention_dropout=0.0,
176
+ hidden_dropout=0.0, # * ADDED
177
+ use_mamba_kernels=True,
178
+ ssm_state_size=128, # mamba_state_size
179
+ mamba_num_heads=128,
180
+ mamba_n_groups=8, # nemo: mamba_ssm_ngroups = num_heads
181
+ mamba_head_dim=64,
182
+ mamba_d_conv=4,
183
+ mamba_expand=2,
184
+ mamba_hidden_act="silu",
185
+ mamba_dt_min=0.001,
186
+ mamba_dt_max=0.1,
187
+ mamba_dt_limit=(0.0, float("inf")),
188
+ mamba_dt_init_floor=1e-4,
189
+ mamba_conv_bias=True,
190
+ mamba_proj_bias=False,
191
+ mamba_chunk_size=256,
192
+ rescale_prenorm_residual=True,
193
+ **kwargs,
194
+ ):
195
+ self.vocab_size = vocab_size
196
+ self.tie_word_embeddings = tie_word_embeddings
197
+ self.hidden_size = hidden_size
198
+ self.intermediate_size = intermediate_size
199
+ self.num_hidden_layers = num_hidden_layers
200
+ self.hybrid_override_pattern = hybrid_override_pattern
201
+ self.num_attention_heads = num_attention_heads
202
+ self.head_dim = head_dim
203
+ self.sliding_window = sliding_window
204
+ self.max_position_embeddings = max_position_embeddings
205
+ self.attention_dropout = attention_dropout
206
+ self.hidden_dropout = hidden_dropout
207
+
208
+ # Validate hybrid_override_pattern
209
+ # M: Mamba2, *: Attention, -: MLP
210
+ assert len(self.hybrid_override_pattern) == self.num_hidden_layers, (
211
+ "hybrid_override_pattern must have same length as " "num_hidden_layers"
212
+ )
213
+ assert re.match(r"^[*-M]+$", self.hybrid_override_pattern), (
214
+ "hybrid_override_pattern must only contain characters " "'M', '*', or '-'"
215
+ )
216
+
217
+ # for backward compatibility
218
+ if num_key_value_heads is None:
219
+ num_key_value_heads = num_attention_heads
220
+
221
+ self.num_key_value_heads = num_key_value_heads
222
+ self.mlp_hidden_act = mlp_hidden_act
223
+ self.attention_bias = attention_bias
224
+ self.mlp_bias = mlp_bias
225
+ self.use_bias = use_bias
226
+ self.initializer_range = initializer_range
227
+ self.layer_norm_epsilon = layer_norm_epsilon
228
+ self.residual_in_fp32 = residual_in_fp32
229
+
230
+ self.use_cache = use_cache
231
+ self.num_logits_to_keep = num_logits_to_keep
232
+
233
+ self.use_mamba_kernels = use_mamba_kernels
234
+ self.mamba_n_groups = mamba_n_groups
235
+ self.mamba_head_dim = mamba_head_dim
236
+ self.ssm_state_size = ssm_state_size
237
+ self.mamba_num_heads = mamba_num_heads
238
+ self.conv_kernel = mamba_d_conv
239
+ self.expand = mamba_expand
240
+ self.mamba_hidden_act = mamba_hidden_act
241
+ self.time_step_min = mamba_dt_min
242
+ self.time_step_max = mamba_dt_max
243
+ self.time_step_limit = mamba_dt_limit
244
+ self.time_step_floor = mamba_dt_init_floor
245
+ self.use_conv_bias = mamba_conv_bias
246
+ self.mamba_proj_bias = mamba_proj_bias
247
+ self.mamba_chunk_size = mamba_chunk_size
248
+ self.rescale_prenorm_residual = rescale_prenorm_residual
249
+
250
+ super().__init__(
251
+ pad_token_id=pad_token_id,
252
+ bos_token_id=bos_token_id,
253
+ eos_token_id=eos_token_id,
254
+ tie_word_embeddings=tie_word_embeddings,
255
+ **kwargs,
256
+ )
257
+
258
+ @property
259
+ def mamba_layer_ids(self):
260
+ return [
261
+ i
262
+ for i in range(self.num_hidden_layers)
263
+ if self.hybrid_override_pattern[i] == MAMBA
264
+ ]
265
+
266
+ @property
267
+ def full_attention_layer_ids(self):
268
+ return [
269
+ i
270
+ for i in range(self.num_hidden_layers)
271
+ if self.hybrid_override_pattern[i] == ATTENTION
272
+ ]
273
+
274
+ @property
275
+ def mamba2_cache_params(self) -> Mamba2CacheParams:
276
+ shape = Mamba2StateShape.create(
277
+ tp_world_size=get_attention_tp_size(),
278
+ intermediate_size=self.mamba_num_heads * self.mamba_head_dim,
279
+ n_groups=self.n_groups,
280
+ num_heads=self.mamba_num_heads,
281
+ head_dim=self.mamba_head_dim,
282
+ state_size=self.ssm_state_size,
283
+ conv_kernel=self.conv_kernel,
284
+ )
285
+
286
+ return Mamba2CacheParams(shape=shape, layers=self.mamba_layer_ids)