sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -132,6 +132,13 @@ class ModelConfig:
132
132
  if is_draft_model and self.hf_config.architectures[0] == "Glm4MoeForCausalLM":
133
133
  self.hf_config.architectures[0] = "Glm4MoeForCausalLMNextN"
134
134
 
135
+ if (
136
+ is_draft_model
137
+ and self.hf_config.architectures[0] == "LongcatFlashForCausalLM"
138
+ ):
139
+ self.hf_config.architectures[0] = "LongcatFlashForCausalLMNextN"
140
+ self.hf_config.num_hidden_layers = self.hf_config.num_nextn_predict_layers
141
+
135
142
  if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
136
143
  self.hf_config.architectures[0] = "MiMoMTP"
137
144
  if (
@@ -140,6 +147,9 @@ class ModelConfig:
140
147
  ):
141
148
  self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
142
149
 
150
+ if is_draft_model and self.hf_config.architectures[0] == "Qwen3NextForCausalLM":
151
+ self.hf_config.architectures[0] = "Qwen3NextForCausalLMMTP"
152
+
143
153
  # Check model type
144
154
  self.is_generation = is_generation_model(
145
155
  self.hf_config.architectures, is_embedding
@@ -199,6 +209,8 @@ class ModelConfig:
199
209
  "DeepseekV2ForCausalLM" in self.hf_config.architectures
200
210
  or "DeepseekV3ForCausalLM" in self.hf_config.architectures
201
211
  or "DeepseekV3ForCausalLMNextN" in self.hf_config.architectures
212
+ or "LongcatFlashForCausalLM" in self.hf_config.architectures
213
+ or "LongcatFlashForCausalLMNextN" in self.hf_config.architectures
202
214
  ):
203
215
  self.head_dim = 256
204
216
  self.attention_arch = AttentionArch.MLA
@@ -270,6 +282,9 @@ class ModelConfig:
270
282
  self.num_key_value_heads = self.num_attention_heads
271
283
  self.hidden_size = self.hf_text_config.hidden_size
272
284
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
285
+ self.num_attention_layers = self.num_hidden_layers
286
+ if "LongcatFlashForCausalLM" in self.hf_config.architectures:
287
+ self.num_attention_layers = self.num_hidden_layers * 2
273
288
  self.num_nextn_predict_layers = getattr(
274
289
  self.hf_text_config, "num_nextn_predict_layers", None
275
290
  )
@@ -290,11 +305,16 @@ class ModelConfig:
290
305
  ) or getattr(self.hf_config, "image_token_index", None)
291
306
 
292
307
  @staticmethod
