sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,39 @@
1
- from typing import Callable, List, Tuple
1
+ from typing import Callable, List, Optional, Tuple, Union
2
2
 
3
3
  import torch
4
+ import torch.nn as nn
5
+
6
+ from sglang.srt.configs.model_config import ModelConfig
7
+ from sglang.srt.custom_op import CustomOp
8
+ from sglang.srt.distributed import (
9
+ get_tensor_model_parallel_rank,
10
+ get_tensor_model_parallel_world_size,
11
+ tensor_model_parallel_all_gather,
12
+ tensor_model_parallel_all_reduce,
13
+ )
14
+ from sglang.srt.distributed.utils import divide
15
+ from sglang.srt.layers.attention.fla.layernorm_gated import layernorm_fn
16
+ from sglang.srt.layers.attention.mamba.causal_conv1d import (
17
+ causal_conv1d_fn,
18
+ causal_conv1d_update,
19
+ )
20
+ from sglang.srt.layers.attention.mamba.mamba_utils import MambaStateShapeCalculator
21
+ from sglang.srt.layers.attention.mamba.ops import (
22
+ mamba_chunk_scan_combined,
23
+ selective_state_update,
24
+ )
25
+ from sglang.srt.layers.linear import (
26
+ ColumnParallelLinear,
27
+ MergedColumnParallelLinear,
28
+ RowParallelLinear,
29
+ )
30
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
31
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
32
+ from sglang.srt.model_loader.weight_utils import (
33
+ composed_weight_loader,
34
+ sharded_weight_loader,
35
+ )
36
+ from sglang.srt.utils import set_weight_attrs
4
37
 
5
38
  LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
6
39
 
