sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
66
66
  from sglang.srt.managers.schedule_batch import global_server_args_dict
67
67
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
68
68
  from sglang.srt.model_loader.weight_utils import default_weight_loader
69
+ from sglang.srt.models.utils import (
70
+ create_fused_set_kv_buffer_arg,
71
+ enable_fused_set_kv_buffer,
72
+ )
69
73
  from sglang.srt.utils import (
70
74
  LazyValue,
71
75
  add_prefix,
@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
193
197
  return ans
194
198
 
195
199
 
196
- def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
197
- """Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
198
- return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
199
-
200
-
201
- # TODO maybe move to a model-common utils
202
- def _create_fused_set_kv_buffer_arg(
203
- value: torch.Tensor,
204
- layer: RadixAttention,
205
- forward_batch: ForwardBatch,
206
- ):
207
- layer_id = layer.layer_id
208
- token_to_kv_pool = forward_batch.token_to_kv_pool
209
-
210
- k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
211
- v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
212
-
213
- return FusedSetKVBufferArg(
214
- value=value,
215
- k_buffer=k_buffer.view(k_buffer.shape[0], -1),
216
- v_buffer=v_buffer.view(v_buffer.shape[0], -1),
217
- k_scale=layer.k_scale,
218
- v_scale=layer.v_scale,
219
- cache_loc=forward_batch.out_cache_loc,
220
- )
221
-
222
-
223
200
  class GptOssAttention(nn.Module):
224
201
  def __init__(
225
202
  self,
@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
337
314
  q,
338
315
  k,
339
316
  fused_set_kv_buffer_arg=(
340
- _create_fused_set_kv_buffer_arg(
317
+ create_fused_set_kv_buffer_arg(
341
318
  value=v,
342
319
  layer=self.attn,
343
320
  forward_batch=forward_batch,
344
321
  )
345
- if _enable_fused_set_kv_buffer(forward_batch)
322
+ if enable_fused_set_kv_buffer(forward_batch)
346
323
  else None
347
324
  ),
348
325
  )
@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
356
333
  attn_output = self.attn(
357
334
  *inner_state,
358
335
  sinks=self.sinks,
359
- save_kv_cache=not _enable_fused_set_kv_buffer(forward_batch),
336
+ save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
360
337
  )
361
338
  output, _ = self.o_proj(attn_output)
362
339
  return output
@@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union
49
49
  import torch
50
50
  import torch.nn as nn
51
51
  import torch.nn.functional as F
52
- from transformers.activations import ACT2FN, PytorchGELUTanh
52
+ from transformers.activations import ACT2FN, GELUTanh
53
53
  from transformers.modeling_utils import PreTrainedModel
54
54
 
55
55
  try:
@@ -614,7 +614,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
614
614
  "num_heads": config.num_attention_heads,
615
615
  "hidden_dim": config.hidden_size,
616
616
  "mlp_dim": config.intermediate_size,
617
- "activation": PytorchGELUTanh(),
617
+ "activation": GELUTanh(),
618
618
  "attn_bias": True,
619
619
  "attn_implementation": config._attn_implementation,
620
620
  },
@@ -385,6 +385,10 @@ class LlamaModel(nn.Module):
385
385
  "Self attention has no KV cache scaling " "factor attribute!"
386
386
  )
387
387
 
388
+ def get_input_embeddings(self) -> nn.Embedding:
389
+ """Get input embeddings from the model."""
390
+ return self.embed_tokens
391
+
388
392
 
389
393
  class LlamaForCausalLM(nn.Module):
390
394
  # BitandBytes specific attributes
@@ -131,7 +131,7 @@ elif _is_hip:
131
131
  awq_dequantize_triton as awq_dequantize,
132
132
  )
133
133
  else:
134
- from vllm._custom_ops import awq_dequantize
134
+ pass
135
135
 
136
136
  logger = logging.getLogger(__name__)
137
137
 
@@ -111,7 +111,7 @@ elif _is_hip:
111
111
  awq_dequantize_triton as awq_dequantize,
112
112
  )
113
113
  else:
114
- from vllm._custom_ops import awq_dequantize
114
+ pass
115
115
 
116
116
 
117
117
  logger = logging.getLogger(__name__)
@@ -291,7 +291,7 @@ class Llama4UnfoldConvolution(nn.Module):
291
291
 
292
292
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
293
293
  hidden_states = self.unfold(hidden_states)
294
- hidden_states = hidden_states.permute(0, 2, 1)
294
+ hidden_states = hidden_states.permute(0, 2, 1).contiguous()
295
295
  hidden_states, _ = self.linear(hidden_states)
296
296
  return hidden_states
297
297
 
@@ -446,9 +446,20 @@ class Llama4ForConditionalGeneration(nn.Module):
446
446
  )
447
447
 
448
448
  if self.has_vision:
449
+ # TODO: make this more general
450
+ ignore_quant_layers = getattr(config, "quantization_config", {}).get(
451
+ "ignore", {}
452
+ )
453
+ if (
454
+ "model.layers.vision_model*" in ignore_quant_layers
455
+ and "model.layers.multi_modal_projector*" in ignore_quant_layers
456
+ ):
457
+ vision_quant_config = None
458
+ else:
459
+ vision_quant_config = quant_config
449
460
  self.vision_model = Llama4VisionModel(
450
461
  config.vision_config,
451
- quant_config=quant_config,
462
+ quant_config=vision_quant_config,
452
463
  prefix=add_prefix("vision_model", prefix),
453
464
  )
454
465
 
@@ -560,7 +571,7 @@ class Llama4ForConditionalGeneration(nn.Module):
560
571
  forward_batch=forward_batch,
561
572
  language_model=self.language_model,
562
573
  data_embedding_funcs={
563
- Modality.IMAGE: self.get_image_feature,
574
+ Modality.IMAGE: image_embedding_func,
564
575
  },
565
576
  positions=positions,
566
577
  )
@@ -689,7 +700,7 @@ class Llama4ForConditionalGeneration(nn.Module):
689
700
  """Handle scale parameter remapping. Returns True if handled."""
690
701
  if "scale" in name and "expert" not in name:
691
702
  remapped_name = maybe_remap_kv_scale_name(name, params_dict)
692
- return remapped_name is None
703
+ return remapped_name is not None and remapped_name != name
693
704
  return False
694
705
 
695
706
  def _handle_stacked_params(
@@ -454,9 +454,6 @@ class Qwen2ForCausalLM(nn.Module):
454
454
  # For EAGLE3 support
455
455
  self.capture_aux_hidden_states = False
456
456
 
457
- # For EAGLE3 support
458
- self.capture_aux_hidden_states = False
459
-
460
457
  def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
461
458
  return self.model.get_input_embedding(input_ids)
462
459
 
@@ -484,10 +481,6 @@ class Qwen2ForCausalLM(nn.Module):
484
481
  if self.capture_aux_hidden_states:
485
482
  hidden_states, aux_hidden_states = hidden_states
486
483
 
487
- aux_hidden_states = None
488
- if self.capture_aux_hidden_states:
489
- hidden_states, aux_hidden_states = hidden_states
490
-
491
484
  if self.pp_group.is_last_rank:
492
485
  if not get_embedding:
493
486
  return self.logits_processor(
@@ -40,7 +40,6 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
40
40
  Qwen2_5_VisionRotaryEmbedding,
41
41
  )
42
42
 
43
- from sglang.srt.hf_transformers_utils import get_processor
44
43
  from sglang.srt.layers.attention.vision import VisionAttention
45
44
  from sglang.srt.layers.layernorm import RMSNorm
46
45
  from sglang.srt.layers.linear import (
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
61
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
62
61
  from sglang.srt.models.qwen2 import Qwen2Model
63
62
  from sglang.srt.utils import add_prefix
63
+ from sglang.srt.utils.hf_transformers_utils import get_processor
64
64
 
65
65
  logger = logging.getLogger(__name__)
66
66
 
@@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
265
265
  self.fullatt_block_indexes = vision_config.fullatt_block_indexes
266
266
  self.window_size = vision_config.window_size
267
267
  self.patch_size = vision_config.patch_size
268
- mlp_hidden_size: int = vision_config.intermediate_size
268
+ mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
269
269
  self.patch_embed = Qwen2_5_VisionPatchEmbed(
270
270
  patch_size=patch_size,
271
271
  temporal_patch_size=temporal_patch_size,
@@ -39,7 +39,6 @@ from transformers.models.qwen2_audio.modeling_qwen2_audio import (
39
39
  Qwen2AudioMultiModalProjector,
40
40
  )
41
41
 
42
- from sglang.srt.hf_transformers_utils import get_processor
43
42
  from sglang.srt.layers.activation import QuickGELU
44
43
  from sglang.srt.layers.attention.vision import VisionAttention
45
44
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
61
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
62
61
  from sglang.srt.models.qwen2 import Qwen2ForCausalLM
63
62
  from sglang.srt.utils import add_prefix
63
+ from sglang.srt.utils.hf_transformers_utils import get_processor
64
64
 
65
65
  logger = logging.getLogger(__name__)
66
66
 
@@ -25,12 +25,14 @@ from torch import nn
25
25
  from transformers import PretrainedConfig
26
26
 
27
27
  from sglang.srt.distributed import (
28
+ get_moe_expert_parallel_world_size,
28
29
  get_pp_group,
29
30
  get_tensor_model_parallel_world_size,
30
31
  tensor_model_parallel_all_reduce,
31
32
  )
32
33
  from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
33
34
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
35
+ from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
34
36
  from sglang.srt.layers.activation import SiluAndMul
35
37
  from sglang.srt.layers.communicator import (
36
38
  LayerCommunicator,
@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
50
52
  RowParallelLinear,
51
53
  )
52
54
  from sglang.srt.layers.logits_processor import LogitsProcessor
55
+ from sglang.srt.layers.moe import get_moe_a2a_backend
53
56
  from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
54
57
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
55
58
  from sglang.srt.layers.moe.topk import TopK
@@ -82,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
82
85
  quant_config: Optional[QuantizationConfig] = None,
83
86
  reduce_results: bool = True,
84
87
  prefix: str = "",
88
+ tp_rank: Optional[int] = None,
89
+ tp_size: Optional[int] = None,
85
90
  ) -> None:
86
91
  super().__init__()
87
92
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -90,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
90
95
  bias=False,
91
96
  quant_config=quant_config,
92
97
  prefix=add_prefix("gate_up_proj", prefix),
98
+ tp_rank=tp_rank,
99
+ tp_size=tp_size,
93
100
  )
94
101
  self.down_proj = RowParallelLinear(
95
102
  intermediate_size,
@@ -98,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
98
105
  quant_config=quant_config,
99
106
  reduce_results=reduce_results,
100
107
  prefix=add_prefix("down_proj", prefix),
108
+ tp_rank=tp_rank,
109
+ tp_size=tp_size,
101
110
  )
102
111
  if hidden_act != "silu":
103
112
  raise ValueError(
@@ -146,7 +155,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
146
155
  self.experts = get_moe_impl_class(quant_config)(
147
156
  layer_id=self.layer_id,
148
157
  top_k=config.num_experts_per_tok,
149
- num_experts=config.num_experts,
158
+ num_experts=config.num_experts
159
+ + global_server_args_dict["ep_num_redundant_experts"],
150
160
  hidden_size=config.hidden_size,
151
161
  intermediate_size=config.moe_intermediate_size,
152
162
  quant_config=quant_config,
@@ -168,11 +178,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
168
178
  quant_config=quant_config,
169
179
  reduce_results=False,
170
180
  prefix=add_prefix("shared_expert", prefix),
181
+ **(
182
+ dict(tp_rank=0, tp_size=1)
183
+ if get_moe_a2a_backend().is_deepep()
184
+ else {}
185
+ ),
171
186
  )
172
187
  else:
173
188
  self.shared_expert = None
174
189
  self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
175
190
 
191
+ if get_moe_a2a_backend().is_deepep():
192
+ # TODO: we will support tp < ep in the future
193
+ self.ep_size = get_moe_expert_parallel_world_size()
194
+ self.num_experts = (
195
+ config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
196
+ )
197
+ self.top_k = config.num_experts_per_tok
198
+
199
+ def get_moe_weights(self):
200
+ return [
201
+ x.data
202
+ for name, x in self.experts.named_parameters()
203
+ if name not in ["correction_bias"]
204
+ ]
205
+
176
206
  def _forward_shared_experts(self, hidden_states: torch.Tensor):
177
207
  shared_output = None
178
208
  if self.shared_expert is not None:
@@ -183,6 +213,36 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
183
213
  )
184
214
  return shared_output
185
215
 
216
+ def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
217
+ shared_output = None
218
+ if hidden_states.shape[0] > 0:
219
+ # router_logits: (num_tokens, n_experts)
220
+ router_logits, _ = self.gate(hidden_states)
221
+ shared_output = self._forward_shared_experts(hidden_states)
222
+ topk_weights, topk_idx, _ = self.topk(
223
+ hidden_states,
224
+ router_logits,
225
+ num_token_non_padded=forward_batch.num_token_non_padded,
226
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
227
+ layer_id=self.layer_id,
228
+ ),
229
+ )
230
+ else:
231
+ topk_weights, topk_idx, _ = self.topk.empty_topk_output(
232
+ hidden_states.device
233
+ )
234
+ final_hidden_states = self.experts(
235
+ hidden_states=hidden_states,
236
+ topk_idx=topk_idx,
237
+ topk_weights=topk_weights,
238
+ forward_batch=forward_batch,
239
+ )
240
+
241
+ if shared_output is not None:
242
+ final_hidden_states.add_(shared_output)
243
+
244
+ return final_hidden_states
245
+
186
246
  def _forward_router_experts(self, hidden_states: torch.Tensor):
187
247
  # router_logits: (num_tokens, n_experts)
188
248
  router_logits, _ = self.gate(hidden_states)
@@ -213,6 +273,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
213
273
  num_tokens, hidden_dim = hidden_states.shape
214
274
  hidden_states = hidden_states.view(-1, hidden_dim)
215
275
 
276
+ if get_moe_a2a_backend().is_deepep():
277
+ return self._forward_deepep(hidden_states, forward_batch)
278
+
216
279
  DUAL_STREAM_TOKEN_THRESHOLD = 1024
217
280
  if (
218
281
  self.alt_stream is not None
@@ -33,7 +33,6 @@ from einops import rearrange
33
33
  from transformers import Qwen2VLConfig
34
34
  from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
35
35
 
36
- from sglang.srt.hf_transformers_utils import get_processor
37
36
  from sglang.srt.layers.activation import QuickGELU
38
37
  from sglang.srt.layers.attention.vision import VisionAttention
39
38
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
@@ -50,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
50
  from sglang.srt.models.qwen2 import Qwen2Model
52
51
  from sglang.srt.utils import add_prefix
52
+ from sglang.srt.utils.hf_transformers_utils import get_processor
53
53
 
54
54
  logger = logging.getLogger(__name__)
55
55
 
@@ -1,6 +1,5 @@
1
1
  # Adapted from qwen2.py
2
2
  import logging
3
- from functools import partial
4
3
  from typing import Any, Dict, Iterable, List, Optional, Tuple
5
4
 
6
5
  import torch
@@ -30,12 +29,19 @@ from sglang.srt.model_loader.weight_utils import (
30
29
  )
31
30
  from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
32
31
  from sglang.srt.models.qwen2 import Qwen2Model
33
- from sglang.srt.utils import add_prefix, is_cuda
32
+ from sglang.srt.utils import (
33
+ add_prefix,
34
+ get_cmo_stream,
35
+ is_cuda,
36
+ is_npu,
37
+ wait_cmo_stream,
38
+ )
34
39
 
35
40
  Qwen3Config = None
36
41
 
37
42
  logger = logging.getLogger(__name__)
38
43
  _is_cuda = is_cuda()
44
+ _is_npu = is_npu()
39
45
 
40
46
 
41
47
  class Qwen3Attention(nn.Module):
@@ -235,9 +241,18 @@ class Qwen3DecoderLayer(nn.Module):
235
241
 
236
242
  # Fully Connected
237
243
  hidden_states, residual = self.layer_communicator.prepare_mlp(
238
- hidden_states, residual, forward_batch
244
+ hidden_states,
245
+ residual,
246
+ forward_batch,
247
+ cache=(
248
+ [self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
249
+ if _is_npu
250
+ else None
251
+ ),
239
252
  )
240
253
  hidden_states = self.mlp(hidden_states)
254
+ if _is_npu and get_cmo_stream():
255
+ wait_cmo_stream()
241
256
  hidden_states, residual = self.layer_communicator.postprocess_layer(
242
257
  hidden_states, residual, forward_batch
243
258
  )
@@ -51,7 +51,7 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
51
51
  from sglang.srt.layers.moe.topk import TopK
52
52
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
53
53
  from sglang.srt.layers.radix_attention import RadixAttention
54
- from sglang.srt.layers.rotary_embedding import get_rope
54
+ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
55
55
  from sglang.srt.layers.utils import get_layer_id
56
56
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
57
57
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
60
60
  from sglang.srt.model_loader.weight_utils import default_weight_loader
61
61
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
62
62
  from sglang.srt.models.qwen2_moe import Qwen2MoeModel
63
+ from sglang.srt.models.utils import (
64
+ create_fused_set_kv_buffer_arg,
65
+ enable_fused_set_kv_buffer,
66
+ )
63
67
  from sglang.srt.utils import (
64
68
  add_prefix,
65
69
  is_cuda,
@@ -354,6 +358,10 @@ class Qwen3MoeAttention(nn.Module):
354
358
  rope_scaling=rope_scaling,
355
359
  dual_chunk_attention_config=dual_chunk_attention_config,
356
360
  )
361
+ self.compatible_with_fused_kv_buffer = (
362
+ False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
363
+ )
364
+
357
365
  self.attn = RadixAttention(
358
366
  self.num_heads,
359
367
  self.head_dim,
@@ -412,7 +420,21 @@ class Qwen3MoeAttention(nn.Module):
412
420
  qkv, _ = self.qkv_proj(hidden_states)
413
421
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
414
422
  q, k = self._apply_qk_norm(q, k)
415
- q, k = self.rotary_emb(positions, q, k)
423
+ q, k = self.rotary_emb(
424
+ positions,
425
+ q,
426
+ k,
427
+ fused_set_kv_buffer_arg=(
428
+ create_fused_set_kv_buffer_arg(
429
+ value=v,
430
+ layer=self.attn,
431
+ forward_batch=forward_batch,
432
+ )
433
+ if enable_fused_set_kv_buffer(forward_batch)
434
+ and self.compatible_with_fused_kv_buffer
435
+ else None
436
+ ),
437
+ )
416
438
  inner_state = q, k, v, forward_batch
417
439
  return None, forward_batch, inner_state
418
440
 
@@ -420,7 +442,13 @@ class Qwen3MoeAttention(nn.Module):
420
442
  hidden_states, forward_batch, inner_state = intermediate_state
421
443
  if inner_state is None:
422
444
  return hidden_states
423
- attn_output = self.attn(*inner_state)
445
+ attn_output = self.attn(
446
+ *inner_state,
447
+ save_kv_cache=not (
448
+ enable_fused_set_kv_buffer(forward_batch)
449
+ and self.compatible_with_fused_kv_buffer
450
+ ),
451
+ )
424
452
  output, _ = self.o_proj(attn_output)
425
453
  return output
426
454
 
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
13
13
  get_tensor_model_parallel_rank,
14
14
  get_tensor_model_parallel_world_size,
15
15
  )
16
+ from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
16
17
  from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
17
18
  from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
18
19
  from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
@@ -46,7 +47,14 @@ from sglang.srt.model_loader.weight_utils import (
46
47
  sharded_weight_loader,
47
48
  )
48
49
  from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
49
- from sglang.srt.utils import add_prefix, is_cuda, is_npu, make_layers, set_weight_attrs
50
+ from sglang.srt.utils import (
51
+ LazyValue,
52
+ add_prefix,
53
+ is_cuda,
54
+ is_npu,
55
+ make_layers,
56
+ set_weight_attrs,
57
+ )
50
58
 
51
59
  logger = logging.getLogger(__name__)
52
60
  _is_cuda = is_cuda()
@@ -239,6 +247,7 @@ class Qwen3GatedDeltaNet(nn.Module):
239
247
  self,
240
248
  config: Qwen3NextConfig,
241
249
  layer_id: int,
250
+ quant_config: Optional[QuantizationConfig] = None,
242
251
  alt_stream: Optional[torch.cuda.Stream] = None,
243
252
  ) -> None:
244
253
  super().__init__()
@@ -278,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module):
278
287
  input_size=self.hidden_size,
279
288
  output_size=projection_size_qkvz,
280
289
  bias=False,
290
+ quant_config=quant_config,
281
291
  tp_rank=self.attn_tp_rank,
282
292
  tp_size=self.attn_tp_size,
283
293
  )
@@ -285,6 +295,7 @@ class Qwen3GatedDeltaNet(nn.Module):
285
295
  input_size=self.hidden_size,
286
296
  output_size=projection_size_ba,
287
297
  bias=False,
298
+ quant_config=None,
288
299
  tp_rank=self.attn_tp_rank,
289
300
  tp_size=self.attn_tp_size,
290
301
  )
@@ -336,6 +347,7 @@ class Qwen3GatedDeltaNet(nn.Module):
336
347
  self.value_dim,
337
348
  self.hidden_size,
338
349
  bias=False,
350
+ quant_config=quant_config,
339
351
  input_is_parallel=True,
340
352
  reduce_results=False,
341
353
  tp_rank=self.attn_tp_rank,
@@ -493,7 +505,9 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
493
505
  ) -> None:
494
506
  super().__init__()
495
507
  self.config = config
496
- self.linear_attn = Qwen3GatedDeltaNet(config, layer_id, alt_stream)
508
+ self.linear_attn = Qwen3GatedDeltaNet(
509
+ config, layer_id, quant_config, alt_stream
510
+ )
497
511
 
498
512
  # Qwen3Next all layers are sparse and have no nextn now
499
513
  self.is_layer_sparse = True
@@ -843,13 +857,14 @@ class Qwen3NextModel(nn.Module):
843
857
  residual = None
844
858
  for i in range(len(self.layers)):
845
859
  layer = self.layers[i]
846
- hidden_states, residual = layer(
847
- layer_id=i,
848
- positions=positions,
849
- hidden_states=hidden_states,
850
- residual=residual,
851
- forward_batch=forward_batch,
852
- )
860
+ with get_global_expert_distribution_recorder().with_current_layer(i):
861
+ hidden_states, residual = layer(
862
+ layer_id=i,
863
+ positions=positions,
864
+ hidden_states=hidden_states,
865
+ residual=residual,
866
+ forward_batch=forward_batch,
867
+ )
853
868
 
854
869
  if not forward_batch.forward_mode.is_idle():
855
870
  if residual is None:
@@ -895,6 +910,18 @@ class Qwen3NextForCausalLM(nn.Module):
895
910
  self.lm_head = self.lm_head.float()
896
911
  self.logits_processor = LogitsProcessor(config)
897
912
 
913
+ self._routed_experts_weights_of_layer = LazyValue(
914
+ lambda: {
915
+ layer_id: layer.mlp.get_moe_weights()
916
+ for layer_id, layer in enumerate(self.model.layers)
917
+ if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock)
918
+ }
919
+ )
920
+
921
+ @property
922
+ def routed_experts_weights_of_layer(self):
923
+ return self._routed_experts_weights_of_layer.value
924
+
898
925
  @torch.no_grad()
899
926
  def forward(
900
927
  self,