sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
sglang/srt/models/gpt_oss.py
CHANGED
@@ -66,6 +66,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
66
66
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
67
67
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
68
68
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
69
|
+
from sglang.srt.models.utils import (
|
70
|
+
create_fused_set_kv_buffer_arg,
|
71
|
+
enable_fused_set_kv_buffer,
|
72
|
+
)
|
69
73
|
from sglang.srt.utils import (
|
70
74
|
LazyValue,
|
71
75
|
add_prefix,
|
@@ -193,33 +197,6 @@ class GptOssSparseMoeBlock(nn.Module):
|
|
193
197
|
return ans
|
194
198
|
|
195
199
|
|
196
|
-
def _enable_fused_set_kv_buffer(forward_batch: ForwardBatch):
|
197
|
-
"""Enable fused set_kv_buffer only on CUDA with bfloat16 KV cache."""
|
198
|
-
return _is_cuda and forward_batch.token_to_kv_pool.dtype == torch.bfloat16
|
199
|
-
|
200
|
-
|
201
|
-
# TODO maybe move to a model-common utils
|
202
|
-
def _create_fused_set_kv_buffer_arg(
|
203
|
-
value: torch.Tensor,
|
204
|
-
layer: RadixAttention,
|
205
|
-
forward_batch: ForwardBatch,
|
206
|
-
):
|
207
|
-
layer_id = layer.layer_id
|
208
|
-
token_to_kv_pool = forward_batch.token_to_kv_pool
|
209
|
-
|
210
|
-
k_buffer = token_to_kv_pool.get_key_buffer(layer_id)
|
211
|
-
v_buffer = token_to_kv_pool.get_value_buffer(layer_id)
|
212
|
-
|
213
|
-
return FusedSetKVBufferArg(
|
214
|
-
value=value,
|
215
|
-
k_buffer=k_buffer.view(k_buffer.shape[0], -1),
|
216
|
-
v_buffer=v_buffer.view(v_buffer.shape[0], -1),
|
217
|
-
k_scale=layer.k_scale,
|
218
|
-
v_scale=layer.v_scale,
|
219
|
-
cache_loc=forward_batch.out_cache_loc,
|
220
|
-
)
|
221
|
-
|
222
|
-
|
223
200
|
class GptOssAttention(nn.Module):
|
224
201
|
def __init__(
|
225
202
|
self,
|
@@ -337,12 +314,12 @@ class GptOssAttention(nn.Module):
|
|
337
314
|
q,
|
338
315
|
k,
|
339
316
|
fused_set_kv_buffer_arg=(
|
340
|
-
|
317
|
+
create_fused_set_kv_buffer_arg(
|
341
318
|
value=v,
|
342
319
|
layer=self.attn,
|
343
320
|
forward_batch=forward_batch,
|
344
321
|
)
|
345
|
-
if
|
322
|
+
if enable_fused_set_kv_buffer(forward_batch)
|
346
323
|
else None
|
347
324
|
),
|
348
325
|
)
|
@@ -356,7 +333,7 @@ class GptOssAttention(nn.Module):
|
|
356
333
|
attn_output = self.attn(
|
357
334
|
*inner_state,
|
358
335
|
sinks=self.sinks,
|
359
|
-
save_kv_cache=not
|
336
|
+
save_kv_cache=not enable_fused_set_kv_buffer(forward_batch),
|
360
337
|
)
|
361
338
|
output, _ = self.o_proj(attn_output)
|
362
339
|
return output
|
@@ -49,7 +49,7 @@ from typing import List, Optional, Sequence, Tuple, Union
|
|
49
49
|
import torch
|
50
50
|
import torch.nn as nn
|
51
51
|
import torch.nn.functional as F
|
52
|
-
from transformers.activations import ACT2FN,
|
52
|
+
from transformers.activations import ACT2FN, GELUTanh
|
53
53
|
from transformers.modeling_utils import PreTrainedModel
|
54
54
|
|
55
55
|
try:
|
@@ -614,7 +614,7 @@ class MoonVitPretrainedModel(PreTrainedModel):
|
|
614
614
|
"num_heads": config.num_attention_heads,
|
615
615
|
"hidden_dim": config.hidden_size,
|
616
616
|
"mlp_dim": config.intermediate_size,
|
617
|
-
"activation":
|
617
|
+
"activation": GELUTanh(),
|
618
618
|
"attn_bias": True,
|
619
619
|
"attn_implementation": config._attn_implementation,
|
620
620
|
},
|
sglang/srt/models/llama.py
CHANGED
@@ -385,6 +385,10 @@ class LlamaModel(nn.Module):
|
|
385
385
|
"Self attention has no KV cache scaling " "factor attribute!"
|
386
386
|
)
|
387
387
|
|
388
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
389
|
+
"""Get input embeddings from the model."""
|
390
|
+
return self.embed_tokens
|
391
|
+
|
388
392
|
|
389
393
|
class LlamaForCausalLM(nn.Module):
|
390
394
|
# BitandBytes specific attributes
|
sglang/srt/models/mllama4.py
CHANGED
@@ -291,7 +291,7 @@ class Llama4UnfoldConvolution(nn.Module):
|
|
291
291
|
|
292
292
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
293
293
|
hidden_states = self.unfold(hidden_states)
|
294
|
-
hidden_states = hidden_states.permute(0, 2, 1)
|
294
|
+
hidden_states = hidden_states.permute(0, 2, 1).contiguous()
|
295
295
|
hidden_states, _ = self.linear(hidden_states)
|
296
296
|
return hidden_states
|
297
297
|
|
@@ -446,9 +446,20 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
446
446
|
)
|
447
447
|
|
448
448
|
if self.has_vision:
|
449
|
+
# TODO: make this more general
|
450
|
+
ignore_quant_layers = getattr(config, "quantization_config", {}).get(
|
451
|
+
"ignore", {}
|
452
|
+
)
|
453
|
+
if (
|
454
|
+
"model.layers.vision_model*" in ignore_quant_layers
|
455
|
+
and "model.layers.multi_modal_projector*" in ignore_quant_layers
|
456
|
+
):
|
457
|
+
vision_quant_config = None
|
458
|
+
else:
|
459
|
+
vision_quant_config = quant_config
|
449
460
|
self.vision_model = Llama4VisionModel(
|
450
461
|
config.vision_config,
|
451
|
-
quant_config=
|
462
|
+
quant_config=vision_quant_config,
|
452
463
|
prefix=add_prefix("vision_model", prefix),
|
453
464
|
)
|
454
465
|
|
@@ -560,7 +571,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
560
571
|
forward_batch=forward_batch,
|
561
572
|
language_model=self.language_model,
|
562
573
|
data_embedding_funcs={
|
563
|
-
Modality.IMAGE:
|
574
|
+
Modality.IMAGE: image_embedding_func,
|
564
575
|
},
|
565
576
|
positions=positions,
|
566
577
|
)
|
@@ -689,7 +700,7 @@ class Llama4ForConditionalGeneration(nn.Module):
|
|
689
700
|
"""Handle scale parameter remapping. Returns True if handled."""
|
690
701
|
if "scale" in name and "expert" not in name:
|
691
702
|
remapped_name = maybe_remap_kv_scale_name(name, params_dict)
|
692
|
-
return remapped_name is None
|
703
|
+
return remapped_name is not None and remapped_name != name
|
693
704
|
return False
|
694
705
|
|
695
706
|
def _handle_stacked_params(
|
sglang/srt/models/qwen2.py
CHANGED
@@ -454,9 +454,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|
454
454
|
# For EAGLE3 support
|
455
455
|
self.capture_aux_hidden_states = False
|
456
456
|
|
457
|
-
# For EAGLE3 support
|
458
|
-
self.capture_aux_hidden_states = False
|
459
|
-
|
460
457
|
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
461
458
|
return self.model.get_input_embedding(input_ids)
|
462
459
|
|
@@ -484,10 +481,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|
484
481
|
if self.capture_aux_hidden_states:
|
485
482
|
hidden_states, aux_hidden_states = hidden_states
|
486
483
|
|
487
|
-
aux_hidden_states = None
|
488
|
-
if self.capture_aux_hidden_states:
|
489
|
-
hidden_states, aux_hidden_states = hidden_states
|
490
|
-
|
491
484
|
if self.pp_group.is_last_rank:
|
492
485
|
if not get_embedding:
|
493
486
|
return self.logits_processor(
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -40,7 +40,6 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|
40
40
|
Qwen2_5_VisionRotaryEmbedding,
|
41
41
|
)
|
42
42
|
|
43
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
44
43
|
from sglang.srt.layers.attention.vision import VisionAttention
|
45
44
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
45
|
from sglang.srt.layers.linear import (
|
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
61
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
62
61
|
from sglang.srt.models.qwen2 import Qwen2Model
|
63
62
|
from sglang.srt.utils import add_prefix
|
63
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
64
64
|
|
65
65
|
logger = logging.getLogger(__name__)
|
66
66
|
|
@@ -265,7 +265,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|
265
265
|
self.fullatt_block_indexes = vision_config.fullatt_block_indexes
|
266
266
|
self.window_size = vision_config.window_size
|
267
267
|
self.patch_size = vision_config.patch_size
|
268
|
-
mlp_hidden_size: int = vision_config.intermediate_size
|
268
|
+
mlp_hidden_size: int = ((vision_config.intermediate_size + 7) // 8) * 8
|
269
269
|
self.patch_embed = Qwen2_5_VisionPatchEmbed(
|
270
270
|
patch_size=patch_size,
|
271
271
|
temporal_patch_size=temporal_patch_size,
|
sglang/srt/models/qwen2_audio.py
CHANGED
@@ -39,7 +39,6 @@ from transformers.models.qwen2_audio.modeling_qwen2_audio import (
|
|
39
39
|
Qwen2AudioMultiModalProjector,
|
40
40
|
)
|
41
41
|
|
42
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
43
42
|
from sglang.srt.layers.activation import QuickGELU
|
44
43
|
from sglang.srt.layers.attention.vision import VisionAttention
|
45
44
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
@@ -61,6 +60,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
61
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
62
61
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
63
62
|
from sglang.srt.utils import add_prefix
|
63
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
64
64
|
|
65
65
|
logger = logging.getLogger(__name__)
|
66
66
|
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -25,12 +25,14 @@ from torch import nn
|
|
25
25
|
from transformers import PretrainedConfig
|
26
26
|
|
27
27
|
from sglang.srt.distributed import (
|
28
|
+
get_moe_expert_parallel_world_size,
|
28
29
|
get_pp_group,
|
29
30
|
get_tensor_model_parallel_world_size,
|
30
31
|
tensor_model_parallel_all_reduce,
|
31
32
|
)
|
32
33
|
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
33
34
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
35
|
+
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
34
36
|
from sglang.srt.layers.activation import SiluAndMul
|
35
37
|
from sglang.srt.layers.communicator import (
|
36
38
|
LayerCommunicator,
|
@@ -50,6 +52,7 @@ from sglang.srt.layers.linear import (
|
|
50
52
|
RowParallelLinear,
|
51
53
|
)
|
52
54
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
55
|
+
from sglang.srt.layers.moe import get_moe_a2a_backend
|
53
56
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
54
57
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
55
58
|
from sglang.srt.layers.moe.topk import TopK
|
@@ -82,6 +85,8 @@ class Qwen2MoeMLP(nn.Module):
|
|
82
85
|
quant_config: Optional[QuantizationConfig] = None,
|
83
86
|
reduce_results: bool = True,
|
84
87
|
prefix: str = "",
|
88
|
+
tp_rank: Optional[int] = None,
|
89
|
+
tp_size: Optional[int] = None,
|
85
90
|
) -> None:
|
86
91
|
super().__init__()
|
87
92
|
self.gate_up_proj = MergedColumnParallelLinear(
|
@@ -90,6 +95,8 @@ class Qwen2MoeMLP(nn.Module):
|
|
90
95
|
bias=False,
|
91
96
|
quant_config=quant_config,
|
92
97
|
prefix=add_prefix("gate_up_proj", prefix),
|
98
|
+
tp_rank=tp_rank,
|
99
|
+
tp_size=tp_size,
|
93
100
|
)
|
94
101
|
self.down_proj = RowParallelLinear(
|
95
102
|
intermediate_size,
|
@@ -98,6 +105,8 @@ class Qwen2MoeMLP(nn.Module):
|
|
98
105
|
quant_config=quant_config,
|
99
106
|
reduce_results=reduce_results,
|
100
107
|
prefix=add_prefix("down_proj", prefix),
|
108
|
+
tp_rank=tp_rank,
|
109
|
+
tp_size=tp_size,
|
101
110
|
)
|
102
111
|
if hidden_act != "silu":
|
103
112
|
raise ValueError(
|
@@ -146,7 +155,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
146
155
|
self.experts = get_moe_impl_class(quant_config)(
|
147
156
|
layer_id=self.layer_id,
|
148
157
|
top_k=config.num_experts_per_tok,
|
149
|
-
num_experts=config.num_experts
|
158
|
+
num_experts=config.num_experts
|
159
|
+
+ global_server_args_dict["ep_num_redundant_experts"],
|
150
160
|
hidden_size=config.hidden_size,
|
151
161
|
intermediate_size=config.moe_intermediate_size,
|
152
162
|
quant_config=quant_config,
|
@@ -168,11 +178,31 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
168
178
|
quant_config=quant_config,
|
169
179
|
reduce_results=False,
|
170
180
|
prefix=add_prefix("shared_expert", prefix),
|
181
|
+
**(
|
182
|
+
dict(tp_rank=0, tp_size=1)
|
183
|
+
if get_moe_a2a_backend().is_deepep()
|
184
|
+
else {}
|
185
|
+
),
|
171
186
|
)
|
172
187
|
else:
|
173
188
|
self.shared_expert = None
|
174
189
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
175
190
|
|
191
|
+
if get_moe_a2a_backend().is_deepep():
|
192
|
+
# TODO: we will support tp < ep in the future
|
193
|
+
self.ep_size = get_moe_expert_parallel_world_size()
|
194
|
+
self.num_experts = (
|
195
|
+
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
196
|
+
)
|
197
|
+
self.top_k = config.num_experts_per_tok
|
198
|
+
|
199
|
+
def get_moe_weights(self):
|
200
|
+
return [
|
201
|
+
x.data
|
202
|
+
for name, x in self.experts.named_parameters()
|
203
|
+
if name not in ["correction_bias"]
|
204
|
+
]
|
205
|
+
|
176
206
|
def _forward_shared_experts(self, hidden_states: torch.Tensor):
|
177
207
|
shared_output = None
|
178
208
|
if self.shared_expert is not None:
|
@@ -183,6 +213,36 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
183
213
|
)
|
184
214
|
return shared_output
|
185
215
|
|
216
|
+
def _forward_deepep(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch):
|
217
|
+
shared_output = None
|
218
|
+
if hidden_states.shape[0] > 0:
|
219
|
+
# router_logits: (num_tokens, n_experts)
|
220
|
+
router_logits, _ = self.gate(hidden_states)
|
221
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
222
|
+
topk_weights, topk_idx, _ = self.topk(
|
223
|
+
hidden_states,
|
224
|
+
router_logits,
|
225
|
+
num_token_non_padded=forward_batch.num_token_non_padded,
|
226
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
227
|
+
layer_id=self.layer_id,
|
228
|
+
),
|
229
|
+
)
|
230
|
+
else:
|
231
|
+
topk_weights, topk_idx, _ = self.topk.empty_topk_output(
|
232
|
+
hidden_states.device
|
233
|
+
)
|
234
|
+
final_hidden_states = self.experts(
|
235
|
+
hidden_states=hidden_states,
|
236
|
+
topk_idx=topk_idx,
|
237
|
+
topk_weights=topk_weights,
|
238
|
+
forward_batch=forward_batch,
|
239
|
+
)
|
240
|
+
|
241
|
+
if shared_output is not None:
|
242
|
+
final_hidden_states.add_(shared_output)
|
243
|
+
|
244
|
+
return final_hidden_states
|
245
|
+
|
186
246
|
def _forward_router_experts(self, hidden_states: torch.Tensor):
|
187
247
|
# router_logits: (num_tokens, n_experts)
|
188
248
|
router_logits, _ = self.gate(hidden_states)
|
@@ -213,6 +273,9 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
213
273
|
num_tokens, hidden_dim = hidden_states.shape
|
214
274
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
215
275
|
|
276
|
+
if get_moe_a2a_backend().is_deepep():
|
277
|
+
return self._forward_deepep(hidden_states, forward_batch)
|
278
|
+
|
216
279
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
217
280
|
if (
|
218
281
|
self.alt_stream is not None
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -33,7 +33,6 @@ from einops import rearrange
|
|
33
33
|
from transformers import Qwen2VLConfig
|
34
34
|
from transformers.models.qwen2_vl.configuration_qwen2_vl import Qwen2VLVisionConfig
|
35
35
|
|
36
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
37
36
|
from sglang.srt.layers.activation import QuickGELU
|
38
37
|
from sglang.srt.layers.attention.vision import VisionAttention
|
39
38
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
@@ -50,6 +49,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
50
49
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
50
|
from sglang.srt.models.qwen2 import Qwen2Model
|
52
51
|
from sglang.srt.utils import add_prefix
|
52
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
53
53
|
|
54
54
|
logger = logging.getLogger(__name__)
|
55
55
|
|
sglang/srt/models/qwen3.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
1
|
# Adapted from qwen2.py
|
2
2
|
import logging
|
3
|
-
from functools import partial
|
4
3
|
from typing import Any, Dict, Iterable, List, Optional, Tuple
|
5
4
|
|
6
5
|
import torch
|
@@ -30,12 +29,19 @@ from sglang.srt.model_loader.weight_utils import (
|
|
30
29
|
)
|
31
30
|
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
|
32
31
|
from sglang.srt.models.qwen2 import Qwen2Model
|
33
|
-
from sglang.srt.utils import
|
32
|
+
from sglang.srt.utils import (
|
33
|
+
add_prefix,
|
34
|
+
get_cmo_stream,
|
35
|
+
is_cuda,
|
36
|
+
is_npu,
|
37
|
+
wait_cmo_stream,
|
38
|
+
)
|
34
39
|
|
35
40
|
Qwen3Config = None
|
36
41
|
|
37
42
|
logger = logging.getLogger(__name__)
|
38
43
|
_is_cuda = is_cuda()
|
44
|
+
_is_npu = is_npu()
|
39
45
|
|
40
46
|
|
41
47
|
class Qwen3Attention(nn.Module):
|
@@ -235,9 +241,18 @@ class Qwen3DecoderLayer(nn.Module):
|
|
235
241
|
|
236
242
|
# Fully Connected
|
237
243
|
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
238
|
-
hidden_states,
|
244
|
+
hidden_states,
|
245
|
+
residual,
|
246
|
+
forward_batch,
|
247
|
+
cache=(
|
248
|
+
[self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
|
249
|
+
if _is_npu
|
250
|
+
else None
|
251
|
+
),
|
239
252
|
)
|
240
253
|
hidden_states = self.mlp(hidden_states)
|
254
|
+
if _is_npu and get_cmo_stream():
|
255
|
+
wait_cmo_stream()
|
241
256
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
242
257
|
hidden_states, residual, forward_batch
|
243
258
|
)
|
sglang/srt/models/qwen3_moe.py
CHANGED
@@ -51,7 +51,7 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
51
51
|
from sglang.srt.layers.moe.topk import TopK
|
52
52
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
53
53
|
from sglang.srt.layers.radix_attention import RadixAttention
|
54
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
54
|
+
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding, get_rope
|
55
55
|
from sglang.srt.layers.utils import get_layer_id
|
56
56
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
57
57
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
@@ -60,6 +60,10 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTe
|
|
60
60
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
61
61
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
|
62
62
|
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
|
63
|
+
from sglang.srt.models.utils import (
|
64
|
+
create_fused_set_kv_buffer_arg,
|
65
|
+
enable_fused_set_kv_buffer,
|
66
|
+
)
|
63
67
|
from sglang.srt.utils import (
|
64
68
|
add_prefix,
|
65
69
|
is_cuda,
|
@@ -354,6 +358,10 @@ class Qwen3MoeAttention(nn.Module):
|
|
354
358
|
rope_scaling=rope_scaling,
|
355
359
|
dual_chunk_attention_config=dual_chunk_attention_config,
|
356
360
|
)
|
361
|
+
self.compatible_with_fused_kv_buffer = (
|
362
|
+
False if isinstance(self.rotary_emb, MRotaryEmbedding) else True
|
363
|
+
)
|
364
|
+
|
357
365
|
self.attn = RadixAttention(
|
358
366
|
self.num_heads,
|
359
367
|
self.head_dim,
|
@@ -412,7 +420,21 @@ class Qwen3MoeAttention(nn.Module):
|
|
412
420
|
qkv, _ = self.qkv_proj(hidden_states)
|
413
421
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
414
422
|
q, k = self._apply_qk_norm(q, k)
|
415
|
-
q, k = self.rotary_emb(
|
423
|
+
q, k = self.rotary_emb(
|
424
|
+
positions,
|
425
|
+
q,
|
426
|
+
k,
|
427
|
+
fused_set_kv_buffer_arg=(
|
428
|
+
create_fused_set_kv_buffer_arg(
|
429
|
+
value=v,
|
430
|
+
layer=self.attn,
|
431
|
+
forward_batch=forward_batch,
|
432
|
+
)
|
433
|
+
if enable_fused_set_kv_buffer(forward_batch)
|
434
|
+
and self.compatible_with_fused_kv_buffer
|
435
|
+
else None
|
436
|
+
),
|
437
|
+
)
|
416
438
|
inner_state = q, k, v, forward_batch
|
417
439
|
return None, forward_batch, inner_state
|
418
440
|
|
@@ -420,7 +442,13 @@ class Qwen3MoeAttention(nn.Module):
|
|
420
442
|
hidden_states, forward_batch, inner_state = intermediate_state
|
421
443
|
if inner_state is None:
|
422
444
|
return hidden_states
|
423
|
-
attn_output = self.attn(
|
445
|
+
attn_output = self.attn(
|
446
|
+
*inner_state,
|
447
|
+
save_kv_cache=not (
|
448
|
+
enable_fused_set_kv_buffer(forward_batch)
|
449
|
+
and self.compatible_with_fused_kv_buffer
|
450
|
+
),
|
451
|
+
)
|
424
452
|
output, _ = self.o_proj(attn_output)
|
425
453
|
return output
|
426
454
|
|
sglang/srt/models/qwen3_next.py
CHANGED
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
|
13
13
|
get_tensor_model_parallel_rank,
|
14
14
|
get_tensor_model_parallel_world_size,
|
15
15
|
)
|
16
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
16
17
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
17
18
|
from sglang.srt.layers.attention.fla.layernorm_gated import RMSNorm as RMSNormGated
|
18
19
|
from sglang.srt.layers.attention.mamba.mamba import mamba_v2_sharded_weight_loader
|
@@ -46,7 +47,14 @@ from sglang.srt.model_loader.weight_utils import (
|
|
46
47
|
sharded_weight_loader,
|
47
48
|
)
|
48
49
|
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP, Qwen2MoeSparseMoeBlock
|
49
|
-
from sglang.srt.utils import
|
50
|
+
from sglang.srt.utils import (
|
51
|
+
LazyValue,
|
52
|
+
add_prefix,
|
53
|
+
is_cuda,
|
54
|
+
is_npu,
|
55
|
+
make_layers,
|
56
|
+
set_weight_attrs,
|
57
|
+
)
|
50
58
|
|
51
59
|
logger = logging.getLogger(__name__)
|
52
60
|
_is_cuda = is_cuda()
|
@@ -239,6 +247,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
|
239
247
|
self,
|
240
248
|
config: Qwen3NextConfig,
|
241
249
|
layer_id: int,
|
250
|
+
quant_config: Optional[QuantizationConfig] = None,
|
242
251
|
alt_stream: Optional[torch.cuda.Stream] = None,
|
243
252
|
) -> None:
|
244
253
|
super().__init__()
|
@@ -278,6 +287,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
|
278
287
|
input_size=self.hidden_size,
|
279
288
|
output_size=projection_size_qkvz,
|
280
289
|
bias=False,
|
290
|
+
quant_config=quant_config,
|
281
291
|
tp_rank=self.attn_tp_rank,
|
282
292
|
tp_size=self.attn_tp_size,
|
283
293
|
)
|
@@ -285,6 +295,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
|
285
295
|
input_size=self.hidden_size,
|
286
296
|
output_size=projection_size_ba,
|
287
297
|
bias=False,
|
298
|
+
quant_config=None,
|
288
299
|
tp_rank=self.attn_tp_rank,
|
289
300
|
tp_size=self.attn_tp_size,
|
290
301
|
)
|
@@ -336,6 +347,7 @@ class Qwen3GatedDeltaNet(nn.Module):
|
|
336
347
|
self.value_dim,
|
337
348
|
self.hidden_size,
|
338
349
|
bias=False,
|
350
|
+
quant_config=quant_config,
|
339
351
|
input_is_parallel=True,
|
340
352
|
reduce_results=False,
|
341
353
|
tp_rank=self.attn_tp_rank,
|
@@ -493,7 +505,9 @@ class Qwen3HybridLinearDecoderLayer(nn.Module):
|
|
493
505
|
) -> None:
|
494
506
|
super().__init__()
|
495
507
|
self.config = config
|
496
|
-
self.linear_attn = Qwen3GatedDeltaNet(
|
508
|
+
self.linear_attn = Qwen3GatedDeltaNet(
|
509
|
+
config, layer_id, quant_config, alt_stream
|
510
|
+
)
|
497
511
|
|
498
512
|
# Qwen3Next all layers are sparse and have no nextn now
|
499
513
|
self.is_layer_sparse = True
|
@@ -843,13 +857,14 @@ class Qwen3NextModel(nn.Module):
|
|
843
857
|
residual = None
|
844
858
|
for i in range(len(self.layers)):
|
845
859
|
layer = self.layers[i]
|
846
|
-
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
860
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
861
|
+
hidden_states, residual = layer(
|
862
|
+
layer_id=i,
|
863
|
+
positions=positions,
|
864
|
+
hidden_states=hidden_states,
|
865
|
+
residual=residual,
|
866
|
+
forward_batch=forward_batch,
|
867
|
+
)
|
853
868
|
|
854
869
|
if not forward_batch.forward_mode.is_idle():
|
855
870
|
if residual is None:
|
@@ -895,6 +910,18 @@ class Qwen3NextForCausalLM(nn.Module):
|
|
895
910
|
self.lm_head = self.lm_head.float()
|
896
911
|
self.logits_processor = LogitsProcessor(config)
|
897
912
|
|
913
|
+
self._routed_experts_weights_of_layer = LazyValue(
|
914
|
+
lambda: {
|
915
|
+
layer_id: layer.mlp.get_moe_weights()
|
916
|
+
for layer_id, layer in enumerate(self.model.layers)
|
917
|
+
if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock)
|
918
|
+
}
|
919
|
+
)
|
920
|
+
|
921
|
+
@property
|
922
|
+
def routed_experts_weights_of_layer(self):
|
923
|
+
return self._routed_experts_weights_of_layer.value
|
924
|
+
|
898
925
|
@torch.no_grad()
|
899
926
|
def forward(
|
900
927
|
self,
|