@@ -62,3 +95,535 @@ def mamba_v2_sharded_weight_loader(
62
95
  loaded_boundary += full_dim - extra
63
96
 
64
97
  return loader
98
+
99
+
100
+ class Mixer2RMSNormGated(CustomOp):
101
+
102
+ def __init__(
103
+ self,
104
+ full_hidden_size: int,
105
+ full_n_groups: int,
106
+ use_rms_norm: bool = True,
107
+ eps: float = 1e-6,
108
+ ):
109
+ super().__init__()
110
+ self.tp_size = get_tensor_model_parallel_world_size()
111
+ self.tp_rank = get_tensor_model_parallel_rank()
112
+ self.full_hidden_size = full_hidden_size
113
+ self.group_size = full_hidden_size // full_n_groups
114
+ self.per_rank_hidden_size = full_hidden_size // self.tp_size
115
+ self.n_groups = full_hidden_size // self.group_size
116
+
117
+ self.variance_epsilon = eps
118
+ self.use_rms_norm = use_rms_norm
119
+ if self.use_rms_norm:
120
+ # Register norm weight only if we're actually applying RMSNorm
121
+ self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
122
+ set_weight_attrs(self.weight, {"weight_loader": sharded_weight_loader(0)})
123
+ else:
124
+ # Avoid checkpoint mismatch by skipping unused parameter
125
+ self.register_parameter("weight", None)
126
+ assert (
127
+ self.full_hidden_size % self.tp_size == 0
128
+ ), "Tensor parallel world size must divide hidden size."
129
+
130
+ def forward_native(
131
+ self,
132
+ x: torch.Tensor,
133
+ gate: torch.Tensor,
134
+ ):
135
+ # Three tensor-parallel cases:
136
+ # 1. n_groups is 1
137
+ # In this case we parallelize along the reduction dim.
138
+ # Each rank computes a local sum of squares followed by AllReduce
139
+ # 2. tp_size divides n_groups
140
+ # Each rank only reduces within its local group(s).
141
+ # No collective ops necessary.
142
+ # 3. The general case can be pretty complicated so we AllGather
143
+ # the input and then redundantly compute the RMSNorm.
144
+ input_dtype = x.dtype
145
+ x = x * nn.functional.silu(gate.to(torch.float32))
146
+ if not self.use_rms_norm:
147
+ return x.to(input_dtype)
148
+
149
+ if self.n_groups == 1:
150
+ if self.tp_size > 1:
151
+ # Compute local sum and then reduce to obtain global sum
152
+ local_sums = x.pow(2).sum(dim=-1, keepdim=True)
153
+ global_sums = tensor_model_parallel_all_reduce(local_sums)
154
+ # Calculate the variance
155
+ count = self.tp_size * x.shape[-1]
156
+ variance = global_sums / count
157
+
158
+ else:
159
+ variance = x.pow(2).mean(-1, keepdim=True)
160
+ x = x * torch.rsqrt(variance + self.variance_epsilon)
161
+ else:
162
+ redundant_tp: bool = self.n_groups % self.tp_size != 0
163
+ if redundant_tp:
164
+ # To handle the general case, redundantly apply the variance
165
+ x = tensor_model_parallel_all_gather(x, -1)
166
+
167
+ *prefix_dims, hidden_dim = x.shape
168
+ group_count = hidden_dim // self.group_size
169
+ x_grouped = x.view(*prefix_dims, group_count, self.group_size)
170
+ variance = x_grouped.pow(2).mean(-1, keepdim=True)
171
+ x_grouped = x_grouped * torch.rsqrt(variance + self.variance_epsilon)
172
+ x = x_grouped.view(*prefix_dims, hidden_dim)
173
+
174
+ if redundant_tp:
175
+ start = self.per_rank_hidden_size * self.tp_rank
176
+ end = start + self.per_rank_hidden_size
177
+ x = x[..., start:end]
178
+
179
+ return self.weight * x.to(input_dtype)
180
+
181
+ def forward_cuda(
182
+ self,
183
+ x: torch.Tensor,
184
+ gate: torch.Tensor,
185
+ ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
186
+ input_dtype = x.dtype
187
+ if not self.use_rms_norm:
188
+ # Keep gate in float32 for numerical stability during silu
189
+ return x * nn.functional.silu(gate.to(torch.float32)).to(input_dtype)
190
+
191
+ if ((self.n_groups % self.tp_size) != 0) or self.n_groups != 1:
192
+ return self.forward_native(x, gate)
193
+
194
+ return layernorm_fn(
195
+ x,
196
+ self.weight.data,
197
+ bias=None,
198
+ z=gate,
199
+ eps=self.variance_epsilon,
200
+ norm_before_gate=False,
201
+ )
202
+
203
+
204
+ class MambaMixer2(torch.nn.Module):
205
+ """
206
+ Compute ∆, A, B, C, and D the state space parameters and compute
207
+ the `contextualized_states`. A, D are input independent
208
+ (see Mamba paper [1] Section 3.5.2 "Interpretation of A"
209
+ for why A isn't selective) ∆, B, C are input-dependent
210
+ (this is a key difference between Mamba and the linear time
211
+ invariant S4, and is why Mamba is called
212
+ **selective** state spaces)
213
+ """
214
+
215
+ def __init__(
216
+ self,
217
+ hidden_size: int,
218
+ ssm_state_size: int,
219
+ conv_kernel_size: int,
220
+ intermediate_size: int,
221
+ use_conv_bias: bool,
222
+ use_bias: bool,
223
+ chunk_size: int,
224
+ layer_id: int,
225
+ n_groups: int = 1,
226
+ num_heads: int = 128,
227
+ head_dim: int = 64,
228
+ rms_norm_eps: float = 1e-5,
229
+ activation: str = "silu",
230
+ use_rms_norm: bool = True,
231
+ model_config: Optional[ModelConfig] = None,
232
+ # cache_config: Optional[CacheConfig] = None,
233
+ quant_config: Optional[QuantizationConfig] = None,
234
+ prefix: str = "",
235
+ ):
236
+ super().__init__()
237
+
238
+ # For TP, the sharding plan is as follows:
239
+ # - for the conv modules, since
240
+ # conv_dim = intermediate_size * 2 * n_groups * ssm_state_size,
241
+ # we shard intermediate_size and n_groups
242
+ # - since intermediate_size = n_heads * head_dim, sharding on
243
+ # intermediate_size is achieved by sharding on n_heads.
244
+ # - IF, world_size divides groups, then sharding
245
+ # (n_groups / world_size, n_heads / world_size)
246
+ # also maintains the invariant n_heads % n_groups == 0
247
+ # - HOWEVER IF, world_size DOES NOT divide groups, then we need
248
+ # to allocate extra space in the shard, such that groups
249
+ # may be replicated to follow the head shard.
250
+ # - NOTE: currently for the world size DOES NOT divide groups
251
+ # case, we only support the case when n_groups == 1
252
+ self.tp_size = get_tensor_model_parallel_world_size()
253
+ self.tp_rank = get_tensor_model_parallel_rank()
254
+
255
+ assert (
256
+ num_heads % self.tp_size == 0
257
+ ), "Tensor parallel world size must divide num heads."
258
+
259
+ assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
260
+ "If tensor parallel world size does not divide num_groups, "
261
+ "then num_groups must equal 1."
262
+ )
263
+
264
+ self.ssm_state_size = ssm_state_size
265
+ self.conv_kernel_size = conv_kernel_size
266
+ self.activation = activation
267
+ self.layer_id = layer_id
268
+
269
+ self.intermediate_size = intermediate_size
270
+ self.head_dim = head_dim
271
+ self.num_heads = num_heads
272
+ self.chunk_size = chunk_size
273
+
274
+ self.n_groups = n_groups
275
+ if n_groups % self.tp_size != 0:
276
+ # - for TP we shard conv_dim by sharding on n_groups,
277
+ # - but if n_groups cannot divide tp_size, we need to
278
+ # extend some extra groups
279
+ groups = MambaStateShapeCalculator.extra_groups_for_head_shards(
280
+ n_groups, self.tp_size
281
+ )
282
+ self.n_groups = n_groups + groups
283
+
284
+ self.groups_ssm_state_size = self.n_groups * self.ssm_state_size
285
+ self.conv_dim = intermediate_size + 2 * self.groups_ssm_state_size
286
+
287
+ self.conv1d = MergedColumnParallelLinear(
288
+ input_size=conv_kernel_size,
289
+ output_sizes=[
290
+ intermediate_size,
291
+ self.groups_ssm_state_size,
292
+ self.groups_ssm_state_size,
293
+ ],
294
+ bias=use_conv_bias,
295
+ quant_config=None,
296
+ prefix=f"{prefix}.conv1d",
297
+ )
298
+
299
+ self.in_proj = MergedColumnParallelLinear(
300
+ input_size=hidden_size,
301
+ output_sizes=[
302
+ intermediate_size,
303
+ intermediate_size,
304
+ self.groups_ssm_state_size,
305
+ self.groups_ssm_state_size,
306
+ self.num_heads,
307
+ ],
308
+ bias=use_bias,
309
+ prefix=f"{prefix}.in_proj",
310
+ )
311
+ if n_groups % self.tp_size != 0:
312
+ # This is the n_groups == 1 case,
313
+ # where we need to duplicate groups if TP>1.
314
+
315
+ # - because in_proj is a concatenation of 3 weights, we
316
+ # need to interleave them before sharding
317
+ # - use the custom weight loader mamba_v2_sharded_weight_loader
318
+ # for conv1d.bias, covn1d.weight and in_proj.weight
319
+ # - need to set these settings, to assign the groups
320
+ # to the head shards
321
+ group_shard_settings = (
322
+ self.groups_ssm_state_size, # expected model size
323
+ (self.n_groups - n_groups) * self.ssm_state_size, # extra dims assigned
324
+ n_groups == 1, # if there was only one group
325
+ )
326
+ intermediate_settings = (intermediate_size, 0, False)
327
+ head_settings = (self.num_heads, 0, False)
328
+
329
+ # - the weight already has a "weight_loader" attribute
330
+ # which set_weight_attrs will raise if we do not
331
+ # delete before trying to override it
332
+ # - ditto for the other two weights below
333
+ delattr(self.conv1d.bias, "weight_loader")
334
+ set_weight_attrs(
335
+ self.conv1d.bias,
336
+ {
337
+ "weight_loader": mamba_v2_sharded_weight_loader(
338
+ [
339
+ intermediate_settings,
340
+ group_shard_settings,
341
+ group_shard_settings,
342
+ ],
343
+ self.tp_size,
344
+ self.tp_rank,
345
+ )
346
+ },
347
+ )
348
+
349
+ delattr(self.conv1d.weight, "weight_loader")
350
+ set_weight_attrs(
351
+ self.conv1d.weight,
352
+ {
353
+ "weight_loader": mamba_v2_sharded_weight_loader(
354
+ [
355
+ intermediate_settings,
356
+ group_shard_settings,
357
+ group_shard_settings,
358
+ ],
359
+ self.tp_size,
360
+ self.tp_rank,
361
+ )
362
+ },
363
+ )
364
+
365
+ if quant_config is None:
366
+ # - quant layers do not have a weight loader
367
+ delattr(self.in_proj.weight, "weight_loader")
368
+ set_weight_attrs(
369
+ self.in_proj.weight,
370
+ {
371
+ "weight_loader": mamba_v2_sharded_weight_loader(
372
+ [
373
+ intermediate_settings, # for gate
374
+ intermediate_settings,
375
+ group_shard_settings,
376
+ group_shard_settings,
377
+ head_settings, # for dt
378
+ ],
379
+ self.tp_size,
380
+ self.tp_rank,
381
+ )
382
+ },
383
+ )
384
+
385
+ # unsqueeze to fit conv1d weights shape into the linear weights shape.
386
+ # Can't do this in `weight_loader` since it already exists in
387
+ # `ColumnParallelLinear` and `MergedColumnParallelLinear`,
388
+ # and `set_weight_attrs` doesn't allow to override it
389
+ self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
390
+
391
+ # - these are TPed by heads to reduce the size of the
392
+ # temporal shape
393
+ self.A = nn.Parameter(
394
+ torch.empty(
395
+ divide(num_heads, self.tp_size),
396
+ dtype=torch.float32,
397
+ )
398
+ )
399
+ self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
400
+ self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
401
+ self.use_rms_norm = use_rms_norm
402
+
403
+ set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
404
+ a_weight_loader = composed_weight_loader(
405
+ sharded_weight_loader(0), lambda x: -torch.exp(x.float())
406
+ )
407
+ set_weight_attrs(self.A, {"weight_loader": a_weight_loader})
408
+ set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
409
+
410
+ self.out_proj = RowParallelLinear(
411
+ intermediate_size,
412
+ hidden_size,
413
+ bias=use_bias,
414
+ input_is_parallel=True,
415
+ quant_config=quant_config,
416
+ prefix=f"{prefix}.out_proj",
417
+ reduce_results=False,
418
+ )
419
+
420
+ self.norm = Mixer2RMSNormGated(
421
+ intermediate_size, n_groups, self.use_rms_norm, eps=rms_norm_eps
422
+ )
423
+
424
+ # The tuple is (conv_state, ssm_state)
425
+ self.kv_cache = (torch.tensor([]), torch.tensor([]))
426
+
427
+ self.model_config = model_config
428
+ self.prefix = prefix
429
+
430
+ def forward_native(
431
+ self,
432
+ hidden_states: torch.Tensor,
433
+ output: torch.Tensor,
434
+ mup_vector: Optional[torch.Tensor] = None,
435
+ ):
436
+ pass
437
+
438
+ def forward(
439
+ self,
440
+ hidden_states: torch.Tensor,
441
+ output: torch.Tensor,
442
+ forward_batch: ForwardBatch,
443
+ mup_vector: Optional[torch.Tensor] = None,
444
+ ):
445
+ # attn_backend_list[-1] gives access to MambaAttnBackend
446
+ mamba_backend = forward_batch.attn_backend.attn_backend_list[-1]
447
+ attn_metadata = mamba_backend.forward_metadata
448
+ state_indices_tensor = attn_metadata.mamba_cache_indices
449
+ chunk_size = self.chunk_size
450
+
451
+ conv_state, ssm_state, *rest = mamba_backend.req_to_token_pool.get_mamba_params(
452
+ self.layer_id
453
+ )
454
+
455
+ assert (
456
+ ssm_state.size(1) == self.ssm_state_size
457
+ ), f"dstate must be {self.ssm_state_size}, got {ssm_state.size(1)}"
458
+
459
+ query_start_loc = attn_metadata.query_start_loc
460
+
461
+ chunk_size = self.chunk_size
462
+
463
+ # TODO: properly support this
464
+ prep_initial_states = False
465
+
466
+ # 1. Gated MLP's linear projection
467
+ projected_states, _ = self.in_proj(hidden_states)
468
+
469
+ if mup_vector is not None:
470
+ projected_states = projected_states * mup_vector
471
+
472
+ gate, hidden_states_B_C, dt = torch.split(
473
+ projected_states,
474
+ [
475
+ self.intermediate_size // self.tp_size,
476
+ self.conv_dim // self.tp_size,
477
+ self.num_heads // self.tp_size,
478
+ ],
479
+ dim=-1,
480
+ )
481
+ conv_weights = self.conv1d.weight.view(
482
+ self.conv1d.weight.size(0), self.conv1d.weight.size(2)
483
+ )
484
+
485
+ # - get hidden_states, B and C after depthwise convolution.
486
+ split_hidden_states_B_C_fn = lambda hidden_states_B_C: torch.split(
487
+ hidden_states_B_C,
488
+ [
489
+ self.intermediate_size // self.tp_size,
490
+ self.groups_ssm_state_size // self.tp_size,
491
+ self.groups_ssm_state_size // self.tp_size,
492
+ ],
493
+ dim=-1,
494
+ )
495
+
496
+ preallocated_ssm_out = torch.empty(
497
+ [
498
+ projected_states.shape[0],
499
+ (self.num_heads * self.head_dim) // self.tp_size,
500
+ ],
501
+ dtype=hidden_states.dtype,
502
+ device=hidden_states.device,
503
+ )
504
+
505
+ # Process prefill requests
506
+ if forward_batch.forward_mode.is_extend():
507
+ # 2. Convolution sequence transformation
508
+ # - "cache_indices" updates the conv_state cache in positions
509
+ # pointed to by "state_indices_tensor"
510
+ num_prefill_tokens = forward_batch.extend_num_tokens or 0
511
+ has_initial_states = forward_batch.extend_prefix_lens > 0
512
+ cache_indices = attn_metadata.mamba_cache_indices
513
+
514
+ x = hidden_states_B_C.transpose(
515
+ 0, 1
516
+ ) # this is the form that causal-conv see
517
+ hidden_states_B_C = causal_conv1d_fn(
518
+ x,
519
+ conv_weights,
520
+ self.conv1d.bias,
521
+ activation=self.activation,
522
+ conv_states=conv_state,
523
+ has_initial_state=has_initial_states,
524
+ cache_indices=cache_indices,
525
+ query_start_loc=query_start_loc,
526
+ seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
527
+ ).transpose(0, 1)
528
+
529
+ hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
530
+
531
+ # 3. State Space Model sequence transformation
532
+ initial_states = None
533
+
534
+ if has_initial_states is not None and prep_initial_states:
535
+ initial_states = torch.where(
536
+ has_initial_states[:, None, None, None],
537
+ ssm_state[state_indices_tensor],
538
+ 0,
539
+ )
540
+
541
+ # NOTE: final output is an in-place update of out tensor
542
+ varlen_state = mamba_chunk_scan_combined(
543
+ hidden_states.view(
544
+ 1, num_prefill_tokens, self.num_heads // self.tp_size, self.head_dim
545
+ ),
546
+ dt.unsqueeze(0),
547
+ self.A,
548
+ B.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
549
+ C.view(1, num_prefill_tokens, self.n_groups // self.tp_size, -1),
550
+ chunk_size=chunk_size,
551
+ D=self.D,
552
+ z=None,
553
+ dt_bias=self.dt_bias,
554
+ cu_seqlens=query_start_loc,
555
+ initial_states=initial_states,
556
+ return_varlen_states=True,
557
+ return_final_states=False,
558
+ dt_softplus=True,
559
+ dt_limit=(0.0, float("inf")),
560
+ out=preallocated_ssm_out.view(1, num_prefill_tokens, -1, self.head_dim),
561
+ state_dtype=ssm_state.dtype,
562
+ )
563
+
564
+ # update ssm states
565
+ # - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
566
+ ssm_state[state_indices_tensor] = varlen_state.permute(0, 3, 2, 1)
567
+ elif forward_batch.forward_mode.is_decode():
568
+ num_decodes = len(query_start_loc) - 1
569
+ # 2. Convolution sequence transformation
570
+ hidden_states_B_C = causal_conv1d_update(
571
+ hidden_states_B_C,
572
+ conv_state,
573
+ conv_weights,
574
+ self.conv1d.bias,
575
+ self.activation,
576
+ conv_state_indices=state_indices_tensor,
577
+ )
578
+
579
+ hidden_states, B, C = split_hidden_states_B_C_fn(hidden_states_B_C)
580
+
581
+ # 3. State Space Model sequence transformation
582
+ n_groups = self.n_groups // self.tp_size
583
+ A = (
584
+ self.A[:, None, ...][:, :, None]
585
+ .expand(-1, self.head_dim, self.ssm_state_size)
586
+ .to(dtype=torch.float32)
587
+ )
588
+ dt = dt[:, :, None].expand(-1, -1, self.head_dim)
589
+ dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
590
+ D = self.D[:, None, ...].expand(-1, self.head_dim)
591
+ B = B.view(-1, n_groups, B.shape[1] // n_groups)
592
+ C = C.view(-1, n_groups, C.shape[1] // n_groups)
593
+ hidden_states = hidden_states.view(
594
+ -1, self.num_heads // self.tp_size, self.head_dim
595
+ )
596
+
597
+ # - the hidden is reshaped into (bs, num_heads, head_dim)
598
+ # - mamba_cache_params.ssm_state's slots will be selected
599
+ # using state_indices_tensor_d
600
+ # NOTE: final output is an in-place update of out tensor
601
+ selective_state_update(
602
+ ssm_state.permute(0, 3, 2, 1),
603
+ hidden_states,
604
+ dt,
605
+ A,
606
+ B,
607
+ C,
608
+ D,
609
+ z=None,
610
+ dt_bias=dt_bias,
611
+ dt_softplus=True,
612
+ state_batch_indices=state_indices_tensor,
613
+ out=preallocated_ssm_out.view(num_decodes, -1, self.head_dim),
614
+ )
615
+ elif forward_batch.forward_mode.is_idle():
616
+ preallocated_ssm_out = preallocated_ssm_out
617
+
618
+ # 4. gated MLP
619
+ # GatedRMSNorm internally applying SiLU to the gate
620
+ # SiLU is applied internally before normalization, unlike standard
621
+ # norm usage
622
+ hidden_states = self.norm(preallocated_ssm_out, gate)
623
+
624
+ # 5. Final linear projection
625
+ output[:], _ = self.out_proj(hidden_states)
626
+
627
+ @property
628
+ def mamba_type(self) -> str:
629
+ return "mamba2"
@@ -0,0 +1,81 @@
1
+ # Adapted from: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/mamba/mamba_utils.py
2
+ from sglang.srt.distributed.utils import divide
3
+
4
+
5
+ class MambaStateShapeCalculator:
6
+
7
+ @classmethod
8
+ def linear_attention_state_shape(
9
+ cls,
10
+ num_heads: int,
11
+ tp_size: int,
12
+ head_dim: int,
13
+ ) -> tuple[tuple[int, int, int], ...]:
14
+
15
+ state_shape = (num_heads // tp_size, head_dim, head_dim)
16
+ return (state_shape,)
17
+
18
+ @classmethod
19
+ def mamba1_state_shape(
20
+ cls,
21
+ tp_world_size: int,
22
+ intermediate_size: int,
23
+ state_size: int,
24
+ conv_kernel: int,
25
+ ) -> tuple[tuple[int, int], tuple[int, int]]:
26
+ conv_state_shape = (divide(intermediate_size, tp_world_size), conv_kernel - 1)
27
+
28
+ temporal_state_shape = (divide(intermediate_size, tp_world_size), state_size)
29
+
30
+ conv_state_shape = conv_state_shape[1], conv_state_shape[0]
31
+
32
+ return conv_state_shape, temporal_state_shape
33
+
34
+ @classmethod
35
+ def mamba2_state_shape(
36
+ cls,
37
+ tp_world_size: int,
38
+ intermediate_size: int,
39
+ n_groups: int,
40
+ num_heads: int,
41
+ head_dim: int,
42
+ state_size: int,
43
+ conv_kernel: int,
44
+ ) -> tuple[tuple[int, int], tuple[int, int, int]]:
45
+ # if n_groups is not divisible by world_size, need to extend the shards
46
+ # to ensure all groups needed by a head is sharded along with it
47
+ n_groups = n_groups + cls.extra_groups_for_head_shards(n_groups, tp_world_size)
48
+ # heads and n_groups are TP-ed
49
+ conv_dim = intermediate_size + 2 * n_groups * state_size
50
+
51
+ # contiguous along 'dim' axis
52
+ conv_state_shape = (conv_kernel - 1, divide(conv_dim, tp_world_size))
53
+
54
+ # These are not TP-ed as they depend on A, dt_bias, D
55
+ # - they are typically small
56
+ # e.g., (h_heads, head_dim, state_size) = (128, 64, 128)
57
+ temporal_state_shape = (divide(num_heads, tp_world_size), head_dim, state_size)
58
+ return conv_state_shape, temporal_state_shape
59
+
60
+ @classmethod
61
+ def short_conv_state_shape(
62
+ cls,
63
+ tp_world_size: int,
64
+ intermediate_size: int,
65
+ conv_kernel: int,
66
+ ) -> tuple[tuple[int, int]]:
67
+ conv_dim = divide(intermediate_size, tp_world_size)
68
+ conv_state_shape = (conv_kernel - 1, conv_dim)
69
+ return (conv_state_shape,)
70
+
71
+ @classmethod
72
+ def extra_groups_for_head_shards(cls, ngroups: int, tp_size: int):
73
+ """Compute the increase in group numbers to account for
74
+ replication in order to accompany the head shards."""
75
+
76
+ # in the case ngoups % tp_size == 0, this will be zero
77
+ if ngroups % tp_size == 0:
78
+ return 0
79
+
80
+ # for n_groups == 1, this is exactly tp_size - n_groups
81
+ return tp_size - ngroups
@@ -0,0 +1,2 @@
1
+ from .mamba_ssm import selective_state_update
2
+ from .ssd_combined import mamba_chunk_scan_combined