293
- def from_server_args(server_args: ServerArgs, model_path: str = None, **kwargs):
308
+ def from_server_args(
309
+ server_args: ServerArgs,
310
+ model_path: str = None,
311
+ model_revision: str = None,
312
+ **kwargs,
313
+ ):
294
314
  return ModelConfig(
295
315
  model_path=model_path or server_args.model_path,
296
316
  trust_remote_code=server_args.trust_remote_code,
297
- revision=server_args.revision,
317
+ revision=model_revision or server_args.revision,
298
318
  context_length=server_args.context_length,
299
319
  model_override_args=server_args.json_model_override_args,
300
320
  is_embedding=server_args.is_embedding,
@@ -393,17 +413,27 @@ class ModelConfig:
393
413
  # compressed-tensors uses a "compression_config" key
394
414
  quant_cfg = getattr(self.hf_config, "compression_config", None)
395
415
  if quant_cfg is None:
396
- # check if is modelopt model -- modelopt doesn't have corresponding field
416
+ # check if is modelopt or mixed-precision model -- Both of them don't have corresponding field
397
417
  # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
398
418
  # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
419
+ # example: https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/tree/main
399
420
  is_local = os.path.exists(self.model_path)
400
421
  modelopt_quant_config = {"quant_method": "modelopt"}
401
422
  if not is_local:
402
- from huggingface_hub import HfApi
423
+ import huggingface_hub
424
+
425
+ try:
426
+ from huggingface_hub import HfApi
427
+
428
+ hf_api = HfApi()
429
+ if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
430
+ quant_cfg = modelopt_quant_config
431
+ except huggingface_hub.errors.OfflineModeIsEnabled:
432
+ logger.warning(
433
+ "Offline mode is enabled, skipping hf_quant_config.json check"
434
+ )
435
+ pass
403
436
 
404
- hf_api = HfApi()
405
- if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
406
- quant_cfg = modelopt_quant_config
407
437
  elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
408
438
  quant_config_file = os.path.join(
409
439
  self.model_path, "hf_quant_config.json"
@@ -0,0 +1,326 @@
1
+ # coding=utf-8
2
+ # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Qwen3Hybrid model configuration"""
16
+
17
+ import enum
18
+ import os
19
+
20
+ import numpy as np
21
+ import torch
22
+ from transformers.configuration_utils import PretrainedConfig
23
+ from transformers.modeling_rope_utils import rope_config_validation
24
+ from transformers.utils import logging
25
+
26
+ from sglang.srt.distributed.utils import divide
27
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
28
+
29
+ logger = logging.get_logger(__name__)
30
+
31
+
32
+ # NOTE: HybridLayerType
33
+ class HybridLayerType(enum.Enum):
34
+ full_attention = "attention"
35
+ swa_attention = "swa_attention"
36
+ linear_attention = "linear_attention"
37
+ mamba2 = "mamba"
38
+
39
+
40
+ class Qwen3NextConfig(PretrainedConfig):
41
+ r"""
42
+ This is the configuration class to store the configuration of a [`Qwen3NextModel`]. It is used to instantiate a
43
+ Qwen3-Next model according to the specified arguments, defining the model architecture.
44
+ Instantiating a configuration with the defaults will yield a similar configuration to that of
45
+ Qwen3-Next-80B-A3B-Instruct [Qwen/Qwen3-Next-80B-A3B-Instruct](https://huggingface.co/Qwen/Qwen3-Next-80B-A3B-Instruct).
46
+
47
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
48
+ documentation from [`PretrainedConfig`] for more information.
49
+
50
+
51
+ Args:
52
+ vocab_size (`int`, *optional*, defaults to 151936):
53
+ Vocabulary size of the model. Defines the number of different tokens that can be represented by the
54
+ `inputs_ids`.
55
+ hidden_size (`int`, *optional*, defaults to 2048):
56
+ Dimension of the hidden representations.
57
+ intermediate_size (`int`, *optional*, defaults to 5632):
58
+ Dimension of the MLP representations.
59
+ num_hidden_layers (`int`, *optional*, defaults to 48):
60
+ Number of hidden layers in the Transformer encoder.
61
+ num_attention_heads (`int`, *optional*, defaults to 16):
62
+ Number of attention heads for each attention layer in the Transformer encoder.
63
+ num_key_value_heads (`int`, *optional*, defaults to 2):
64
+ This is the number of key_value heads that should be used to implement Grouped Query Attention. If
65
+ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
66
+ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
67
+ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
68
+ by meanpooling all the original heads within that group. For more details checkout [this
69
+ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `32`.
70
+ hidden_act (`str`, *optional*, defaults to `"silu"`):
71
+ The non-linear activation function in the decoder.
72
+ max_position_embeddings (`int`, *optional*, defaults to 32768):
73
+ The maximum sequence length that this model might ever be used with.
74
+ initializer_range (`float`, *optional*, defaults to 0.02):
75
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
76
+ rms_norm_eps (`float`, *optional*, defaults to 1e-06):
77
+ The epsilon used by the rms normalization layers.
78
+ use_cache (`bool`, *optional*, defaults to `True`):
79
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
80
+ relevant if `config.is_decoder=True`.
81
+ tie_word_embeddings (`bool`, *optional*, defaults to `False`):
82
+ Whether the model's input and output word embeddings should be tied.
83
+ rope_theta (`float`, *optional*, defaults to 10000.0):
84
+ The base period of the RoPE embeddings.
85
+ rope_scaling (`Dict`, *optional*):
86
+ Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
87
+ and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
88
+ accordingly.
89
+ Expected contents:
90
+ `rope_type` (`str`):
91
+ The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
92
+ 'llama3'], with 'default' being the original RoPE implementation.
93
+ `factor` (`float`, *optional*):
94
+ Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
95
+ most scaling types, a `factor` of x will enable the model to handle sequences of length x *
96
+ original maximum pre-trained length.
97
+ `original_max_position_embeddings` (`int`, *optional*):
98
+ Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
99
+ pretraining.
100
+ `attention_factor` (`float`, *optional*):
101
+ Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
102
+ computation. If unspecified, it defaults to value recommended by the implementation, using the
103
+ `factor` field to infer the suggested value.
104
+ `beta_fast` (`float`, *optional*):
105
+ Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
106
+ ramp function. If unspecified, it defaults to 32.
107
+ `beta_slow` (`float`, *optional*):
108
+ Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
109
+ ramp function. If unspecified, it defaults to 1.
110
+ `short_factor` (`List[float]`, *optional*):
111
+ Only used with 'longrope'. The scaling factor to be applied to short contexts (<
112
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
113
+ size divided by the number of attention heads divided by 2
114
+ `long_factor` (`List[float]`, *optional*):
115
+ Only used with 'longrope'. The scaling factor to be applied to long contexts (<
116
+ `original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
117
+ size divided by the number of attention heads divided by 2
118
+ `low_freq_factor` (`float`, *optional*):
119
+ Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
120
+ `high_freq_factor` (`float`, *optional*):
121
+ Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
122
+ partial_rotary_factor (`float`, *optional*, defaults to 0.25):
123
+ Percentage of the query and keys which will have rotary embedding.
124
+ attention_bias (`bool`, *optional*, defaults to `False`):
125
+ Whether to use a bias in the query, key, value and output projection layers during self-attention.
126
+ attention_dropout (`float`, *optional*, defaults to 0.0):
127
+ The dropout ratio for the attention probabilities.
128
+ head_dim (`int`, *optional*, defaults to 256):
129
+ Projection weights dimension in multi-head attention.
130
+ linear_conv_kernel_dim (`int`, *optional*, defaults to 4):
131
+ Kernel size of the convolution used in linear attention layers.
132
+ linear_key_head_dim (`int`, *optional*, defaults to 128):
133
+ Dimension of each key head in linear attention.
134
+ linear_value_head_dim (`int`, *optional*, defaults to 128):
135
+ Dimension of each value head in linear attention.
136
+ linear_num_key_heads (`int`, *optional*, defaults to 16):
137
+ Number of key heads used in linear attention layers.
138
+ linear_num_value_heads (`int`, *optional*, defaults to 32):
139
+ Number of value heads used in linear attention layers.
140
+ decoder_sparse_step (`int`, *optional*, defaults to 1):
141
+ The frequency of the MoE layer.
142
+ moe_intermediate_size (`int`, *optional*, defaults to 512):
143
+ Intermediate size of the routed expert.
144
+ shared_expert_intermediate_size (`int`, *optional*, defaults to 512):
145
+ Intermediate size of the shared expert.
146
+ num_experts_per_tok (`int`, *optional*, defaults to 10):
147
+ Number of selected experts.
148
+ num_experts (`int`, *optional*, defaults to 512):
149
+ Number of routed experts.
150
+ norm_topk_prob (`bool`, *optional*, defaults to `True`):
151
+ Whether to normalize the topk probabilities.
152
+ output_router_logits (`bool`, *optional*, defaults to `False`):
153
+ Whether or not the router logits should be returned by the model. Enabling this will also
154
+ allow the model to output the auxiliary loss, including load balancing loss and router z-loss.
155
+ router_aux_loss_coef (`float`, *optional*, defaults to 0.001):
156
+ The aux loss factor for the total loss.
157
+ mlp_only_layers (`list[int]`, *optional*, defaults to `[]`):
158
+ Indicate which layers use Qwen3NextMLP rather than Qwen3NextSparseMoeBlock
159
+ The list contains layer index, from 0 to num_layers-1 if we have num_layers layers
160
+ If `mlp_only_layers` is empty, `decoder_sparse_step` is used to determine the sparsity.
161
+ layer_types (`list[str]`, *optional*, defaults to None):
162
+ Types of each layer (attention or linear).
163
+
164
+ ```python
165
+ >>> from transformers import Qwen3NextModel, Qwen3NextConfig
166
+
167
+ >>> # Initializing a Qwen3Next style configuration
168
+ >>> configuration = Qwen3NextConfig()
169
+
170
+ >>> # Initializing a model from the Qwen3-Next-80B-A3B style configuration
171
+ >>> model = Qwen3NextModel(configuration)
172
+
173
+ >>> # Accessing the model configuration
174
+ >>> configuration = model.config
175
+ ```
176
+ """
177
+
178
+ model_type = "qwen3_next"
179
+ keys_to_ignore_at_inference = ["past_key_values"]
180
+
181
+ def __init__(
182
+ self,
183
+ vocab_size=151936,
184
+ hidden_size=2048,
185
+ intermediate_size=5632,
186
+ num_hidden_layers=48,
187
+ num_attention_heads=16,
188
+ num_key_value_heads=2,
189
+ hidden_act="silu",
190
+ max_position_embeddings=32768,
191
+ initializer_range=0.02,
192
+ rms_norm_eps=1e-6,
193
+ use_cache=True,
194
+ tie_word_embeddings=False,
195
+ rope_theta=10000.0,
196
+ rope_scaling=None,
197
+ partial_rotary_factor=0.25,
198
+ attention_bias=False,
199
+ attention_dropout=0.0,
200
+ head_dim=256,
201
+ linear_conv_kernel_dim=4,
202
+ linear_key_head_dim=128,
203
+ linear_value_head_dim=128,
204
+ linear_num_key_heads=16,
205
+ linear_num_value_heads=32,
206
+ decoder_sparse_step=1,
207
+ moe_intermediate_size=512,
208
+ shared_expert_intermediate_size=512,
209
+ num_experts_per_tok=10,
210
+ num_experts=512,
211
+ norm_topk_prob=True,
212
+ output_router_logits=False,
213
+ router_aux_loss_coef=0.001,
214
+ mlp_only_layers=[],
215
+ layer_types=None,
216
+ **kwargs,
217
+ ):
218
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
219
+ self.vocab_size = vocab_size
220
+ self.max_position_embeddings = max_position_embeddings
221
+ self.hidden_size = hidden_size
222
+ self.intermediate_size = intermediate_size
223
+ self.num_hidden_layers = num_hidden_layers
224
+ self.num_attention_heads = num_attention_heads
225
+ self.num_key_value_heads = num_key_value_heads
226
+ self.hidden_act = hidden_act
227
+ self.initializer_range = initializer_range
228
+ self.rms_norm_eps = rms_norm_eps
229
+ self.use_cache = use_cache
230
+ self.rope_theta = rope_theta
231
+ self.rope_scaling = rope_scaling
232
+ self.partial_rotary_factor = partial_rotary_factor
233
+ self.attention_bias = attention_bias
234
+ self.attention_dropout = attention_dropout
235
+ self.head_dim = head_dim
236
+ rope_config_validation(self)
237
+
238
+ # linear attention (gdn now part)
239
+ self.linear_conv_kernel_dim = linear_conv_kernel_dim
240
+ self.linear_key_head_dim = linear_key_head_dim
241
+ self.linear_value_head_dim = linear_value_head_dim
242
+ self.linear_num_key_heads = linear_num_key_heads
243
+ self.linear_num_value_heads = linear_num_value_heads
244
+
245
+ # MoE arguments
246
+ self.decoder_sparse_step = decoder_sparse_step
247
+ self.moe_intermediate_size = moe_intermediate_size
248
+ self.shared_expert_intermediate_size = shared_expert_intermediate_size
249
+ self.num_experts_per_tok = num_experts_per_tok
250
+ self.num_experts = num_experts
251
+ self.norm_topk_prob = norm_topk_prob
252
+ self.output_router_logits = output_router_logits
253
+ self.router_aux_loss_coef = router_aux_loss_coef
254
+ self.mlp_only_layers = mlp_only_layers
255
+
256
+ @property
257
+ def layers_block_type(self):
258
+ layer_type_list = []
259
+
260
+ for l in range(self.num_hidden_layers):
261
+ if (l + 1) % self.full_attention_interval == 0:
262
+ layer_type_list.append(HybridLayerType.full_attention.value)
263
+ else:
264
+ layer_type_list.append(HybridLayerType.linear_attention.value)
265
+
266
+ return layer_type_list
267
+
268
+ @property
269
+ def linear_layer_ids(self):
270
+ return [
271
+ i
272
+ for i, type_value in enumerate(self.layers_block_type)
273
+ if type_value == HybridLayerType.linear_attention.value
274
+ ]
275
+
276
+ @property
277
+ def full_attention_layer_ids(self):
278
+ return [
279
+ i
280
+ for i, type_value in enumerate(self.layers_block_type)
281
+ if type_value == HybridLayerType.full_attention.value
282
+ ]
283
+
284
+ @property
285
+ def hybrid_gdn_params(self):
286
+ world_size = get_attention_tp_size()
287
+ conv_dim = (
288
+ self.linear_key_head_dim * self.linear_num_key_heads * 2
289
+ + self.linear_value_head_dim * self.linear_num_value_heads
290
+ )
291
+ conv_state_shape = (
292
+ divide(conv_dim, world_size),
293
+ self.linear_conv_kernel_dim - 1,
294
+ )
295
+
296
+ temporal_state_shape = (
297
+ divide(self.linear_num_value_heads, world_size),
298
+ self.linear_key_head_dim,
299
+ self.linear_value_head_dim,
300
+ )
301
+ conv_dtype = torch.bfloat16
302
+ dtype_map = {
303
+ "float32": torch.float32,
304
+ "bfloat16": torch.bfloat16,
305
+ }
306
+ ssm_dtype = dtype_map[os.environ["SGLANG_MAMBA_SSM_DTYPE"]]
307
+ mamba_layers = self.linear_layer_ids
308
+ return (
309
+ conv_state_shape,
310
+ temporal_state_shape,
311
+ conv_dtype,
312
+ ssm_dtype,
313
+ mamba_layers,
314
+ )
315
+
316
+ @property
317
+ def mamba_cache_per_req(self):
318
+ conv_state_shape, temporal_state_shape, conv_dtype, ssm_dtype, mamba_layers = (
319
+ self.hybrid_gdn_params
320
+ )
321
+ mamba_layers_len = len(mamba_layers)
322
+
323
+ return (
324
+ int(np.prod(conv_state_shape)) * conv_dtype.itemsize
325
+ + int(np.prod(temporal_state_shape)) * ssm_dtype.itemsize
326
+ ) * mamba_layers_len
@@ -20,7 +20,7 @@ class ConnectorType(str, enum.Enum):
20
20
  KV = "KV"
21
21
 
22
22
 
23
- def create_remote_connector(url, device="cpu") -> BaseConnector:
23
+ def create_remote_connector(url, **kwargs) -> BaseConnector:
24
24
  connector_type = parse_connector_type(url)
25
25
  if connector_type == "redis":
26
26
  return RedisConnector(url)
@@ -20,9 +20,8 @@ class BaseConnector(ABC):
20
20
  <connector_type://<host>:<port>/<model_name>/files/<filename>
21
21
  """
22
22
 
23
- def __init__(self, url: str, device: torch.device = "cpu"):
23
+ def __init__(self, url: str):
24
24
  self.url = url
25
- self.device = device
26
25
  self.closed = False
27
26
  self.local_dir = tempfile.mkdtemp()
28
27
  for sig in (signal.SIGINT, signal.SIGTERM):
@@ -15,10 +15,10 @@ logger = logging.getLogger(__name__)
15
15
 
16
16
  class RedisConnector(BaseKVConnector):
17
17
 
18
- def __init__(self, url: str, device: torch.device = "cpu"):
18
+ def __init__(self, url: str):
19
19
  import redis
20
20
 
21
- super().__init__(url, device)
21
+ super().__init__(url)
22
22
  parsed_url = urlparse(url)
23
23
  self.connection = redis.Redis(host=parsed_url.hostname, port=parsed_url.port)
24
24
  self.model_name = parsed_url.path.lstrip("/")
@@ -15,7 +15,7 @@ def create_serde(serde_type: str) -> Tuple[Serializer, Deserializer]:
15
15
 
16
16
  if serde_type == "safe":
17
17
  s = SafeSerializer()
18
- d = SafeDeserializer(torch.uint8)
18
+ d = SafeDeserializer()
19
19
  else:
20
20
  raise ValueError(f"Unknown serde type: {serde_type}")
21
21
 
@@ -19,11 +19,12 @@ class SafeSerializer(Serializer):
19
19
 
20
20
  class SafeDeserializer(Deserializer):
21
21
 
22
- def __init__(self, dtype):
23
- super().__init__(dtype)
22
+ def __init__(self):
23
+ # TODO: dtype options
24
+ super().__init__(torch.float32)
24
25
 
25
26
  def from_bytes_normal(self, b: Union[bytearray, bytes]) -> torch.Tensor:
26
- return load(bytes(b))["tensor_bytes"].to(dtype=self.dtype)
27
+ return load(bytes(b))["tensor_bytes"]
27
28
 
28
29
  def from_bytes(self, b: Union[bytearray, bytes]) -> torch.Tensor:
29
30
  return self.from_bytes_normal(b)
sglang/srt/custom_op.py CHANGED
@@ -1,12 +1,20 @@
1
1
  from torch import nn
2
2
 
3
- from sglang.srt.utils import cpu_has_amx_support, is_cpu, is_cuda, is_hip, is_npu
3
+ from sglang.srt.utils import (
4
+ cpu_has_amx_support,
5
+ is_cpu,
6
+ is_cuda,
7
+ is_hip,
8
+ is_npu,
9
+ is_xpu,
10
+ )
4
11
 
5
12
  _is_cuda = is_cuda()
6
13
  _is_hip = is_hip()
7
14
  _is_cpu = is_cpu()
8
15
  _is_cpu_amx_available = cpu_has_amx_support()
9
16
  _is_npu = is_npu()
17
+ _is_xpu = is_xpu()
10
18
 
11
19
 
12
20
  class CustomOp(nn.Module):
@@ -88,5 +96,7 @@ class CustomOp(nn.Module):
88
96
  return self.forward_cpu
89
97
  elif _is_npu:
90
98
  return self.forward_npu
99
+ elif _is_xpu:
100
+ return self.forward_xpu
91
101
  else:
92
102
  return self.forward_native